mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 10:50:23 +00:00
improve interrogate
This commit is contained in:
@@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
|
||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||
|
||||
def download_default_clip_interrogate_categories(content_dir):
|
||||
print("Downloading CLIP categories...")
|
||||
@@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir):
|
||||
tmpdir = content_dir + "_tmp"
|
||||
try:
|
||||
os.makedirs(tmpdir)
|
||||
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
|
||||
|
||||
for category_type in category_types:
|
||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||
os.rename(tmpdir, content_dir)
|
||||
|
||||
except Exception as e:
|
||||
@@ -51,12 +48,13 @@ class InterrogateModels:
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.loaded_categories = None
|
||||
self.selected_categories = []
|
||||
self.content_dir = content_dir
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
def categories(self):
|
||||
if self.loaded_categories is not None:
|
||||
return self.loaded_categories
|
||||
if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories:
|
||||
return self.loaded_categories
|
||||
|
||||
self.loaded_categories = []
|
||||
|
||||
@@ -64,14 +62,19 @@ class InterrogateModels:
|
||||
download_default_clip_interrogate_categories(self.content_dir)
|
||||
|
||||
if os.path.exists(self.content_dir):
|
||||
for filename in os.listdir(self.content_dir):
|
||||
self.selected_categories = shared.opts.interrogate_clip_categories
|
||||
for category_type in category_types:
|
||||
if 'all' not in self.selected_categories and category_type not in self.selected_categories:
|
||||
continue
|
||||
filename = os.path.join(self.content_dir, f"{category_type}.txt")
|
||||
if not os.path.isfile(filename):
|
||||
continue
|
||||
m = re_topn.search(filename)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
|
||||
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
|
||||
self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines))
|
||||
|
||||
return self.loaded_categories
|
||||
|
||||
@@ -139,6 +142,8 @@ class InterrogateModels:
|
||||
def rank(self, image_features, text_array, top_count=1):
|
||||
import clip
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if shared.opts.interrogate_clip_dict_limit != 0:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
|
Reference in New Issue
Block a user