mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Fix MPS cache cleanup
Importing torch does not import torch.mps so the call failed.
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import platform
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# use check `getattr` and try it for compatibility.
|
||||
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
|
||||
return False
|
||||
else:
|
||||
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
|
||||
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
def torch_mps_gc() -> None:
|
||||
try:
|
||||
from torch.mps import empty_cache
|
||||
empty_cache()
|
||||
except Exception:
|
||||
log.warning("MPS garbage collection failed", exc_info=True)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
|
Reference in New Issue
Block a user