rework RNG to use generators instead of generating noises beforehand

This commit is contained in:
AUTOMATIC1111
2023-08-09 08:43:31 +03:00
parent d81d3fa8cd
commit 0d5dc9a6e7
5 changed files with 196 additions and 171 deletions

View File

@@ -1,5 +1,5 @@
import inspect
from collections import namedtuple, deque
from collections import namedtuple
import numpy as np
import torch
from PIL import Image
@@ -132,10 +132,15 @@ replace_torchsde_browinan()
class TorchHijack:
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
"""This is here to replace torch.randn_like of k-diffusion.
k-diffusion has random_sampler argument for most samplers, but not for all, so
this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, p):
self.rng = p.rng
def __getattr__(self, item):
if item == 'randn_like':
@@ -147,12 +152,7 @@ class TorchHijack:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return devices.randn_like(x)
return self.rng.next()
class Sampler:
@@ -215,7 +215,7 @@ class Sampler:
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
k_diffusion.sampling.torch = TorchHijack(p)
extra_params_kwargs = {}
for param_name in self.extra_params: