mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Merge branch 'dev' into master
This commit is contained in:
@@ -1,22 +1,22 @@
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from os import mkdir
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
import numpy as np
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@@ -49,11 +49,12 @@ class CheckpointInfo:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
|
||||
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
|
||||
name = abspath.replace(abs_ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
@@ -129,9 +130,12 @@ except Exception:
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""called once at startup to do various one-time tasks related to SD models"""
|
||||
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
enable_midas_autodownload()
|
||||
patch_given_betas()
|
||||
|
||||
|
||||
def checkpoint_tiles(use_short=False):
|
||||
@@ -309,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
# move to end as latest
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
@@ -353,12 +359,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if model.is_ssd:
|
||||
sd_hijack.model_hijack.conv_ssd(model)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
del state_dict
|
||||
|
||||
@@ -456,6 +462,20 @@ def enable_midas_autodownload():
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def patch_given_betas():
|
||||
import ldm.models.diffusion.ddpm
|
||||
|
||||
def patched_register_schedule(*args, **kwargs):
|
||||
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
||||
|
||||
if isinstance(args[1], ListConfig):
|
||||
args = (args[0], np.array(args[1]), *args[2:])
|
||||
|
||||
original_register_schedule(*args, **kwargs)
|
||||
|
||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
|
||||
|
||||
def repair_config(sd_config):
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
@@ -780,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
send_model_to_cpu(sd_model or shared.sd_model)
|
||||
|
||||
return sd_model
|
||||
|
||||
|
Reference in New Issue
Block a user