train: change filename processing to be more simple and configurable

train: make it possible to make text files with prompts
train: rework scheduler so that there's less repeating code in textual inversion and hypernets
train: move epochs setting to options
This commit is contained in:
AUTOMATIC
2022-10-12 20:49:47 +03:00
parent cc5803603b
commit c3c8eef9fd
7 changed files with 106 additions and 63 deletions

View File

@@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception:
continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens)
if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
if include_cond:
text = self.create_text(filename_tokens)
cond = cond_model([text]).to(devices.cpu)
else:
cond = None
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
self.dataset.append((init_latent, filename_tokens, cond))
if include_cond:
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
self.dataset.append(entry)
self.length = len(self.dataset) * repeats
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def create_text(self, filename_tokens):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens, cond = self.dataset[index]
entry = self.dataset[index]
text = self.create_text(filename_tokens)
return x, text, cond
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
return entry