mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
apply lr schedule to hypernets
This commit is contained in:
@@ -10,6 +10,7 @@ import datetime
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
|
||||
|
||||
class Embedding:
|
||||
@@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
||||
|
||||
scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(scheduleIter)
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
@@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
|
||||
if embedding.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(scheduleIter)
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
@@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
embedding.save(filename)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
class LearnSchedule:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
for i, pair in enumerate(pairs):
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
if step > cur_step:
|
||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||
self.maxit += 1
|
||||
if step > max_steps:
|
||||
return
|
||||
elif step == -1:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
else:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.it < self.maxit:
|
||||
self.it += 1
|
||||
return self.rates[self.it - 1]
|
||||
else:
|
||||
raise StopIteration
|
||||
|
Reference in New Issue
Block a user