From 74069addc31e6cb24a5fb394419aef87b43a8b2c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 6 Jul 2024 11:00:22 +0300 Subject: [PATCH] SD2 v autodetection fix --- modules/sd_models_config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 599153c2d..fb44c5a8d 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -58,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict): with torch.no_grad(): unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} unet.load_state_dict(unet_sd, strict=True) - unet.to(device=device, dtype=torch.float) + unet.to(device=device, dtype=devices.dtype_unet) test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 - out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + with devices.autocast(): + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() return out < -1