Add NPU Support

This commit is contained in:
wangshuai09
2024-01-27 17:21:32 +08:00
parent cf2772fab0
commit ec124607f4
7 changed files with 62 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
from modules import errors, shared
from modules import errors, shared, npu_specific
if sys.platform == "darwin":
from modules import mac_specific
@@ -40,6 +40,9 @@ def get_optimal_device_name():
if has_xpu():
return xpu_specific.get_xpu_device_string()
if npu_specific.has_npu:
return npu_specific.get_npu_device_string()
return "cpu"
@@ -67,6 +70,9 @@ def torch_gc():
if has_xpu():
xpu_specific.torch_xpu_gc()
if npu_specific.has_npu:
npu_specific.torch_npu_gc()
def enable_tf32():
if torch.cuda.is_available():
@@ -164,4 +170,3 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)