mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
train: change filename processing to be more simple and configurable
train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options
This commit is contained in:
@@ -11,11 +11,21 @@ import tqdm
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = os.path.splitext(path)[0] + ".txt"
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
filename_tokens = re_tag.findall(filename_tokens)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
with open(text_filename, "r", encoding="utf8") as file:
|
||||
filename_text = file.read()
|
||||
else:
|
||||
filename_text = os.path.splitext(filename)[0]
|
||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
if re_word:
|
||||
tokens = re_word.findall(filename_text)
|
||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
|
||||
if include_cond:
|
||||
text = self.create_text(filename_tokens)
|
||||
cond = cond_model([text]).to(devices.cpu)
|
||||
else:
|
||||
cond = None
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens, cond))
|
||||
if include_cond:
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
||||
|
||||
self.dataset.append(entry)
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
|
||||
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
|
||||
def create_text(self, filename_tokens):
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
x, filename_tokens, cond = self.dataset[index]
|
||||
entry = self.dataset[index]
|
||||
|
||||
text = self.create_text(filename_tokens)
|
||||
return x, text, cond
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
|
||||
return entry
|
||||
|
@@ -1,6 +1,12 @@
|
||||
import tqdm
|
||||
|
||||
class LearnSchedule:
|
||||
|
||||
class LearnScheduleIterator:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
"""
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
|
||||
"""
|
||||
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
@@ -32,3 +38,32 @@ class LearnSchedule:
|
||||
return self.rates[self.it - 1]
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class LearnRateScheduler:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
|
||||
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
self.verbose = verbose
|
||||
|
||||
if self.verbose:
|
||||
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
||||
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if step_number <= self.end_step:
|
||||
return
|
||||
|
||||
try:
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
except Exception:
|
||||
self.finished = True
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
||||
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = self.learn_rate
|
||||
|
||||
|
@@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
||||
insert_image_data_embed, extract_image_data_embed,
|
||||
@@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
@@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
@@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text, _) in pbar:
|
||||
for i, entry in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
||||
if embedding.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([text])
|
||||
c = cond_model([entry.cond_text])
|
||||
|
||||
x = x.to(devices.device)
|
||||
x = entry.latent.to(devices.device)
|
||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||
del x
|
||||
|
||||
@@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
@@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(text)}<br/>
|
||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
Reference in New Issue
Block a user