mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Use options instead of cmd_args
This commit is contained in:
@@ -339,10 +339,28 @@ class SkipWritingToConfig:
|
||||
SkipWritingToConfig.skip = self.previous
|
||||
|
||||
|
||||
def check_fp8(model):
|
||||
if model is None:
|
||||
return None
|
||||
if devices.get_optimal_device_name() == "mps":
|
||||
enable_fp8 = False
|
||||
elif shared.opts.fp8_storage == "Enable":
|
||||
enable_fp8 = True
|
||||
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
return enable_fp8
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if not check_fp8(model) and devices.fp8:
|
||||
# prevent model to load state dict in fp8
|
||||
model.half()
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
@@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
if devices.get_optimal_device_name() == "mps":
|
||||
enable_fp8 = False
|
||||
elif shared.cmd_opts.opt_unet_fp8_storage:
|
||||
enable_fp8 = True
|
||||
elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
|
||||
if enable_fp8:
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
if model.is_sdxl:
|
||||
cond_stage = model.conditioner
|
||||
else:
|
||||
cond_stage = model.cond_stage_model
|
||||
|
||||
for module in cond_stage.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
module.to(torch.float8_e4m3fn)
|
||||
|
||||
if devices.device == devices.cpu:
|
||||
for module in model.model.diffusion_model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
module.to(torch.float8_e4m3fn)
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
module.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
devices.fp8 = False
|
||||
@@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
return None
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
@@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
if check_fp8(sd_model) != devices.fp8:
|
||||
# load from state dict again to prevent extra numerical errors
|
||||
forced_reload = True
|
||||
elif sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
|
Reference in New Issue
Block a user