Option to use CPU for random number generation.

Makes a given manual seed generate the same images across different
platforms, independently of the GPU architecture in use.

Fixes #9613.
This commit is contained in:
Deciare
2023-04-18 23:18:58 -04:00
committed by Deciare
parent 22bcc7be42
commit d40e44ade4
4 changed files with 17 additions and 3 deletions

View File

@@ -60,3 +60,12 @@ def store_latent(decoded):
class InterruptedException(BaseException):
pass
if opts.use_cpu_randn:
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn