mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 13:19:54 +00:00
initial SD3 support
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import torch, math, einops
|
||||
from mmdit import MMDiT
|
||||
from modules.models.sd3.mmdit import MMDiT
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""):
|
||||
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
|
||||
super().__init__()
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
|
||||
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
|
||||
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
|
||||
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
|
Reference in New Issue
Block a user