Disable ipex autocast due to its bad perf

This commit is contained in:
Nuullll
2023-12-02 14:00:46 +08:00
parent 8b40f475a3
commit 7499148ad4
4 changed files with 51 additions and 17 deletions

View File

@@ -3,11 +3,18 @@ import contextlib
from functools import lru_cache
import torch
from modules import errors, shared, xpu_specific
from modules import errors, shared
if sys.platform == "darwin":
from modules import mac_specific
if shared.cmd_opts.use_ipex:
from modules import xpu_specific
def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
def has_mps() -> bool:
if sys.platform != "darwin":
@@ -30,7 +37,7 @@ def get_optimal_device_name():
if has_mps():
return "mps"
if xpu_specific.has_ipex:
if has_xpu():
return xpu_specific.get_xpu_device_string()
return "cpu"
@@ -57,6 +64,9 @@ def torch_gc():
if has_mps():
mac_specific.torch_mps_gc()
if has_xpu():
xpu_specific.torch_xpu_gc()
def enable_tf32():
if torch.cuda.is_available():
@@ -103,15 +113,11 @@ def autocast(disable=False):
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
if xpu_specific.has_xpu:
return torch.autocast("xpu")
return torch.autocast("cuda")
def without_autocast(disable=False):
device_type = "xpu" if xpu_specific.has_xpu else "cuda"
return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
class NansException(Exception):