mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
Gradient accumulation, autocast fix, new latent sampling method, etc
This commit is contained in:
@@ -3,7 +3,7 @@ import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
import random
|
||||
@@ -11,25 +11,28 @@ import tqdm
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
self.latent_dist = latent_dist
|
||||
self.latent_sample = latent_sample
|
||||
self.cond = cond
|
||||
self.cond_text = cond_text
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
if shared.state.interrupted:
|
||||
raise Exception("inturrupted")
|
||||
try:
|
||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
@@ -71,37 +79,58 @@ class PersonalizedBase(Dataset):
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = None
|
||||
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
with torch.autocast("cuda"):
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
latent_sampling_method = "once"
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "deterministic":
|
||||
# Works only for DiagonalGaussianDistribution
|
||||
latent_dist.std = 0
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||
|
||||
if include_cond:
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with torch.autocast("cuda"):
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
# elif not include_cond:
|
||||
# _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text])
|
||||
# max_n = token_count // 75
|
||||
# index_list = [ [] for _ in range(max_n + 1) ]
|
||||
# for n, (z, _) in hijack_fixes[0]:
|
||||
# index_list[n].append(z)
|
||||
# with torch.autocast("cuda"):
|
||||
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
# entry.emb_index = index_list
|
||||
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
|
||||
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = np.random.permutation(self.dataset_length)
|
||||
self.length = len(self.dataset)
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
tags = filename_text.split(',')
|
||||
if shared.opts.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
|
||||
if shared.opts.shuffle_tags:
|
||||
if self.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
if self.shuffle_tags:
|
||||
random.shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
return text
|
||||
@@ -110,19 +139,28 @@ class PersonalizedBase(Dataset):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
res = []
|
||||
entry = self.dataset[i]
|
||||
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
if self.latent_sampling_method == "random":
|
||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist)
|
||||
return entry
|
||||
|
||||
for j in range(self.batch_size):
|
||||
position = i * self.batch_size + j
|
||||
if position % len(self.indexes) == 0:
|
||||
self.shuffle()
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs)
|
||||
self.collate_fn = collate_wrapper
|
||||
|
||||
|
||||
index = self.indexes[position % len(self.indexes)]
|
||||
entry = self.dataset[index]
|
||||
class BatchLoader:
|
||||
def __init__(self, data):
|
||||
self.cond_text = [entry.cond_text for entry in data]
|
||||
self.cond = [entry.cond for entry in data]
|
||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
def pin_memory(self):
|
||||
self.latent_sample = self.latent_sample.pin_memory()
|
||||
return self
|
||||
|
||||
res.append(entry)
|
||||
|
||||
return res
|
||||
def collate_wrapper(batch):
|
||||
return BatchLoader(batch)
|
Reference in New Issue
Block a user