This commit is contained in:
wangshuai09
2024-01-31 10:46:53 +08:00
parent 74ff85a1a1
commit cc3f604310
8 changed files with 22 additions and 20 deletions

View File

@@ -88,9 +88,16 @@ def torch_gc():
xpu_specific.torch_xpu_gc()
if npu_specific.has_npu:
torch_npu_set_device()
npu_specific.torch_npu_gc()
def torch_npu_set_device():
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
if npu_specific.has_npu:
torch.npu.set_device(0)
def enable_tf32():
if torch.cuda.is_available():