add an option to unload models during hypernetwork training to save VRAM

This commit is contained in:
AUTOMATIC
2022-10-11 19:03:08 +03:00
parent 6d09b8d1df
commit d4ea5f4d86
5 changed files with 46 additions and 18 deletions

View File

@@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
@@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
else:
images_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
@@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
return hypernetwork, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text) in pbar:
for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step
if hypernetwork.step > steps:
@@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break
with torch.autocast("cuda"):
c = cond_model([text])
cond = cond.to(devices.device)
x = x.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
losses[hypernetwork.step % losses.shape[0]] = loss.item()
@@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
preview_text = text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=preview_text,
@@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
processed = processing.process_images(p)
image = processed.images[0]
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image
image.save(last_saved_image)