Option for using fp16 weight when apply lora

This commit is contained in:
Kohaku-Blueleaf
2023-11-21 19:59:34 +08:00
parent b2e039d07b
commit 370a77f8e7
4 changed files with 25 additions and 7 deletions

View File

@@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.clone().half()
if module.bias is not None:
module.fp16_bias = module.bias.clone().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")