Modular device management

This commit is contained in:
Abdullah Barhoum
2022-09-11 07:11:27 +02:00
committed by AUTOMATIC1111
parent 065e310a3f
commit b5d1af11b7
4 changed files with 19 additions and 13 deletions

12
modules/devices.py Normal file
View File

@@ -0,0 +1,12 @@
import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
def get_optimal_device():
if torch.cuda.is_available():
return torch.device("cuda")
if has_mps:
return torch.device("mps")
return torch.device("cpu")