mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-09 13:49:48 +00:00
Learning rate sched syntax support for grad clipping
This commit is contained in:
@@ -255,9 +255,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
|
||||
clip_grad_mode_value = clip_grad_mode == "value"
|
||||
clip_grad_mode_norm = clip_grad_mode == "norm"
|
||||
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
|
||||
if clip_grad_enabled:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
@@ -273,6 +276,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if clip_grad_enabled:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([entry.cond_text for entry in entries])
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
@@ -285,9 +291,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
loss.backward()
|
||||
|
||||
if clip_grad_mode_value:
|
||||
torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
|
||||
torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
|
||||
elif clip_grad_mode_norm:
|
||||
torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
|
||||
torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
|
Reference in New Issue
Block a user