ManualCast for 10/16 series gpu

This commit is contained in:
Kohaku-Blueleaf
2023-10-28 15:24:26 +08:00
parent 0beb131c7f
commit d4d3134f6d
3 changed files with 62 additions and 14 deletions

View File

@@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if enable_fp8:
devices.fp8 = True
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
if devices.device == devices.cpu:
for module in model.model.diffusion_model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")
timer.record("apply fp8")
else:
devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16