mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
made 'reuse seed' button give you the seed/subseed of the currently selected picture rather than the first
This commit is contained in:
@@ -83,7 +83,7 @@ class StableDiffusionProcessing:
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed, info, subseed=None):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0):
|
||||
self.images = images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
@@ -93,26 +93,62 @@ class Processed:
|
||||
self.info = info
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_index = p.sampler_index
|
||||
self.sampler = samplers[p.sampler_index].name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
self.restore_faces = p.restore_faces
|
||||
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
||||
self.sd_model_hash = shared.sd_model.sd_model_hash
|
||||
self.seed_resize_from_w = p.seed_resize_from_w
|
||||
self.seed_resize_from_h = p.seed_resize_from_h
|
||||
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
||||
self.extra_generation_params = p.extra_generation_params
|
||||
self.index_of_first_image = index_of_first_image
|
||||
|
||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||
|
||||
self.all_prompts = all_prompts or [self.prompt]
|
||||
self.all_seeds = all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or [self.subseed]
|
||||
|
||||
def js(self):
|
||||
obj = {
|
||||
"prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
|
||||
"negative_prompt": self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0],
|
||||
"seed": int(self.seed if type(self.seed) != list else self.seed[0]),
|
||||
"subseed": int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1,
|
||||
"prompt": self.prompt,
|
||||
"all_prompts": self.all_prompts,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"seed": self.seed,
|
||||
"all_seeds": self.all_seeds,
|
||||
"subseed": self.subseed,
|
||||
"all_subseeds": self.all_subseeds,
|
||||
"subseed_strength": self.subseed_strength,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"sampler_index": self.sampler_index,
|
||||
"sampler": self.sampler,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"steps": self.steps,
|
||||
"batch_size": self.batch_size,
|
||||
"restore_faces": self.restore_faces,
|
||||
"face_restoration_model": self.face_restoration_model,
|
||||
"sd_model_hash": self.sd_model_hash,
|
||||
"seed_resize_from_w": self.seed_resize_from_w,
|
||||
"seed_resize_from_h": self.seed_resize_from_h,
|
||||
"denoising_strength": self.denoising_strength,
|
||||
"extra_generation_params": self.extra_generation_params,
|
||||
"index_of_first_image": self.index_of_first_image,
|
||||
}
|
||||
|
||||
return json.dumps(obj)
|
||||
|
||||
def infotext(self, p: StableDiffusionProcessing, index):
|
||||
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
||||
|
||||
|
||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||
def slerp(val, low, high):
|
||||
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||
@@ -156,11 +192,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
noise = devices.randn(seed, noise_shape)
|
||||
|
||||
if subnoise is not None:
|
||||
#noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
|
||||
noise = slerp(subseed_strength, noise, subnoise)
|
||||
|
||||
if noise_shape != shape:
|
||||
#noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
|
||||
x = devices.randn(seed, shape)
|
||||
dx = (shape[2] - noise_shape[2]) // 2
|
||||
dy = (shape[1] - noise_shape[1]) // 2
|
||||
@@ -194,6 +228,35 @@ def fix_seed(p):
|
||||
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": samplers[p.sampler_index].name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
}
|
||||
|
||||
if p.extra_generation_params is not None:
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
@@ -231,32 +294,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
|
||||
|
||||
def infotext(iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": samplers[p.sampler_index].name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
}
|
||||
|
||||
if p.extra_generation_params is not None:
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir):
|
||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
||||
@@ -350,18 +388,20 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
p.color_corrections = None
|
||||
|
||||
index_of_first_image = 0
|
||||
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
||||
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
||||
grid = images.image_grid(output_images, p.batch_size)
|
||||
|
||||
if opts.return_grid:
|
||||
output_images.insert(0, grid)
|
||||
index_of_first_image = 1
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
|
||||
|
||||
devices.torch_gc()
|
||||
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0])
|
||||
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
Reference in New Issue
Block a user