Add --precision half cmd option

This commit is contained in:
huchenlei
2024-05-16 19:50:06 -04:00
parent ddb28b33a3
commit 2a8a60c2c5
6 changed files with 71 additions and 19 deletions

View File

@@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
fp8: bool = False
# Force fp16 for all models in inference. No casting during inference.
# This flag is controlled by "--precision half" command line arg.
force_fp16: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -127,6 +130,8 @@ unet_needs_upcast = False
def cond_cast_unet(input):
if force_fp16:
return input.to(torch.float16)
return input.to(dtype_unet) if unet_needs_upcast else input
@@ -206,6 +211,11 @@ def autocast(disable=False):
if disable:
return contextlib.nullcontext()
if force_fp16:
# No casting during inference if force_fp16 is enabled.
# All tensor dtype conversion happens before inference.
return contextlib.nullcontext()
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
@@ -269,3 +279,17 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
def force_model_fp16():
"""
ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
force conversion of input to float32. If force_fp16 is enabled, we need to
prevent this casting.
"""
assert force_fp16
import sgm.modules.diffusionmodules.util as sgm_util
import ldm.modules.diffusionmodules.util as ldm_util
sgm_util.GroupNorm32 = torch.nn.GroupNorm
ldm_util.GroupNorm32 = torch.nn.GroupNorm
print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")