support loading clip/t5 from the main model checkpoint

This commit is contained in:
AUTOMATIC1111
2024-06-29 00:38:52 +03:00
parent d67348a0a5
commit 7e4b06fcd0
3 changed files with 24 additions and 25 deletions

View File

@@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
if hasattr(model, "before_load_weights"):
model.before_load_weights(state_dict)
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
if hasattr(model, "after_load_weights"):
model.after_load_weights(state_dict)
del state_dict
# Set is_sdxl_inpaint flag.
@@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if hasattr(sd_model, "after_load_weights"):
sd_model.after_load_weights()
timer.record("load weights from state dict")
send_model_to_device(sd_model)