mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
Gradient accumulation, autocast fix, new latent sampling method, etc
This commit is contained in:
@@ -184,7 +184,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
|
||||
if (step + 1) % shared.opts.training_write_csv_every != 0:
|
||||
if step % shared.opts.training_write_csv_every != 0:
|
||||
return
|
||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||
|
||||
@@ -194,21 +194,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
if write_csv_header:
|
||||
csv_writer.writeheader()
|
||||
|
||||
epoch = step // epoch_len
|
||||
epoch_step = step % epoch_len
|
||||
epoch = (step - 1) // epoch_len
|
||||
epoch_step = (step - 1) % epoch_len
|
||||
|
||||
csv_writer.writerow({
|
||||
"step": step + 1,
|
||||
"step": step,
|
||||
"epoch": epoch,
|
||||
"epoch_step": epoch_step + 1,
|
||||
"epoch_step": epoch_step,
|
||||
**values,
|
||||
})
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
assert learn_rate, "Learning rate is empty or 0"
|
||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||
assert batch_size > 0, "Batch size must be positive"
|
||||
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
||||
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
||||
assert data_root, "Dataset directory is empty"
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
@@ -224,10 +226,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
@@ -255,161 +257,205 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
else:
|
||||
images_embeds_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step >= steps:
|
||||
initial_step = embedding.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, entries in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
with torch.autocast("cuda"):
|
||||
# c = stack_conds(batch.cond).to(devices.device)
|
||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||
# print(mask)
|
||||
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
del x
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
#print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
|
||||
#scaler.unscale_(optimizer)
|
||||
#print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
|
||||
#torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0)
|
||||
#print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
embedding.step += 1
|
||||
pbar.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_step = _loss_step
|
||||
_loss_step = 0
|
||||
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
steps_done = embedding.step + 1
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
epoch_num = embedding.step // steps_per_epoch
|
||||
epoch_step = embedding.step % steps_per_epoch
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([entry.cond_text for entry in entries])
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
del x
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
#if shared.opts.save_optimizer_state:
|
||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
steps_done = embedding.step + 1
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step % len(ds)
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
do_not_reload_embeddings=True,
|
||||
)
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
preview_text = p.prompt
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||
"loss": f"{losses.mean():.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
do_not_reload_embeddings=True,
|
||||
)
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
||||
preview_text = p.prompt
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
title = "<{}>".format(data.get('name', '???'))
|
||||
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
try:
|
||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||
except Exception as e:
|
||||
vectorSize = '?'
|
||||
|
||||
shared.state.current_image = image
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
title = "<{}>".format(data.get('name', '???'))
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
try:
|
||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||
except Exception as e:
|
||||
vectorSize = '?'
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
Reference in New Issue
Block a user