mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Merge branch 'dev' into npu_support
This commit is contained in:
@@ -23,6 +23,23 @@ def has_mps() -> bool:
|
||||
return mac_specific.has_mps
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
if device_id is None:
|
||||
device_id = get_cuda_device_id()
|
||||
return (
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return (
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
else 0
|
||||
) or torch.cuda.current_device()
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
@@ -79,8 +96,7 @@ def enable_tf32():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
||||
if cuda_no_autocast():
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -90,6 +106,7 @@ def enable_tf32():
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
@@ -98,6 +115,7 @@ device_codeformer: torch.device = None
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype_vae: torch.dtype = torch.float16
|
||||
dtype_unet: torch.dtype = torch.float16
|
||||
dtype_inference: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
@@ -110,15 +128,89 @@ def cond_cast_float(input):
|
||||
|
||||
|
||||
nv_rng = None
|
||||
patch_module_list = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
if isinstance(result, tuple):
|
||||
result = tuple(
|
||||
i.to(dtype_inference)
|
||||
if isinstance(i, torch.Tensor)
|
||||
else i
|
||||
for i in result
|
||||
)
|
||||
elif isinstance(result, torch.Tensor):
|
||||
result = result.to(dtype_inference)
|
||||
return result
|
||||
return forward_wrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
applied = False
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if applied:
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
module_type.forward = module_type.org_forward
|
||||
delattr(module_type, "org_forward")
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
if fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user