make CLIP interrogator download original text files if the directory does not exist

remove random artist built-in extension (to re-added as a normal extension on demand)
remove artists.csv (but what does it mean????????????????????)
make interrogate buttons show Loading... when you click them
This commit is contained in:
AUTOMATIC
2023-01-21 09:14:27 +03:00
parent 40ff6db532
commit 6d805b669e
9 changed files with 46 additions and 3151 deletions

View File

@@ -126,8 +126,6 @@ class Api:
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
@@ -390,12 +388,6 @@ class Api:
return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db

View File

@@ -1,25 +0,0 @@
import os.path
import csv
from collections import namedtuple
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
class ArtistsDatabase:
def __init__(self, filename):
self.cats = set()
self.artists = []
if not os.path.exists(filename):
return
with open(filename, "r", newline='', encoding="utf8") as file:
reader = csv.DictReader(file)
for row in reader:
artist = Artist(row["artist"], float(row["score"]), row["category"])
self.artists.append(artist)
self.cats.add(artist.category)
def categories(self):
return sorted(self.cats)

View File

@@ -5,12 +5,13 @@ from collections import namedtuple
import re
import torch
import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
from modules import devices, paths, lowvram, modelloader
from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
try:
os.makedirs(tmpdir)
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
os.rename(tmpdir, content_dir)
except Exception as e:
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
os.remove(tmpdir)
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
self.loaded_categories = None
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
def categories(self):
if self.loaded_categories is not None:
return self.loaded_categories
self.loaded_categories = []
if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir)
if os.path.exists(self.content_dir):
for filename in os.listdir(self.content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
self.categories.append(Category(name=filename, topn=topn, items=lines))
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
return self.loaded_categories
def load_blip_model(self):
import models.blip
@@ -139,7 +172,6 @@ class InterrogateModels:
shared.state.begin()
shared.state.job = 'interrogate'
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -159,12 +191,7 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists:
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0]
for name, topn, items in self.categories:
for name, topn, items in self.categories():
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:

View File

@@ -9,7 +9,6 @@ from PIL import Image
import gradio as gr
import tqdm
import modules.artists
import modules.interrogate
import modules.memmon
import modules.styles
@@ -254,8 +253,6 @@ class State:
state = State()
state.server_start = time.time()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename)
@@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),

View File

@@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
return [gr_show(True), None]
return [gr.update(), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt
return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt
return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface):
@@ -1039,19 +1039,18 @@ def create_ui():
init_img_inpaint,
],
outputs=[img2img_prompt, dummy_component],
show_progress=False,
)
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
fn=lambda *args : process_interrogate(interrogate, *args),
fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
)
img2img_deepbooru.click(
fn=lambda *args : process_interrogate(interrogate_deepbooru, *args),
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args,
)