medvram support for SD3

This commit is contained in:
AUTOMATIC1111
2024-06-24 10:15:46 +03:00
parent a65dd315ad
commit a8fba9af35
4 changed files with 35 additions and 8 deletions

View File

@@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
@@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
@@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'cond_stage_model'),
(self, 'model'),
]