Use options instead of cmd_args

This commit is contained in:
Kohaku-Blueleaf
2023-11-19 15:50:06 +08:00
parent b60e1088db
commit 598da5cd49
6 changed files with 49 additions and 42 deletions

View File

@@ -339,10 +339,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
def check_fp8(model):
if model is None:
return None
if devices.get_optimal_device_name() == "mps":
enable_fp8 = False
elif shared.opts.fp8_storage == "Enable":
enable_fp8 = True
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
enable_fp8 = True
else:
enable_fp8 = False
return enable_fp8
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if not check_fp8(model) and devices.fp8:
# prevent model to load state dict in fp8
model.half()
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
if devices.get_optimal_device_name() == "mps":
enable_fp8 = False
elif shared.cmd_opts.opt_unet_fp8_storage:
enable_fp8 = True
elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
enable_fp8 = True
else:
enable_fp8 = False
if enable_fp8:
if check_fp8(model):
devices.fp8 = True
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
if devices.device == devices.cpu:
for module in model.model.diffusion_model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
else:
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False
@@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
def reload_model_weights(sd_model=None, info=None):
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
if check_fp8(sd_model) != devices.fp8:
# load from state dict again to prevent extra numerical errors
forced_reload = True
elif sd_model.sd_model_checkpoint == checkpoint_info.filename:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None: