Add utility to inspect a model's parameters (to get dtype/device)

This commit is contained in:
Aarni Koskela
2023-12-31 00:20:30 +02:00
parent a84e842189
commit 5768afc776
8 changed files with 53 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
from modules.torch_utils import get_param
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = next(model.model.diffusion_model.parameters()).dtype
dtype = get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'