change import statements for #14478

This commit is contained in:
AUTOMATIC1111
2023-12-31 22:38:30 +03:00
parent be5f1acc8f
commit a70dfb64a8
7 changed files with 14 additions and 17 deletions

View File

@@ -4,7 +4,7 @@ from functools import lru_cache
import torch
from modules import errors, shared
from modules.torch_utils import get_param
from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
@@ -132,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs):
org_dtype = get_param(self).dtype
org_dtype = torch_utils.get_param(self).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}