mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
Merge branch 'master' into token_updates
This commit is contained in:
139
modules/ui.py
139
modules/ui.py
@@ -22,6 +22,7 @@ import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack
|
||||
from modules.paths import script_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
@@ -34,6 +35,7 @@ import modules.codeformer_model
|
||||
import modules.styles
|
||||
import modules.generation_parameters_copypaste
|
||||
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
|
||||
import modules.textual_inversion.ui
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||
mimetypes.init()
|
||||
@@ -144,8 +146,8 @@ def save_files(js_data, images, index):
|
||||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def wrap_gradio_call(func):
|
||||
def f(*args, **kwargs):
|
||||
def wrap_gradio_call(func, extra_outputs=None):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
@@ -161,7 +163,10 @@ def wrap_gradio_call(func):
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
|
||||
@@ -181,6 +186,7 @@ def wrap_gradio_call(func):
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
return tuple(res)
|
||||
|
||||
@@ -189,7 +195,7 @@ def wrap_gradio_call(func):
|
||||
|
||||
def check_progress_call(id_part):
|
||||
if shared.state.job_count == 0:
|
||||
return "", gr_show(False), gr_show(False)
|
||||
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||
|
||||
progress = 0
|
||||
|
||||
@@ -221,13 +227,19 @@ def check_progress_call(id_part):
|
||||
else:
|
||||
preview_visibility = gr_show(True)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
|
||||
if shared.state.textinfo is not None:
|
||||
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||
else:
|
||||
textinfo_result = gr_show(False)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||
|
||||
|
||||
def check_progress_call_initial(id_part):
|
||||
shared.state.job_count = -1
|
||||
shared.state.current_latent = None
|
||||
shared.state.current_image = None
|
||||
shared.state.textinfo = None
|
||||
|
||||
return check_progress_call(id_part)
|
||||
|
||||
@@ -403,13 +415,16 @@ def create_toprow(is_img2img):
|
||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||
|
||||
|
||||
def setup_progressbar(progressbar, preview, id_part):
|
||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
if textinfo is None:
|
||||
textinfo = gr.HTML(visible=False)
|
||||
|
||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||
check_progress.click(
|
||||
fn=lambda: check_progress_call(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||
@@ -417,11 +432,14 @@ def setup_progressbar(progressbar, preview, id_part):
|
||||
fn=lambda: check_progress_call_initial(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
|
||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
dummy_component = gr.Label(visible=False)
|
||||
@@ -487,7 +505,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
||||
txt2img_args = dict(
|
||||
fn=txt2img,
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||
_js="submit",
|
||||
inputs=[
|
||||
txt2img_prompt,
|
||||
@@ -681,7 +699,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=img2img,
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||
_js="submit_img2img",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -838,7 +856,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||
|
||||
submit.click(
|
||||
fn=run_extras,
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
_js="get_extras_tab_index",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -888,7 +906,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
|
||||
image.change(
|
||||
fn=wrap_gradio_call(run_pnginfo),
|
||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||
inputs=[image],
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
@@ -897,7 +915,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
||||
@@ -906,10 +924,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
||||
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks() as textual_inversion_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||
|
||||
new_embedding_name = gr.Textbox(label="Name")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_embedding = gr.Button(value="Train", variant='primary')
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||
|
||||
create_embedding.click(
|
||||
fn=modules.textual_inversion.ui.create_embedding,
|
||||
inputs=[
|
||||
new_embedding_name,
|
||||
nvpt,
|
||||
],
|
||||
outputs=[
|
||||
train_embedding_name,
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
train_embedding.click(
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
train_embedding_name,
|
||||
learn_rate,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
steps,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
interrupt_training.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def create_setting_component(key):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
@@ -1021,6 +1125,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||
(settings_interface, "Settings", "settings"),
|
||||
]
|
||||
|
||||
@@ -1054,11 +1159,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = run_modelmerger(*args)
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
print("Error loading/saving model file:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||
return results
|
||||
|
||||
|
Reference in New Issue
Block a user