mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 13:19:54 +00:00
train: change filename processing to be more simple and configurable
train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options
This commit is contained in:
@@ -11,11 +11,21 @@ import tqdm
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = os.path.splitext(path)[0] + ".txt"
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
filename_tokens = re_tag.findall(filename_tokens)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
with open(text_filename, "r", encoding="utf8") as file:
|
||||
filename_text = file.read()
|
||||
else:
|
||||
filename_text = os.path.splitext(filename)[0]
|
||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
if re_word:
|
||||
tokens = re_word.findall(filename_text)
|
||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
|
||||
if include_cond:
|
||||
text = self.create_text(filename_tokens)
|
||||
cond = cond_model([text]).to(devices.cpu)
|
||||
else:
|
||||
cond = None
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens, cond))
|
||||
if include_cond:
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
||||
|
||||
self.dataset.append(entry)
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
|
||||
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
|
||||
def create_text(self, filename_tokens):
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
x, filename_tokens, cond = self.dataset[index]
|
||||
entry = self.dataset[index]
|
||||
|
||||
text = self.create_text(filename_tokens)
|
||||
return x, text, cond
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
|
||||
return entry
|
||||
|
Reference in New Issue
Block a user