mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-09 13:49:48 +00:00
fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text
This commit is contained in:
78
modules/aesthetic_clip.py
Normal file
78
modules/aesthetic_clip.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
import html
|
||||
import gc
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import shared
|
||||
from modules.shared import device, aesthetic_embeddings
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def get_all_images_in_folder(folder):
|
||||
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
||||
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
||||
|
||||
|
||||
def check_is_valid_image_file(filename):
|
||||
return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
|
||||
|
||||
|
||||
def batched(dataset, total, n=1):
|
||||
for ndx in range(0, total, n):
|
||||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||||
|
||||
|
||||
def iter_to_batched(iterable, n=1):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = tuple(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def generate_imgs_embd(name, folder, batch_size):
|
||||
# clipModel = CLIPModel.from_pretrained(
|
||||
# shared.sd_model.cond_stage_model.clipModel.name_or_path
|
||||
# )
|
||||
model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
|
||||
processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
|
||||
|
||||
with torch.no_grad():
|
||||
embs = []
|
||||
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
||||
desc=f"Generating embeddings for {name}"):
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
||||
outputs = model.get_image_features(**inputs).cpu()
|
||||
embs.append(torch.clone(outputs))
|
||||
inputs.to("cpu")
|
||||
del inputs, outputs
|
||||
|
||||
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
||||
|
||||
# The generated embedding will be located here
|
||||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||||
torch.save(embs, path)
|
||||
|
||||
model = model.cpu()
|
||||
del model
|
||||
del processor
|
||||
del embs
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
res = f"""
|
||||
Done generating embedding for {name}!
|
||||
Hypernetwork saved to {html.escape(path)}
|
||||
"""
|
||||
shared.update_aesthetic_embeddings()
|
||||
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
|
||||
value=sorted(aesthetic_embeddings.keys())[0] if len(
|
||||
aesthetic_embeddings) > 0 else None), res, ""
|
Reference in New Issue
Block a user