Merge branch 'dev' into bgh-handle-metadata-issues-more-cleanly

This commit is contained in:
AUTOMATIC1111
2024-06-08 12:16:55 +03:00
committed by GitHub
50 changed files with 481 additions and 255 deletions

View File

@@ -149,10 +149,12 @@ def list_models():
cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
model_url = None
expected_sha256 = None
else:
model_url = f"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa'
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256)
if os.path.exists(cmd_ckpt):
checkpoint_info = CheckpointInfo(cmd_ckpt)
@@ -384,6 +386,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
# Set is_sdxl_inpaint flag.
diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None)
model.is_sdxl_inpaint = (
model.is_sdxl and
diffusion_model_input is not None and
diffusion_model_input.shape[1] == 9
)
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@@ -407,6 +416,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
timer.record("apply float()")
else:
vae = model.first_stage_model
@@ -544,7 +554,7 @@ def repair_config(sd_config):
if hasattr(sd_config.model.params, 'unet_config'):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
@@ -555,6 +565,14 @@ def repair_config(sd_config):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr(sd_config.model.params, "network_config"):
sd_config.model.params.network_config.params.use_checkpoint = False
if hasattr(sd_config.model.params, "unet_config"):
sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
@@ -663,10 +681,11 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
if m is not None:
if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
devices.torch_gc()