SD3 lora support

This commit is contained in:
AUTOMATIC1111
2024-07-15 08:31:55 +03:00
parent b2453d280a
commit 7e5cdaab4b
6 changed files with 106 additions and 24 deletions

View File

@@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
}
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
self.depth = depth
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()