Merge branch 'AUTOMATIC1111:master' into master

This commit is contained in:
Zac Liu
2022-12-06 09:16:15 +08:00
committed by GitHub
40 changed files with 582 additions and 121 deletions

View File

@@ -44,6 +44,15 @@ def get_optimal_device():
return cpu
def get_device_for(task):
from modules import shared
if task in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
@@ -53,37 +62,35 @@ def torch_gc():
def enable_tf32():
if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
generator.manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
torch.manual_seed(seed)
if device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
generator = torch.Generator(device=cpu)
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
return noise
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)