append_tag_shuffle

This commit is contained in:
TinkTheBoush
2022-11-01 23:29:12 +09:00
parent c28de154b0
commit 467cae167a
4 changed files with 15 additions and 6 deletions

View File

@@ -24,7 +24,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -33,6 +33,7 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.shuffle_tags = shuffle_tags
self.dataset = []
@@ -98,7 +99,12 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", filename_text)
if self.tag_shuffle:
tags = filename_text.split(',')
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
else:
text = text.replace("[filewords]", filename_text)
return text
def __len__(self):