add XL support for live previews: approx and TAESD

This commit is contained in:
AUTOMATIC1111
2023-07-13 17:24:54 +03:00
parent 6f23da603d
commit b8159d0919
3 changed files with 40 additions and 25 deletions

View File

@@ -8,9 +8,9 @@ import os
import torch
import torch.nn as nn
from modules import devices, paths_internal
from modules import devices, paths_internal, shared
sd_vae_taesd = None
sd_vae_taesd_models = {}
def conv(n_in, n_out, **kwargs):
@@ -61,9 +61,7 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def download_model(model_path):
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
@@ -72,17 +70,19 @@ def download_model(model_path):
def model():
global sd_vae_taesd
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
if sd_vae_taesd is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
download_model(model_path)
if loaded_model is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path):
sd_vae_taesd = TAESD(model_path)
sd_vae_taesd.eval()
sd_vae_taesd.to(devices.device, devices.dtype)
loaded_model = TAESD(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')
return sd_vae_taesd.decoder
return loaded_model.decoder