Merge branch 'master' into inpaint_textual_inversion

This commit is contained in:
AUTOMATIC1111
2023-01-04 17:40:19 +03:00
committed by GitHub
131 changed files with 8627 additions and 8919 deletions

View File

@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models, images
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -23,9 +23,12 @@ class Embedding:
self.vec = vec
self.name = name
self.step = step
self.shape = None
self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -39,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
optimizer_saved_dict = {
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, filename + '.optim')
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -57,14 +67,17 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
@@ -74,21 +87,26 @@ class EmbeddingDatabase:
return embedding
def load_textual_inversion_embeddings(self):
def get_expected_shape(self):
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
name, ext = os.path.splitext(filename)
ext = ext.upper()
data = []
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -96,8 +114,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
else:
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -121,7 +141,13 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
@@ -132,12 +158,13 @@ class EmbeddingDatabase:
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -155,13 +182,11 @@ class EmbeddingDatabase:
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
@@ -184,7 +209,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if (step + 1) % shared.opts.training_write_csv_every != 0:
if step % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
@@ -194,21 +219,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header:
csv_writer.writeheader()
epoch = step // epoch_len
epoch_step = step % epoch_len
epoch = (step - 1) // epoch_len
epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
"step": step + 1,
"step": step,
"epoch": epoch,
"epoch_step": epoch_step + 1,
"epoch_step": epoch_step,
**values,
})
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
@@ -244,11 +271,12 @@ def create_dummy_mask(x, width=None, height=None):
return image_conditioning
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -275,31 +303,58 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
initial_step = embedding.step or 0
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
old_parallel_processing_allowed = shared.parallel_processing_allowed
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
if os.path.exists(filename + '.optim'):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
losses = torch.zeros((32,))
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>"
last_saved_image = "<none>"
@@ -307,138 +362,166 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
img_c = None
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
if img_c is None:
img_c = create_dummy_mask(c, training_width, training_height)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x
if shared.state.interrupted:
break
_loss_step += loss.item()
scaler.scale(loss).backward()
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
if img_c is None:
img_c = create_dummy_mask(c, training_width, training_height)
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
scaler.step(optimizer)
scaler.update()
embedding.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0]
del x
steps_done = embedding.step + 1
losses[embedding.step % losses.shape[0]] = loss.item()
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
steps_done = embedding.step + 1
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step % len(ds)
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
shared.sd_model.first_stage_model.to(devices.device)
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
})
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
preview_text = p.prompt
shared.sd_model.first_stage_model.to(devices.device)
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
preview_text = p.prompt
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
processed = processing.process_images(p)
image = processed.images[0]
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
shared.state.current_image = image
title = "<{}>".format(data.get('name', '???'))
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
title = "<{}>".format(data.get('name', '???'))
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
shared.state.job_no = embedding.step
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -449,6 +532,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint