mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
initial SD3 support
This commit is contained in:
@@ -34,9 +34,9 @@ class Block(nn.Module):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
|
||||
def decoder():
|
||||
def decoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
@@ -44,13 +44,13 @@ def decoder():
|
||||
)
|
||||
|
||||
|
||||
def encoder():
|
||||
def encoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
|
||||
|
||||
@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, decoder_path="taesd_decoder.pth"):
|
||||
def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.decoder = decoder()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
|
||||
|
||||
self.decoder = decoder(latent_channels)
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth"):
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = encoder()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
|
||||
|
||||
self.encoder = encoder(latent_channels)
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@@ -87,7 +95,13 @@ def download_model(model_path, model_url):
|
||||
|
||||
|
||||
def decoder_model():
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_decoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_decoder.pth"
|
||||
else:
|
||||
model_name = "taesd_decoder.pth"
|
||||
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
@@ -106,7 +120,13 @@ def decoder_model():
|
||||
|
||||
|
||||
def encoder_model():
|
||||
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_encoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_encoder.pth"
|
||||
else:
|
||||
model_name = "taesd_encoder.pth"
|
||||
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
|
Reference in New Issue
Block a user