mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-06-27 09:26:47 +00:00
Revert "Apply correct inference precision implementation"
This reverts commit e00365962b
.
This commit is contained in:
parent
e00365962b
commit
1fd69655fe
@ -132,21 +132,6 @@ patch_module_list = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def cast_output(result):
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result)
|
|
||||||
elif isinstance(result, torch.Tensor):
|
|
||||||
result = result.to(dtype_inference)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def autocast_with_cast_output(self, *args, **kwargs):
|
|
||||||
result = self.org_forward(*args, **kwargs)
|
|
||||||
if dtype_inference != dtype:
|
|
||||||
result = cast_output(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def manual_cast_forward(target_dtype):
|
def manual_cast_forward(target_dtype):
|
||||||
def forward_wrapper(self, *args, **kwargs):
|
def forward_wrapper(self, *args, **kwargs):
|
||||||
if any(
|
if any(
|
||||||
@ -164,7 +149,15 @@ def manual_cast_forward(target_dtype):
|
|||||||
self.to(org_dtype)
|
self.to(org_dtype)
|
||||||
|
|
||||||
if target_dtype != dtype_inference:
|
if target_dtype != dtype_inference:
|
||||||
result = cast_output(result)
|
if isinstance(result, tuple):
|
||||||
|
result = tuple(
|
||||||
|
i.to(dtype_inference)
|
||||||
|
if isinstance(i, torch.Tensor)
|
||||||
|
else i
|
||||||
|
for i in result
|
||||||
|
)
|
||||||
|
elif isinstance(result, torch.Tensor):
|
||||||
|
result = result.to(dtype_inference)
|
||||||
return result
|
return result
|
||||||
return forward_wrapper
|
return forward_wrapper
|
||||||
|
|
||||||
@ -185,20 +178,6 @@ def manual_cast(target_dtype):
|
|||||||
module_type.forward = module_type.org_forward
|
module_type.forward = module_type.org_forward
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def precision_full_with_autocast(autocast_ctx):
|
|
||||||
for module_type in patch_module_list:
|
|
||||||
org_forward = module_type.forward
|
|
||||||
module_type.forward = autocast_with_cast_output
|
|
||||||
module_type.org_forward = org_forward
|
|
||||||
try:
|
|
||||||
with autocast_ctx:
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
for module_type in patch_module_list:
|
|
||||||
module_type.forward = module_type.org_forward
|
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False):
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
@ -212,9 +191,6 @@ def autocast(disable=False):
|
|||||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||||
return manual_cast(dtype_inference)
|
return manual_cast(dtype_inference)
|
||||||
|
|
||||||
if dtype_inference == torch.float32 and dtype != torch.float32:
|
|
||||||
return precision_full_with_autocast(torch.autocast("cuda"))
|
|
||||||
|
|
||||||
return torch.autocast("cuda")
|
return torch.autocast("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user