mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-05 03:32:37 +00:00
Option for using fp16 weight when apply lora
This commit is contained in:
@@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
module.to(torch.float8_e4m3fn)
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
module.fp16_weight = module.weight.clone().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.clone().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
|
Reference in New Issue
Block a user