Use options instead of cmd_args

This commit is contained in:
Kohaku-Blueleaf
2023-11-19 15:50:06 +08:00
parent b60e1088db
commit 598da5cd49
6 changed files with 49 additions and 42 deletions

View File

@@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id():
return (
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
@@ -116,16 +116,19 @@ patch_module_list = [
torch.nn.LayerNorm,
]
def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
@contextlib.contextmanager
def manual_autocast():
def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward