add checkpoints tab for extra networks UI

This commit is contained in:
AUTOMATIC
2023-01-28 22:52:27 +03:00
parent 91c8d0dcfc
commit 1d8e06d542
8 changed files with 94 additions and 8 deletions

View File

@@ -1,4 +1,6 @@
import os.path
import urllib.parse
from pathlib import Path
from modules import shared
import gradio as gr
@@ -8,12 +10,31 @@ import html
from modules.generation_parameters_copypaste import image_from_url_text
extra_pages = []
allowed_dirs = set()
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
extra_pages.append(page)
allowed_dirs.clear()
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def add_pages_to_demo(app):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
if os.path.splitext(filename)[1].lower() != ".png":
raise ValueError(f"File cannot be fetched: {filename}. Only png.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
class ExtraNetworksPage:
@@ -26,6 +47,9 @@ class ExtraNetworksPage:
def refresh(self):
pass
def link_preview(self, filename):
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
def create_html(self, tabname):
view = shared.opts.extra_networks_default_view
items_html = ''
@@ -54,13 +78,17 @@ class ExtraNetworksPage:
def create_html_for_item(self, item, tabname):
preview = item.get("preview", None)
onclick = item.get("onclick", None)
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item["prompt"],
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
"card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
}
@@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path):
parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path)
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
return child_path.startswith(parent_path)
def setup_ui(ui, gallery):
@@ -173,7 +201,8 @@ def setup_ui(ui, gallery):
ui.button_save_preview.click(
fn=save_preview,
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*ui.pages]
)