mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
added resize seeds and variation seeds features
This commit is contained in:
@@ -29,7 +29,7 @@ def torch_gc():
|
||||
|
||||
|
||||
class StableDiffusionProcessing:
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
@@ -37,6 +37,10 @@ class StableDiffusionProcessing:
|
||||
self.prompt_for_display: str = None
|
||||
self.negative_prompt: str = (negative_prompt or "")
|
||||
self.seed: int = seed
|
||||
self.subseed: int = subseed
|
||||
self.subseed_strength: float = subseed_strength
|
||||
self.seed_resize_from_h: int = seed_resize_from_h
|
||||
self.seed_resize_from_w: int = seed_resize_from_w
|
||||
self.sampler_index: int = sampler_index
|
||||
self.batch_size: int = batch_size
|
||||
self.n_iter: int = n_iter
|
||||
@@ -84,23 +88,67 @@ class Processed:
|
||||
|
||||
return json.dumps(obj)
|
||||
|
||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||
def slerp(val, low, high):
|
||||
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||
omega = torch.acos((low_norm*high_norm).sum(1))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
def create_random_tensors(shape, seeds):
|
||||
|
||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||
xs = []
|
||||
for seed in seeds:
|
||||
torch.manual_seed(seed)
|
||||
for i, seed in enumerate(seeds):
|
||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||
|
||||
subnoise = None
|
||||
if subseeds is not None:
|
||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
||||
torch.manual_seed(subseed)
|
||||
subnoise = torch.randn(noise_shape, device=shared.device)
|
||||
|
||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||
# but the original script had it like this so I do not dare change it for now because
|
||||
# but the original script had it like this, so I do not dare change it for now because
|
||||
# it will break everyone's seeds.
|
||||
xs.append(torch.randn(shape, device=shared.device))
|
||||
x = torch.stack(xs)
|
||||
torch.manual_seed(seed)
|
||||
noise = torch.randn(noise_shape, device=shared.device)
|
||||
|
||||
if subnoise is not None:
|
||||
#noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
|
||||
noise = slerp(subseed_strength, noise, subnoise)
|
||||
|
||||
if noise_shape != shape:
|
||||
#noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
|
||||
# noise_shape = (64, 80)
|
||||
# shape = (64, 72)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
x = torch.randn(shape, device=shared.device)
|
||||
dx = (shape[2] - noise_shape[2]) // 2 # -4
|
||||
dy = (shape[1] - noise_shape[1]) // 2
|
||||
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||
tx = 0 if dx < 0 else dx
|
||||
ty = 0 if dy < 0 else dy
|
||||
dx = max(-dx, 0)
|
||||
dy = max(-dy, 0)
|
||||
|
||||
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
||||
noise = x
|
||||
|
||||
|
||||
|
||||
xs.append(noise)
|
||||
x = torch.stack(xs).to(shared.device)
|
||||
return x
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
return int(random.randrange(4294967294)) if seed is None or seed == -1 else seed
|
||||
def fix_seed(p):
|
||||
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
|
||||
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
@@ -111,7 +159,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
assert p.prompt is not None
|
||||
torch_gc()
|
||||
|
||||
seed = set_seed(p.seed)
|
||||
fix_seed(p)
|
||||
|
||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||
@@ -125,20 +173,31 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
all_prompts = p.batch_size * p.n_iter * [prompt]
|
||||
|
||||
if type(seed) == list:
|
||||
all_seeds = seed
|
||||
if type(p.seed) == list:
|
||||
all_seeds = int(p.seed)
|
||||
else:
|
||||
all_seeds = [int(seed + x) for x in range(len(all_prompts))]
|
||||
all_seeds = [int(p.seed + x) for x in range(len(all_prompts))]
|
||||
|
||||
if type(p.subseed) == list:
|
||||
all_subseeds = p.subseed
|
||||
else:
|
||||
all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
|
||||
|
||||
def infotext(iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": samplers[p.sampler_index].name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Seed": all_seeds[position_in_batch + iteration * p.batch_size],
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
}
|
||||
|
||||
if p.extra_generation_params is not None:
|
||||
@@ -174,7 +233,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
comments += model_hijack.comments
|
||||
|
||||
# we manually generate all input noises because each one should have a specific seed
|
||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=all_subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
@@ -231,10 +290,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
output_images.insert(0, grid)
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", seed, all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||
|
||||
torch_gc()
|
||||
return Processed(p, output_images, seed, infotext())
|
||||
return Processed(p, output_images, all_seeds[0], infotext())
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
Reference in New Issue
Block a user