train: change filename processing to be more simple and configurable

train: make it possible to make text files with prompts
train: rework scheduler so that there's less repeating code in textual inversion and hypernets
train: move epochs setting to options
This commit is contained in:
AUTOMATIC
2022-10-12 20:49:47 +03:00
parent cc5803603b
commit c3c8eef9fd
7 changed files with 106 additions and 63 deletions

View File

@@ -14,7 +14,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule
from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
@@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar:
for i, entry in pbar:
hypernetwork.step = i + ititial_step
if hypernetwork.step > end_step:
try:
(learn_rate, end_step) = next(schedules)
except Exception:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
cond = cond.to(devices.device)
x = x.to(devices.device)
cond = entry.cond.to(devices.device)
x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
@@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
)
processed = processing.process_images(p)
image = processed.images[0]
image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
if image is not None:
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>