mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
apply lr schedule to hypernets
This commit is contained in:
@@ -14,6 +14,7 @@ import torch
|
||||
from torch import einsum
|
||||
from einops import rearrange, repeat
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
@@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
|
||||
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
|
||||
last_saved_file = "<none>"
|
||||
@@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
if ititial_step > steps:
|
||||
return hypernetwork, filename
|
||||
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, (x, text, cond) in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
||||
if hypernetwork.step > steps:
|
||||
break
|
||||
if hypernetwork.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except Exception:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
Reference in New Issue
Block a user