Allow saving "before-highres-fix. (#4150)

* Save image/s before doing highres fix.
This commit is contained in:
timntorres
2022-11-02 02:18:21 -07:00
committed by GitHub
parent 4a8cf01f6f
commit 9c67408004
3 changed files with 18 additions and 5 deletions

View File

@@ -521,7 +521,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
# Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
if isinstance(p, StableDiffusionProcessingTxt2Img):
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n)
else:
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -649,7 +653,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -685,6 +689,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
# Save a copy of the image/s before doing highres fix, if applicable.
if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix:
for i in range(self.batch_size):
# This batch's ith image.
img = sd_samplers.sample_to_image(samples, i)
# Index that accounts for both batch size and batch count.
ind = i + self.batch_size*n
images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix")
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)