mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
Merge branch 'master' into saving
This commit is contained in:
164
modules/ui.py
164
modules/ui.py
@@ -15,11 +15,13 @@ import subprocess as sp
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
|
||||
import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack
|
||||
from modules.paths import script_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
@@ -32,6 +34,7 @@ import modules.codeformer_model
|
||||
import modules.styles
|
||||
import modules.generation_parameters_copypaste
|
||||
from modules.images import apply_filename_pattern, get_next_sequence_number
|
||||
import modules.textual_inversion.ui
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||
mimetypes.init()
|
||||
@@ -129,27 +132,37 @@ def save_files(js_data, images, index):
|
||||
writer = csv.writer(file)
|
||||
if at_start:
|
||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
||||
if file_decoration != "":
|
||||
file_decoration = "-" + file_decoration.lower()
|
||||
file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt)
|
||||
truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration
|
||||
filename_base = truncated
|
||||
extension = opts.samples_format.lower()
|
||||
|
||||
basecount = get_next_sequence_number(path, "")
|
||||
for i, filedata in enumerate(images):
|
||||
file_number = f"{basecount+i:05}"
|
||||
filename = file_number + filename_base + ".png"
|
||||
filename = file_number + filename_base + f".{extension}"
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text('parameters', infotexts[i])
|
||||
|
||||
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
||||
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||
if opts.enable_pnginfo and extension == 'png':
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text('parameters', infotexts[i])
|
||||
image.save(filepath, pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(filepath, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
|
||||
piexif.insert(piexif.dump({"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
|
||||
}}), filepath)
|
||||
|
||||
filenames.append(filename)
|
||||
|
||||
@@ -158,8 +171,8 @@ def save_files(js_data, images, index):
|
||||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def wrap_gradio_call(func):
|
||||
def f(*args, **kwargs):
|
||||
def wrap_gradio_call(func, extra_outputs=None):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
@@ -175,7 +188,10 @@ def wrap_gradio_call(func):
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
|
||||
@@ -195,6 +211,7 @@ def wrap_gradio_call(func):
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
return tuple(res)
|
||||
|
||||
@@ -203,7 +220,7 @@ def wrap_gradio_call(func):
|
||||
|
||||
def check_progress_call(id_part):
|
||||
if shared.state.job_count == 0:
|
||||
return "", gr_show(False), gr_show(False)
|
||||
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||
|
||||
progress = 0
|
||||
|
||||
@@ -235,13 +252,19 @@ def check_progress_call(id_part):
|
||||
else:
|
||||
preview_visibility = gr_show(True)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
|
||||
if shared.state.textinfo is not None:
|
||||
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||
else:
|
||||
textinfo_result = gr_show(False)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||
|
||||
|
||||
def check_progress_call_initial(id_part):
|
||||
shared.state.job_count = -1
|
||||
shared.state.current_latent = None
|
||||
shared.state.current_image = None
|
||||
shared.state.textinfo = None
|
||||
|
||||
return check_progress_call(id_part)
|
||||
|
||||
@@ -396,7 +419,7 @@ def create_toprow(is_img2img):
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||
submit = gr.Button('Generate', elem_id="generate", variant='primary')
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
@@ -415,13 +438,16 @@ def create_toprow(is_img2img):
|
||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
||||
|
||||
|
||||
def setup_progressbar(progressbar, preview, id_part):
|
||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
if textinfo is None:
|
||||
textinfo = gr.HTML(visible=False)
|
||||
|
||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||
check_progress.click(
|
||||
fn=lambda: check_progress_call(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||
@@ -429,11 +455,14 @@ def setup_progressbar(progressbar, preview, id_part):
|
||||
fn=lambda: check_progress_call_initial(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
|
||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
||||
dummy_component = gr.Label(visible=False)
|
||||
@@ -499,7 +528,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
||||
txt2img_args = dict(
|
||||
fn=txt2img,
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||
_js="submit",
|
||||
inputs=[
|
||||
txt2img_prompt,
|
||||
@@ -615,7 +644,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
||||
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
|
||||
|
||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
|
||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
|
||||
|
||||
with gr.Row():
|
||||
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
|
||||
@@ -691,7 +720,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=img2img,
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||
_js="submit_img2img",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -844,7 +873,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||
|
||||
submit.click(
|
||||
fn=run_extras,
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
_js="get_extras_tab_index",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
@@ -894,7 +923,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
|
||||
image.change(
|
||||
fn=wrap_gradio_call(run_pnginfo),
|
||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||
inputs=[image],
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
@@ -903,7 +932,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
||||
@@ -912,10 +941,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
||||
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks() as textual_inversion_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||
|
||||
new_embedding_name = gr.Textbox(label="Name")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_embedding = gr.Button(value="Train", variant='primary')
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||
|
||||
create_embedding.click(
|
||||
fn=modules.textual_inversion.ui.create_embedding,
|
||||
inputs=[
|
||||
new_embedding_name,
|
||||
nvpt,
|
||||
],
|
||||
outputs=[
|
||||
train_embedding_name,
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
train_embedding.click(
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
train_embedding_name,
|
||||
learn_rate,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
steps,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
interrupt_training.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def create_setting_component(key):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
@@ -1027,6 +1142,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||
(settings_interface, "Settings", "settings"),
|
||||
]
|
||||
|
||||
@@ -1060,11 +1176,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = run_modelmerger(*args)
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
print("Error loading/saving model file:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||
return results
|
||||
|
||||
|
Reference in New Issue
Block a user