Revert "Merge pull request #16078 from huchenlei/fix_sd2"

This reverts commit 4cc3add770, reversing
changes made to 50514ce414.
This commit is contained in:
AUTOMATIC1111
2024-07-06 10:40:48 +03:00
parent 0a6628bad0
commit ffead92d4e
2 changed files with 1 additions and 8 deletions

View File

@@ -56,19 +56,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
unet.eval()
with torch.no_grad():
unet_dtype = torch.float
original_unet_dtype = devices.dtype_unet
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=unet_dtype)
devices.dtype_unet = unet_dtype
unet.to(device=device, dtype=torch.float)
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
devices.dtype_unet = original_unet_dtype
return out < -1