getting SD2.1 to run on SDXL repo

This commit is contained in:
AUTOMATIC1111
2023-07-11 21:16:43 +03:00
parent 7b833291b3
commit af081211ee
9 changed files with 152 additions and 24 deletions

40
modules/sd_models_xl.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import torch
import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]):
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0
c = self.conditioner({'txt': batch})
return c
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
return self.model(x, t, cond)
def extend_sdxl(model):
dtype = next(model.model.diffusion_model.parameters()).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0]
model.cond_stage_key = model.cond_stage_model.input_key
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model