Fix MPS cache cleanup

Importing torch does not import torch.mps so the call failed.
This commit is contained in:
Aarni Koskela
2023-07-10 21:18:34 +03:00
parent 7b833291b3
commit b85fc7187d
2 changed files with 17 additions and 2 deletions

View File

@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif has_mps() and hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():