add option to use batch size for training

This commit is contained in:
AUTOMATIC
2022-10-15 09:24:59 +03:00
parent acedbe67d2
commit c7a86f7fe9
4 changed files with 54 additions and 30 deletions

View File

@@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -78,13 +79,13 @@ class PersonalizedBase(Dataset):
if include_cond:
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
self.length = len(self.dataset) * repeats
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(self.length) % len(self.dataset)
self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
@@ -101,13 +102,19 @@ class PersonalizedBase(Dataset):
return self.length
def __getitem__(self, i):
if i % len(self.dataset) == 0:
self.shuffle()
res = []
index = self.indexes[i % len(self.indexes)]
entry = self.dataset[index]
for j in range(self.batch_size):
position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
index = self.indexes[position % len(self.indexes)]
entry = self.dataset[index]
return entry
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
res.append(entry)
return res