mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 13:19:54 +00:00
support loading clip/t5 from the main model checkpoint
This commit is contained in:
@@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
|
||||
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||
self.model_t5 = Sd3T5(self.t5xxl)
|
||||
|
||||
self.weights_loaded = False
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
with devices.without_autocast():
|
||||
lg_out, vector_out = self.model_lg(prompts)
|
||||
|
||||
token_count = lg_out.shape[1]
|
||||
|
||||
t5_out = self.model_t5(prompts, token_count=token_count)
|
||||
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
|
||||
return {
|
||||
@@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
|
||||
'vector': vector_out,
|
||||
}
|
||||
|
||||
def load_weights(self):
|
||||
if self.weights_loaded:
|
||||
return
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl:
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
self.weights_loaded = True
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.tensor([[0]], device=devices.device) # XXX
|
||||
|
||||
|
Reference in New Issue
Block a user