allow baking in VAE in checkpoint merger tab

do not save config if it's the default for checkpoint merger tab
change file naming scheme for checkpoint merger tab
allow just saving A without any merging for checkpoint merger tab
some stylistic changes for UI in checkpoint merger tab
This commit is contained in:
AUTOMATIC
2023-01-19 10:39:51 +03:00
parent c7e50425f6
commit 0f5dbfffd0
7 changed files with 102 additions and 59 deletions

View File

@@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -1185,7 +1185,7 @@ def create_ui():
with gr.Column(variant='compact'):
gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with FormRow():
with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1197,13 +1197,20 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
with FormRow():
with gr.Column():
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
with gr.Column():
with FormRow():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
@@ -1757,6 +1764,7 @@ def create_ui():
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click(
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
@@ -1771,6 +1779,7 @@ def create_ui():
custom_name,
checkpoint_format,
config_source,
bake_in_vae,
],
outputs=[
primary_model_name,