mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
support loading clip/t5 from the main model checkpoint
This commit is contained in:
@@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
if hasattr(model, "before_load_weights"):
|
||||
model.before_load_weights(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if hasattr(model, "after_load_weights"):
|
||||
model.after_load_weights(state_dict)
|
||||
|
||||
del state_dict
|
||||
|
||||
# Set is_sdxl_inpaint flag.
|
||||
@@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if hasattr(sd_model, "after_load_weights"):
|
||||
sd_model.after_load_weights()
|
||||
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
|
Reference in New Issue
Block a user