mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
Merge branch 'AUTOMATIC1111:master' into deepdanbooru_pre_process
This commit is contained in:
@@ -8,14 +8,14 @@ from torchvision import transforms
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
from modules import devices
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
@@ -32,12 +32,15 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
image = Image.open(path)
|
||||
image = image.convert('RGB')
|
||||
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
try:
|
||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
@@ -52,7 +55,13 @@ class PersonalizedBase(Dataset):
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens))
|
||||
if include_cond:
|
||||
text = self.create_text(filename_tokens)
|
||||
cond = cond_model([text]).to(devices.cpu)
|
||||
else:
|
||||
cond = None
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens, cond))
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
|
||||
@@ -63,6 +72,12 @@ class PersonalizedBase(Dataset):
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
|
||||
def create_text(self, filename_tokens):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
@@ -71,10 +86,7 @@ class PersonalizedBase(Dataset):
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
x, filename_tokens = self.dataset[index]
|
||||
x, filename_tokens, cond = self.dataset[index]
|
||||
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
|
||||
return x, text
|
||||
text = self.create_text(filename_tokens)
|
||||
return x, text, cond
|
||||
|
34
modules/textual_inversion/learn_schedule.py
Normal file
34
modules/textual_inversion/learn_schedule.py
Normal file
@@ -0,0 +1,34 @@
|
||||
|
||||
class LearnSchedule:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
for i, pair in enumerate(pairs):
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
if step > cur_step:
|
||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||
self.maxit += 1
|
||||
if step > max_steps:
|
||||
return
|
||||
elif step == -1:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
else:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.it < self.maxit:
|
||||
self.it += 1
|
||||
return self.rates[self.it - 1]
|
||||
else:
|
||||
raise StopIteration
|
@@ -60,7 +60,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
subindex = [0]
|
||||
filename = os.path.join(src, imagefile)
|
||||
img = Image.open(filename).convert("RGB")
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
@@ -10,6 +10,7 @@ import datetime
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
|
||||
|
||||
class Embedding:
|
||||
@@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
@@ -189,8 +190,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||
embedding.vec.requires_grad = True
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
|
||||
last_saved_file = "<none>"
|
||||
@@ -200,15 +199,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text) in pbar:
|
||||
for i, (x, text, _) in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
||||
if embedding.step > steps:
|
||||
break
|
||||
if embedding.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
@@ -226,10 +234,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_num = embedding.step // epoch_len
|
||||
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
@@ -238,12 +246,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
prompt=text,
|
||||
prompt=preview_text,
|
||||
steps=20,
|
||||
height=training_height,
|
||||
width=training_width,
|
||||
height=training_height,
|
||||
width=training_width,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
@@ -254,7 +264,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image += f", prompt: {text}"
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
||||
@@ -276,4 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
embedding.save(filename)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
@@ -23,6 +23,8 @@ def preprocess(*args):
|
||||
|
||||
def train_embedding(*args):
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
|
Reference in New Issue
Block a user