add progress bar to modelmerger

This commit is contained in:
AUTOMATIC
2023-01-19 09:25:37 +03:00
parent 7cfc645030
commit c7e50425f6
5 changed files with 40 additions and 9 deletions

View File

@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin()
shared.state.job = 'model-merge'
shared.state.job_count = 1
def fail(message):
shared.state.textinfo = message
shared.state.end()
return [message, *[gr.update() for _ in range(4)]]
return [*[gr.update() for _ in range(4)], message]
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1:
shared.state.job_count += 1
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
shared.state.sampling_step += 1
del theta_2
shared.state.nextjob()
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
if save_as_half:
theta_0[key] = theta_0[key].half()
shared.state.sampling_step += 1
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob()
shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]