mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
Add CPU fp8 support
Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet.
This commit is contained in:
@@ -71,6 +71,7 @@ def enable_tf32():
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
@@ -93,10 +94,13 @@ def cond_cast_float(input):
|
||||
nv_rng = None
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
def autocast(disable=False, unet=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if unet and fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
Reference in New Issue
Block a user