support loading clip/t5 from the main model checkpoint

This commit is contained in:
AUTOMATIC1111
2024-06-29 00:38:52 +03:00
parent d67348a0a5
commit 7e4b06fcd0
3 changed files with 24 additions and 25 deletions

View File

@@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
self.model_t5 = Sd3T5(self.t5xxl)
self.weights_loaded = False
def forward(self, prompts: list[str]):
with devices.without_autocast():
lg_out, vector_out = self.model_lg(prompts)
token_count = lg_out.shape[1]
t5_out = self.model_t5(prompts, token_count=token_count)
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
return {
@@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
'vector': vector_out,
}
def load_weights(self):
if self.weights_loaded:
return
def before_load_weights(self, state_dict):
clip_path = os.path.join(shared.models_path, "CLIP")
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl:
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
self.weights_loaded = True
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX

View File

@@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.cond_stage_model = SD3Cond()
self.text_encoders = SD3Cond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
@@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
self.latent_format = SD3LatentFormat()
self.latent_channels = 16
def after_load_weights(self):
self.cond_stage_model.load_weights()
@property
def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self):
return contextlib.nullcontext()