mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Use options instead of cmd_args
This commit is contained in:
@@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool:
|
||||
if device_id is None:
|
||||
device_id = get_cuda_device_id()
|
||||
return (
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return (
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
else 0
|
||||
) or torch.cuda.current_device()
|
||||
|
||||
@@ -116,16 +116,19 @@ patch_module_list = [
|
||||
torch.nn.LayerNorm,
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(self, *args, **kwargs):
|
||||
org_dtype = next(self.parameters()).dtype
|
||||
self.to(dtype)
|
||||
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
self.to(org_dtype)
|
||||
return result
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_autocast():
|
||||
def manual_cast_forward(self, *args, **kwargs):
|
||||
org_dtype = next(self.parameters()).dtype
|
||||
self.to(dtype)
|
||||
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
self.to(org_dtype)
|
||||
return result
|
||||
for module_type in patch_module_list:
|
||||
org_forward = module_type.forward
|
||||
module_type.forward = manual_cast_forward
|
||||
|
Reference in New Issue
Block a user