mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
do not load wait for shared.sd_model to load at startup
This commit is contained in:
@@ -2,6 +2,8 @@ import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
@@ -404,13 +406,39 @@ def repair_config(sd_config):
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_sd_model(self):
|
||||
if self.sd_model is None:
|
||||
with self.lock:
|
||||
try:
|
||||
load_model()
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model")
|
||||
print("", file=sys.stderr)
|
||||
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||
self.sd_model = None
|
||||
|
||||
return self.sd_model
|
||||
|
||||
def set_sd_model(self, v):
|
||||
self.sd_model = v
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
if model_data.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
model_data.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
@@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
sd_model = model_data.sd_model
|
||||
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
@@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return shared.sd_model
|
||||
return model_data.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
@@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if shared.sd_model:
|
||||
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
shared.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
Reference in New Issue
Block a user