changes for #294

This commit is contained in:
AUTOMATIC
2022-09-12 20:09:32 +03:00
parent 11e03b9abd
commit c7e0e28ccd
3 changed files with 22 additions and 32 deletions

View File

@@ -31,3 +31,20 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device()
device_codeformer = cpu if has_mps else device
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
generator.manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
torch.manual_seed(seed)
return torch.randn(shape, device=device)