This commit is contained in:
wangshuai09
2024-01-31 10:46:53 +08:00
parent 74ff85a1a1
commit cc3f604310
8 changed files with 22 additions and 20 deletions

View File

@@ -8,11 +8,10 @@ def check_for_npu():
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu
torch_npu.npu.set_device(0)
try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
_ = torch_npu.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
@@ -25,8 +24,6 @@ def get_npu_device_string():
def torch_npu_gc():
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
torch.npu.set_device(0)
with torch.npu.device(get_npu_device_string()):
torch.npu.empty_cache()