Merge pull request #7309 from brkirch/fix-embeddings

Fix embeddings, upscalers, and refactor `--upcast-sampling`
This commit is contained in:
AUTOMATIC1111
2023-01-28 18:44:36 +03:00
committed by GitHub
6 changed files with 26 additions and 14 deletions

View File

@@ -87,6 +87,14 @@ dtype_unet = torch.float16
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape):
torch.manual_seed(seed)
if device.type == 'mps':
@@ -199,6 +207,3 @@ if has_mps():
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )