Add a check and explanation for tensor with all NaNs.

This commit is contained in:
AUTOMATIC
2023-01-16 22:59:46 +03:00
parent 52f6e94338
commit 9991967f40
3 changed files with 33 additions and 0 deletions

View File

@@ -106,6 +106,33 @@ def autocast(disable=False):
return torch.autocast("cuda")
class NansException(Exception):
pass
def test_for_nans(x, where):
from modules import shared
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."
raise NansException(message)
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
@@ -156,3 +183,4 @@ if has_mps():
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )