CLIP interrogator

This commit is contained in:
AUTOMATIC
2022-09-11 18:48:36 +03:00
parent 13008bab90
commit f194457229
13 changed files with 204 additions and 13 deletions

View File

@@ -1,12 +1,16 @@
import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
def get_optimal_device():
if torch.cuda.is_available():
return torch.device("cuda")
if has_mps:
return torch.device("mps")
return torch.device("cpu")
if torch.cuda.is_available():
return torch.device("cuda")
if has_mps:
return torch.device("mps")
return cpu

142
modules/interrogate.py Normal file
View File

@@ -0,0 +1,142 @@
import os
import sys
import traceback
from collections import namedtuple
import re
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
from modules import devices, paths
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
categories = None
def __init__(self, content_dir):
self.categories = []
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
self.categories.append(Category(name=filename, topn=topn, items=lines))
def load_blip_model(self):
import models.blip
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
def load_clip_model(self):
import clip
model, preprocess = clip.load(clip_model_name)
model.eval()
model = model.to(shared.device)
return model, preprocess
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
self.blip_model = self.blip_model.to(shared.device)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
self.clip_model = self.clip_model.to(shared.device)
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu)
if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu)
def rank(self, image_features, text_array, top_count=1):
import clip
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def generate_caption(self, pil_image):
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(shared.device)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
return caption[0]
def interrogate(self, pil_image):
res = None
try:
self.load()
caption = self.generate_caption(pil_image)
res = caption
images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device)
with torch.no_grad():
image_features = self.clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists:
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0]
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
res += ", " + match
except Exception:
print(f"Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.unload()
return res

View File

@@ -18,6 +18,7 @@ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
]
paths = {}

View File

@@ -11,6 +11,7 @@ import modules.artists
from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate
config_filename = "config.json"
@@ -77,6 +78,8 @@ artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.c
styles_filename = os.path.join(script_path, 'styles.csv')
prompt_styles = modules.styles.load_styles(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
class Options:
@@ -123,6 +126,11 @@ class Options:
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"interrogate_keep_models_in_memory": OptionInfo(True, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
}
def __init__(self):

View File

@@ -242,9 +242,14 @@ def add_style(style_name, text):
return [update, update]
def interrogate(image):
prompt = shared.interrogator.interrogate(image)
return gr_show(True) if prompt is None else prompt
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
with gr.Row():
with gr.Row(elem_id="toprow"):
txt2img_prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
txt2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
@@ -365,10 +370,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
with gr.Row():
with gr.Row(elem_id="toprow"):
img2img_prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="img2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
img2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
img2img_interrogate = gr.Button('Interrogate', elem_id="img2img_interrogate", variant='primary')
submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary')
check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
@@ -461,6 +467,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpaint_full_res: gr_show(is_inpaint),
inpainting_mask_invert: gr_show(is_inpaint),
denoising_strength_change_factor: gr_show(is_loopback),
img2img_interrogate: gr_show(not is_inpaint),
}
switch_mode.change(
@@ -480,6 +487,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpaint_full_res,
inpainting_mask_invert,
denoising_strength_change_factor,
img2img_interrogate,
]
)
@@ -540,6 +548,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
fn=interrogate,
inputs=[init_img],
outputs=[img2img_prompt],
)
check_progress.click(
fn=check_progress_call,
show_progress=False,