mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-10 10:50:09 +00:00
Merge remote-tracking branch 'upstream/master' into interrogate_include_ranks_in_output
This commit is contained in:
@@ -1,21 +1,99 @@
|
||||
import os.path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import get_context
|
||||
import multiprocessing
|
||||
import time
|
||||
import re
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
def get_deepbooru_tags(pil_image):
|
||||
"""
|
||||
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
|
||||
try:
|
||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
||||
return get_tags_from_process(pil_image)
|
||||
finally:
|
||||
release_process()
|
||||
|
||||
|
||||
def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
|
||||
def create_deepbooru_opts():
|
||||
from modules import shared
|
||||
|
||||
return {
|
||||
"use_spaces": shared.opts.deepbooru_use_spaces,
|
||||
"use_escape": shared.opts.deepbooru_escape,
|
||||
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
||||
}
|
||||
|
||||
|
||||
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
||||
model, tags = get_deepbooru_tags_model()
|
||||
while True: # while process is running, keep monitoring queue for new image
|
||||
pil_image = queue.get()
|
||||
if pil_image == "QUIT":
|
||||
break
|
||||
else:
|
||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
||||
|
||||
|
||||
def create_deepbooru_process(threshold, deepbooru_opts):
|
||||
"""
|
||||
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
||||
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
||||
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
||||
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
||||
the tags.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_manager = multiprocessing.Manager()
|
||||
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
||||
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
||||
shared.deepbooru_process.start()
|
||||
|
||||
|
||||
def get_tags_from_process(image):
|
||||
from modules import shared
|
||||
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
caption = shared.deepbooru_process_return["value"]
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
def release_process():
|
||||
"""
|
||||
Stops the deepbooru process to return used memory
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_queue.put("QUIT")
|
||||
shared.deepbooru_process.join()
|
||||
shared.deepbooru_process_queue = None
|
||||
shared.deepbooru_process = None
|
||||
shared.deepbooru_process_return = None
|
||||
shared.deepbooru_process_manager = None
|
||||
|
||||
def get_deepbooru_tags_model():
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
this_folder = os.path.dirname(__file__)
|
||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||
# there is no point importing these every time
|
||||
import zipfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||
model_path)
|
||||
load_file_from_url(
|
||||
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||
model_path)
|
||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||
@@ -24,6 +102,17 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
|
||||
model = dd.project.load_model_from_project(
|
||||
model_path, compile_model=True
|
||||
)
|
||||
return model, tags
|
||||
|
||||
|
||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
alpha_sort = deepbooru_opts['alpha_sort']
|
||||
use_spaces = deepbooru_opts['use_spaces']
|
||||
use_escape = deepbooru_opts['use_escape']
|
||||
|
||||
width = model.input_shape[2]
|
||||
height = model.input_shape[1]
|
||||
@@ -46,32 +135,35 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
|
||||
|
||||
for i, tag in enumerate(tags):
|
||||
result_dict[tag] = y[i]
|
||||
result_tags_out = []
|
||||
|
||||
unsorted_tags_in_theshold = []
|
||||
result_tags_print = []
|
||||
for tag in tags:
|
||||
if result_dict[tag] >= threshold:
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
tag_formatted = tag.replace('_', ' ').replace(':', ' ')
|
||||
if include_ranks:
|
||||
result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})')
|
||||
else:
|
||||
result_tags_out.append(tag_formatted)
|
||||
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||
|
||||
# sort tags
|
||||
result_tags_out = []
|
||||
sort_ndx = 0
|
||||
if alpha_sort:
|
||||
sort_ndx = 1
|
||||
|
||||
# sort by reverse by likelihood and normal for alpha
|
||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||
for weight, tag in unsorted_tags_in_theshold:
|
||||
result_tags_out.append(tag)
|
||||
|
||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||
|
||||
return ', '.join(result_tags_out)
|
||||
tags_text = ', '.join(result_tags_out)
|
||||
|
||||
if use_spaces:
|
||||
tags_text = tags_text.replace('_', ' ')
|
||||
|
||||
def subprocess_init_no_cuda():
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
if use_escape:
|
||||
tags_text = re.sub(re_special, r'\\\1', tags_text)
|
||||
|
||||
|
||||
def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False):
|
||||
context = get_context('spawn')
|
||||
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
|
||||
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks)
|
||||
ret = f.result() # will rethrow any exceptions
|
||||
return ret
|
||||
return tags_text.replace(':', ' ')
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
@@ -19,7 +20,7 @@ import gradio as gr
|
||||
cached_images = {}
|
||||
|
||||
|
||||
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||
devices.torch_gc()
|
||||
|
||||
imageArr = []
|
||||
@@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize):
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
||||
@@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
|
||||
c = cropped
|
||||
cached_images[key] = c
|
||||
|
||||
return c
|
||||
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
|
||||
@@ -190,7 +200,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
theta_0[key] = theta_1[key]
|
||||
|
@@ -14,7 +14,7 @@ import torch
|
||||
from torch import einsum
|
||||
from einops import rearrange, repeat
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
@@ -120,6 +120,17 @@ def load_hypernetwork(filename):
|
||||
shared.loaded_hypernetwork = None
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
if not search:
|
||||
return None
|
||||
search = search.lower()
|
||||
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
||||
if not applicable:
|
||||
return None
|
||||
applicable = sorted(applicable, key=lambda name: len(name))
|
||||
return applicable[0]
|
||||
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
|
||||
@@ -164,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
||||
assert hypernetwork_name, 'embedding not selected'
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
@@ -212,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
if ititial_step > steps:
|
||||
return hypernetwork, filename
|
||||
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, (x, text, cond) in pbar:
|
||||
for i, entry in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
||||
if hypernetwork.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except Exception:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
cond = cond.to(devices.device)
|
||||
x = x.to(devices.device)
|
||||
cond = entry.cond.to(devices.device)
|
||||
x = entry.latent.to(devices.device)
|
||||
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
||||
del x
|
||||
del cond
|
||||
@@ -256,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
|
||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
optimizer.zero_grad()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
@@ -271,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
)
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
image = processed.images[0] if len(processed.images)>0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
@@ -288,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(text)}<br/>
|
||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
@@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
|
@@ -86,6 +86,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||
xformers_available = False
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
loaded_hypernetwork = None
|
||||
|
||||
@@ -229,7 +230,10 @@ options_templates.update(options_section(('system', "System"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"),
|
||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
@@ -255,8 +259,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
|
@@ -11,11 +11,21 @@ import tqdm
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = os.path.splitext(path)[0] + ".txt"
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
filename_tokens = re_tag.findall(filename_tokens)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
with open(text_filename, "r", encoding="utf8") as file:
|
||||
filename_text = file.read()
|
||||
else:
|
||||
filename_text = os.path.splitext(filename)[0]
|
||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
if re_word:
|
||||
tokens = re_word.findall(filename_text)
|
||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
|
||||
if include_cond:
|
||||
text = self.create_text(filename_tokens)
|
||||
cond = cond_model([text]).to(devices.cpu)
|
||||
else:
|
||||
cond = None
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens, cond))
|
||||
if include_cond:
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
||||
|
||||
self.dataset.append(entry)
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
|
||||
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
|
||||
def create_text(self, filename_tokens):
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
x, filename_tokens, cond = self.dataset[index]
|
||||
entry = self.dataset[index]
|
||||
|
||||
text = self.create_text(filename_tokens)
|
||||
return x, text, cond
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
|
||||
return entry
|
||||
|
219
modules/textual_inversion/image_embedding.py
Normal file
219
modules/textual_inversion/image_embedding.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import base64
|
||||
import json
|
||||
import numpy as np
|
||||
import zlib
|
||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||
from fonts.ttf import Roboto
|
||||
import torch
|
||||
|
||||
|
||||
class EmbeddingEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
class EmbeddingDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
||||
|
||||
def object_hook(self, d):
|
||||
if 'TORCHTENSOR' in d:
|
||||
return torch.from_numpy(np.array(d['TORCHTENSOR']))
|
||||
return d
|
||||
|
||||
|
||||
def embedding_to_b64(data):
|
||||
d = json.dumps(data, cls=EmbeddingEncoder)
|
||||
return base64.b64encode(d.encode())
|
||||
|
||||
|
||||
def embedding_from_b64(data):
|
||||
d = base64.b64decode(data)
|
||||
return json.loads(d, cls=EmbeddingDecoder)
|
||||
|
||||
|
||||
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
|
||||
while True:
|
||||
seed = (a * seed + c) % m
|
||||
yield seed % 255
|
||||
|
||||
|
||||
def xor_block(block):
|
||||
g = lcg()
|
||||
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
|
||||
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
|
||||
|
||||
|
||||
def style_block(block, sequence):
|
||||
im = Image.new('RGB', (block.shape[1], block.shape[0]))
|
||||
draw = ImageDraw.Draw(im)
|
||||
i = 0
|
||||
for x in range(-6, im.size[0], 8):
|
||||
for yi, y in enumerate(range(-6, im.size[1], 8)):
|
||||
offset = 0
|
||||
if yi % 2 == 0:
|
||||
offset = 4
|
||||
shade = sequence[i % len(sequence)]
|
||||
i += 1
|
||||
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
|
||||
|
||||
fg = np.array(im).astype(np.uint8) & 0xF0
|
||||
|
||||
return block ^ fg
|
||||
|
||||
|
||||
def insert_image_data_embed(image, data):
|
||||
d = 3
|
||||
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
|
||||
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
|
||||
data_np_high = data_np_ >> 4
|
||||
data_np_low = data_np_ & 0x0F
|
||||
|
||||
h = image.size[1]
|
||||
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
||||
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
||||
|
||||
data_np_low.resize(next_size)
|
||||
data_np_low = data_np_low.reshape((h, -1, d))
|
||||
|
||||
data_np_high.resize(next_size)
|
||||
data_np_high = data_np_high.reshape((h, -1, d))
|
||||
|
||||
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
||||
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
|
||||
|
||||
data_np_low = style_block(data_np_low, sequence=edge_style)
|
||||
data_np_low = xor_block(data_np_low)
|
||||
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
|
||||
data_np_high = xor_block(data_np_high)
|
||||
|
||||
im_low = Image.fromarray(data_np_low, mode='RGB')
|
||||
im_high = Image.fromarray(data_np_high, mode='RGB')
|
||||
|
||||
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
|
||||
background.paste(im_low, (0, 0))
|
||||
background.paste(image, (im_low.size[0]+1, 0))
|
||||
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
|
||||
|
||||
return background
|
||||
|
||||
|
||||
def crop_black(img, tol=0):
|
||||
mask = (img > tol).all(2)
|
||||
mask0, mask1 = mask.any(0), mask.any(1)
|
||||
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
|
||||
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
|
||||
return img[row_start:row_end, col_start:col_end]
|
||||
|
||||
|
||||
def extract_image_data_embed(image):
|
||||
d = 3
|
||||
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
|
||||
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
|
||||
if black_cols[0].shape[0] < 2:
|
||||
print('No Image data blocks found.')
|
||||
return None
|
||||
|
||||
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
|
||||
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
|
||||
|
||||
data_block_lower = xor_block(data_block_lower)
|
||||
data_block_upper = xor_block(data_block_upper)
|
||||
|
||||
data_block = (data_block_upper << 4) | (data_block_lower)
|
||||
data_block = data_block.flatten().tobytes()
|
||||
|
||||
data = zlib.decompress(data_block)
|
||||
return json.loads(data, cls=EmbeddingDecoder)
|
||||
|
||||
|
||||
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
|
||||
from math import cos
|
||||
|
||||
image = srcimage.copy()
|
||||
|
||||
if textfont is None:
|
||||
try:
|
||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
textfont = opts.font or Roboto
|
||||
except Exception:
|
||||
textfont = Roboto
|
||||
|
||||
factor = 1.5
|
||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||
for y in range(image.size[1]):
|
||||
mag = 1-cos(y/image.size[1]*factor)
|
||||
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
|
||||
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
|
||||
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
fontsize = 32
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
padding = 10
|
||||
|
||||
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
||||
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
||||
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
|
||||
|
||||
_, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
|
||||
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
_, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
|
||||
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
|
||||
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
|
||||
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
|
||||
|
||||
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
|
||||
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
|
||||
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
testEmbed = Image.open('test_embedding.png')
|
||||
data = extract_image_data_embed(testEmbed)
|
||||
assert data is not None
|
||||
|
||||
data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
|
||||
assert data is not None
|
||||
|
||||
image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
|
||||
cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
|
||||
|
||||
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
|
||||
|
||||
embedded_image = insert_image_data_embed(cap_image, test_embed)
|
||||
|
||||
retrived_embed = extract_image_data_embed(embedded_image)
|
||||
|
||||
assert str(retrived_embed) == str(test_embed)
|
||||
|
||||
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
|
||||
|
||||
assert embedded_image == embedded_image2
|
||||
|
||||
g = lcg()
|
||||
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
|
||||
|
||||
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
|
||||
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
|
||||
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
|
||||
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
|
||||
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
|
||||
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
|
||||
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
|
||||
204, 86, 73, 222, 44, 198, 118, 240, 97]
|
||||
|
||||
assert shared_random == reference_random
|
||||
|
||||
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
|
||||
|
||||
assert 12731374 == hunna_kay_random_sum
|
@@ -1,6 +1,12 @@
|
||||
import tqdm
|
||||
|
||||
class LearnSchedule:
|
||||
|
||||
class LearnScheduleIterator:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
"""
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
|
||||
"""
|
||||
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
@@ -32,3 +38,32 @@ class LearnSchedule:
|
||||
return self.rates[self.it - 1]
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class LearnRateScheduler:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
|
||||
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
self.verbose = verbose
|
||||
|
||||
if self.verbose:
|
||||
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
||||
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if step_number <= self.end_step:
|
||||
return
|
||||
|
||||
try:
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
except Exception:
|
||||
self.finished = True
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
||||
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = self.learn_rate
|
||||
|
||||
|
@@ -3,11 +3,35 @@ from PIL import Image, ImageOps
|
||||
import platform
|
||||
import sys
|
||||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import shared, images
|
||||
from modules.shared import opts, cmd_opts
|
||||
if cmd_opts.deepdanbooru:
|
||||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
|
||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
|
||||
|
||||
finally:
|
||||
|
||||
if process_caption:
|
||||
shared.interrogator.send_blip_to_ram()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.release_process()
|
||||
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
@@ -22,19 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
||||
def save_pic_with_caption(image, index):
|
||||
if process_caption:
|
||||
caption = "-" + shared.interrogator.generate_caption(image)
|
||||
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
||||
else:
|
||||
caption = filename
|
||||
caption = os.path.splitext(caption)[0]
|
||||
caption = os.path.basename(caption)
|
||||
caption = ""
|
||||
|
||||
if process_caption:
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.get_tags_from_process(image)
|
||||
|
||||
filename_part = filename
|
||||
filename_part = os.path.splitext(filename_part)[0]
|
||||
filename_part = os.path.basename(filename_part)
|
||||
|
||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
|
||||
subindex[0] += 1
|
||||
|
||||
def save_pic(image, index):
|
||||
@@ -79,30 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
save_pic(img, index)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
if process_caption:
|
||||
shared.interrogator.send_blip_to_ram()
|
||||
|
||||
def sanitize_caption(base_path, original_caption, suffix):
|
||||
operating_system = platform.system().lower()
|
||||
if (operating_system == "windows"):
|
||||
invalid_path_characters = "\\/:*?\"<>|"
|
||||
max_path_length = 259
|
||||
else:
|
||||
invalid_path_characters = "/" #linux/macos
|
||||
max_path_length = 1023
|
||||
caption = original_caption
|
||||
for invalid_character in invalid_path_characters:
|
||||
caption = caption.replace(invalid_character, "")
|
||||
fixed_path_length = len(base_path) + len(suffix)
|
||||
if fixed_path_length + len(caption) <= max_path_length:
|
||||
return caption
|
||||
caption_tokens = caption.split()
|
||||
new_caption = ""
|
||||
for token in caption_tokens:
|
||||
last_caption = new_caption
|
||||
new_caption = new_caption + token + " "
|
||||
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
|
||||
break
|
||||
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
|
||||
return last_caption.strip()
|
||||
|
BIN
modules/textual_inversion/test_embedding.png
Normal file
BIN
modules/textual_inversion/test_embedding.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 478 KiB |
@@ -7,11 +7,15 @@ import tqdm
|
||||
import html
|
||||
import datetime
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
||||
insert_image_data_embed, extract_image_data_embed,
|
||||
caption_image_overlay)
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vec, name, step=None):
|
||||
@@ -81,7 +85,18 @@ class EmbeddingDatabase:
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
|
||||
data = torch.load(path, map_location="cpu")
|
||||
data = []
|
||||
|
||||
if filename.upper().endswith('.PNG'):
|
||||
embed_image = Image.open(path)
|
||||
if 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
@@ -157,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
@@ -179,11 +194,17 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
if create_image_every > 0 and save_image_with_stored_embedding:
|
||||
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
||||
os.makedirs(images_embeds_dir, exist_ok=True)
|
||||
else:
|
||||
images_embeds_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
@@ -199,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text, _) in pbar:
|
||||
for i, entry in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
||||
if embedding.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([text])
|
||||
c = cond_model([entry.cond_text])
|
||||
|
||||
x = x.to(devices.device)
|
||||
x = entry.latent.to(devices.device)
|
||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||
del x
|
||||
|
||||
@@ -246,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
@@ -262,6 +275,26 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
image = processed.images[0]
|
||||
|
||||
shared.state.current_image = image
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||
|
||||
title = "<{}>".format(data.get('name', '???'))
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}'.format(embedding.step)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
@@ -272,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(text)}<br/>
|
||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
100
modules/ui.py
100
modules/ui.py
@@ -107,7 +107,7 @@ def send_gradio_gallery_to_image(x):
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
import csv
|
||||
import csv
|
||||
filenames = []
|
||||
fullfns = []
|
||||
|
||||
@@ -131,6 +131,8 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
images = [images[index]]
|
||||
start_index = index
|
||||
|
||||
os.makedirs(opts.outdir_save, exist_ok=True)
|
||||
|
||||
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||
at_start = file.tell() == 0
|
||||
writer = csv.writer(file)
|
||||
@@ -181,8 +183,15 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
print("Arguments:", args, kwargs, file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.job = ""
|
||||
@@ -311,12 +320,13 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
|
||||
|
||||
|
||||
def interrogate(image):
|
||||
prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks)
|
||||
prompt = shared.interrogator.interrogate(image)
|
||||
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
|
||||
|
||||
def interrogate_deepbooru(image):
|
||||
prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold, opts.interrogate_return_ranks)
|
||||
prompt = get_deepbooru_tags(image)
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
|
||||
|
||||
@@ -557,11 +567,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
@@ -745,11 +755,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
@@ -911,7 +921,15 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.TabItem('Batch Process'):
|
||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||
with gr.TabItem('Scale by'):
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
||||
with gr.TabItem('Scale to'):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
|
||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
@@ -942,6 +960,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
_js="get_extras_tab_index",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
dummy_component,
|
||||
extras_image,
|
||||
image_batch,
|
||||
@@ -949,6 +968,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
codeformer_visibility,
|
||||
codeformer_weight,
|
||||
upscaling_resize,
|
||||
upscaling_resize_w,
|
||||
upscaling_resize_h,
|
||||
upscaling_crop,
|
||||
extras_upscaler_1,
|
||||
extras_upscaler_2,
|
||||
extras_upscaler_2_visibility,
|
||||
@@ -959,14 +981,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
extras_send_to_img2img.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_img2img",
|
||||
inputs=[result_images],
|
||||
outputs=[init_img],
|
||||
)
|
||||
|
||||
|
||||
extras_send_to_inpaint.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_inpaint",
|
||||
@@ -1013,14 +1035,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks() as textual_inversion_interface:
|
||||
with gr.Blocks() as train_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Tabs(elem_id="train_tabs"):
|
||||
|
||||
with gr.Tab(label="Create embedding"):
|
||||
new_embedding_name = gr.Textbox(label="Name")
|
||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||
@@ -1032,9 +1054,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
|
||||
|
||||
with gr.Tab(label="Create hypernetwork"):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
|
||||
@@ -1045,9 +1065,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Column():
|
||||
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
||||
|
||||
with gr.Tab(label="Preprocess images"):
|
||||
process_src = gr.Textbox(label='Source directory')
|
||||
process_dst = gr.Textbox(label='Destination directory')
|
||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
@@ -1056,7 +1074,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Row():
|
||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||
process_split = gr.Checkbox(label='Split oversized images into two')
|
||||
process_caption = gr.Checkbox(label='Use BLIP caption as filename')
|
||||
process_caption = gr.Checkbox(label='Use BLIP for caption')
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
@@ -1065,7 +1084,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Column():
|
||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
with gr.Tab(label="Train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||
@@ -1076,9 +1095,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
||||
|
||||
with gr.Row():
|
||||
@@ -1134,6 +1153,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
process_flip,
|
||||
process_split,
|
||||
process_caption,
|
||||
process_caption_deepbooru
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
@@ -1152,10 +1172,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
num_repeats,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
save_image_with_stored_embedding,
|
||||
preview_image_prompt,
|
||||
],
|
||||
outputs=[
|
||||
@@ -1351,7 +1371,7 @@ Requested path was: {f}
|
||||
outputs=[],
|
||||
_js='function(){restart_reload()}'
|
||||
)
|
||||
|
||||
|
||||
if column is not None:
|
||||
column.__exit__()
|
||||
|
||||
@@ -1361,7 +1381,7 @@ Requested path was: {f}
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||
(train_interface, "Train", "ti"),
|
||||
(settings_interface, "Settings", "settings"),
|
||||
]
|
||||
|
||||
@@ -1383,12 +1403,12 @@ Requested path was: {f}
|
||||
component_dict[k] = component
|
||||
|
||||
settings_interface.gradio_ref = demo
|
||||
|
||||
|
||||
with gr.Tabs() as tabs:
|
||||
for interface, label, ifid in interfaces:
|
||||
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
||||
interface.render()
|
||||
|
||||
|
||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
|
||||
@@ -1521,10 +1541,10 @@ Requested path was: {f}
|
||||
|
||||
if getattr(obj,'custom_script_source',None) is not None:
|
||||
key = 'customscript/' + obj.custom_script_source + '/' + key
|
||||
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
return
|
||||
|
||||
|
||||
saved_value = ui_settings.get(key, None)
|
||||
if saved_value is None:
|
||||
ui_settings[key] = getattr(obj, field)
|
||||
@@ -1548,10 +1568,10 @@ Requested path was: {f}
|
||||
|
||||
if type(x) == gr.Textbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
|
||||
if type(x) == gr.Number:
|
||||
apply_field(x, 'value')
|
||||
|
||||
|
||||
visit(txt2img_interface, loadsave, "txt2img")
|
||||
visit(img2img_interface, loadsave, "img2img")
|
||||
visit(extras_interface, loadsave, "extras")
|
||||
|
Reference in New Issue
Block a user