mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
add progress bar to modelmerger
This commit is contained in:
@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
|
||||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
shared.state.job_count = 1
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
shared.state.end()
|
||||
return [message, *[gr.update() for _ in range(4)]]
|
||||
return [*[gr.update() for _ in range(4)], message]
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
if theta_func1:
|
||||
shared.state.job_count += 1
|
||||
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
shared.state.sampling_steps = len(theta_1.keys())
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
del theta_2
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
|
||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
shared.state.sampling_steps = len(theta_0.keys())
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
|
||||
@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
|
||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||
|
Reference in New Issue
Block a user