Add CPU fp8 support

Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear)

And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet.
This commit is contained in:
Kohaku-Blueleaf
2023-10-24 01:49:05 +08:00
parent 5f9ddfa46f
commit eaa9f5162f
3 changed files with 22 additions and 6 deletions

View File

@@ -71,6 +71,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -93,10 +94,13 @@ def cond_cast_float(input):
nv_rng = None
def autocast(disable=False):
def autocast(disable=False, unet=False):
if disable:
return contextlib.nullcontext()
if unet and fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()