mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
@@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.paths as paths
|
||||
import modules.face_restoration
|
||||
from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
import modules.sd_models as sd_models
|
||||
@@ -799,17 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
unet_object = p.sd_model.model
|
||||
vae_model = p.sd_model.first_stage_model
|
||||
try:
|
||||
from hyper_tile import split_attention, flush
|
||||
except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df
|
||||
split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager
|
||||
flush = lambda: None
|
||||
import random
|
||||
saved_rng_state = random.getstate()
|
||||
random.seed(p.seed) # hyper_tile uses random, so we need to seed it
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
@@ -871,29 +861,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.comment(comment)
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
set_hypertile_seed(p.seed)
|
||||
# add batch size + hypertile status to information to reproduce the run
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
# get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height
|
||||
gcd = math.gcd(p.width, p.height)
|
||||
largest_tile_size_available = 1
|
||||
while gcd % (largest_tile_size_available * 2) == 0:
|
||||
largest_tile_size_available *= 2
|
||||
aspect_ratio = p.width / p.height
|
||||
with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn):
|
||||
with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn):
|
||||
flush()
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if getattr(samples_ddim, 'already_decoded', False):
|
||||
x_samples_ddim = samples_ddim
|
||||
else:
|
||||
if opts.sd_vae_decode_method != 'Full':
|
||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||
with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn):
|
||||
flush()
|
||||
with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
@@ -1000,7 +981,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
random.setstate(saved_rng_state)
|
||||
if not p.disable_extra_networks and p.extra_network_data:
|
||||
extra_networks.deactivate(p, p.extra_network_data)
|
||||
|
||||
@@ -1161,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
aspect_ratio = self.width / self.height
|
||||
x = self.rng.next()
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
tile_size = largest_tile_size_available(self.width, self.height)
|
||||
with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
devices.torch_gc()
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
else:
|
||||
decoded_samples = None
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||
|
||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||
@@ -1186,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
return samples
|
||||
|
||||
self.is_hr_pass = True
|
||||
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
|
||||
@@ -1264,18 +1244,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.before_hr(self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
tile_size = largest_tile_size_available(target_width, target_height)
|
||||
with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
self.sampler = None
|
||||
devices.torch_gc()
|
||||
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
self.is_hr_pass = False
|
||||
|
||||
return decoded_samples
|
||||
|
||||
def close(self):
|
||||
@@ -1550,8 +1531,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if self.initial_noise_multiplier != 1.0:
|
||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||
x *= self.initial_noise_multiplier
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
aspect_ratio = self.width / self.height
|
||||
tile_size = largest_tile_size_available(self.width, self.height)
|
||||
with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||
devices.torch_gc()
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
Reference in New Issue
Block a user