mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
append_tag_shuffle
This commit is contained in:
@@ -24,7 +24,7 @@ class DatasetEntry:
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
@@ -33,6 +33,7 @@ class PersonalizedBase(Dataset):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
self.shuffle_tags = shuffle_tags
|
||||
|
||||
self.dataset = []
|
||||
|
||||
@@ -98,7 +99,12 @@ class PersonalizedBase(Dataset):
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
if self.tag_shuffle:
|
||||
tags = filename_text.split(',')
|
||||
random.shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
else:
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
|
Reference in New Issue
Block a user