Custom Width and Height

This commit is contained in:
alg-wiki
2022-10-10 22:35:35 +09:00
parent 4ee7519fc2
commit 04c745ea4f
4 changed files with 26 additions and 23 deletions

View File

@@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
self.width = size
self.height = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []