support loading clip/t5 from the main model checkpoint

This commit is contained in:
AUTOMATIC1111
2024-06-29 00:38:52 +03:00
parent d67348a0a5
commit 7e4b06fcd0
3 changed files with 24 additions and 25 deletions

View File

@@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.cond_stage_model = SD3Cond()
self.text_encoders = SD3Cond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
@@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
self.latent_format = SD3LatentFormat()
self.latent_channels = 16
def after_load_weights(self):
self.cond_stage_model.load_weights()
@property
def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self):
return contextlib.nullcontext()