Refactor conditional casting, fix upscalers

This commit is contained in:
brkirch
2023-01-27 10:19:43 -05:00
parent c4b9b07db6
commit ada17dbd7c
5 changed files with 25 additions and 10 deletions

View File

@@ -83,6 +83,14 @@ dtype_unet = torch.float16
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape):
torch.manual_seed(seed)
if device.type == 'mps':