first attempt to produce crrect seeds in batch

This commit is contained in:
AUTOMATIC
2022-09-13 21:49:58 +03:00
parent 85b97cc49c
commit 9d40212485
3 changed files with 51 additions and 2 deletions

View File

@@ -48,3 +48,13 @@ def randn(seed, shape):
torch.manual_seed(seed)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
return torch.randn(shape, device=device)