mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
Add Tiny AE live preview
This commit is contained in:
76
modules/sd_vae_taesd.py
Normal file
76
modules/sd_vae_taesd.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Tiny AutoEncoder for Stable Diffusion
|
||||
(DNN for encoding / decoding SD's latent space)
|
||||
|
||||
https://github.com/madebyollin/taesd
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal
|
||||
|
||||
sd_vae_taesd = None
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
|
||||
class Clamp(nn.Module):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
|
||||
def decoder():
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 2
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, decoder_path="taesd_decoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.decoder = decoder()
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
|
||||
def decode():
|
||||
global sd_vae_taesd
|
||||
|
||||
if sd_vae_taesd is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
|
||||
if os.path.exists(model_path):
|
||||
sd_vae_taesd = TAESD(model_path)
|
||||
sd_vae_taesd.eval()
|
||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||
else:
|
||||
raise FileNotFoundError('Tiny AE mdoel not found')
|
||||
|
||||
return sd_vae_taesd.decoder
|
Reference in New Issue
Block a user