mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Support SDXL v-pred models
This commit is contained in:
@@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False):
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None):
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
||||
@@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
if not checkpoint_config:
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
|
||||
timer.record("find config")
|
||||
@@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
if sd_model is not None:
|
||||
send_model_to_trash(sd_model)
|
||||
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config)
|
||||
return model_data.sd_model
|
||||
|
||||
try:
|
||||
|
Reference in New Issue
Block a user