added torch.mps.empty_cache() to torch_gc()

changed a bunch of places that use torch.cuda.empty_cache() to use torch_gc() instead
This commit is contained in:
AUTOMATIC1111
2023-07-08 17:13:18 +03:00
parent e161b5a025
commit da8916f926
6 changed files with 10 additions and 13 deletions

View File

@@ -85,7 +85,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache()
devices.torch_gc()
try:
model = self.load_model(selected_file)
@@ -110,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache()
devices.torch_gc()
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB