mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-10 10:50:09 +00:00
Merge remote-tracking branch 'upstream/master' into roy.add_simple_interrogate_api
This commit is contained in:
@@ -1,34 +1,32 @@
|
||||
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI
|
||||
import time
|
||||
import uvicorn
|
||||
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import modules.shared as shared
|
||||
from modules import devices
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers
|
||||
from modules.extras import run_pnginfo
|
||||
import modules.shared as shared
|
||||
import uvicorn
|
||||
from fastapi import Body, APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from typing import List
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
||||
reqDict.pop('upscaler_1')
|
||||
reqDict.pop('upscaler_2')
|
||||
return reqDict
|
||||
|
||||
|
||||
class Api:
|
||||
@@ -36,26 +34,22 @@ class Api:
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
|
||||
def __base64_to_image(self, base64_string):
|
||||
# if has a comma, deal with prefix
|
||||
if "," in base64_string:
|
||||
base64_string = base64_string.split(",")[1]
|
||||
imgdata = base64.b64decode(base64_string)
|
||||
# convert base64 to PIL image
|
||||
return Image.open(io.BytesIO(imgdata))
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
@@ -63,40 +57,39 @@ class Api:
|
||||
)
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
|
||||
|
||||
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(img2imgreq.sampler_index)
|
||||
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = self.__base64_to_image(mask)
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
}
|
||||
)
|
||||
@@ -104,27 +97,89 @@ class Api:
|
||||
|
||||
imgs = []
|
||||
for img in init_images:
|
||||
img = self.__base64_to_image(img)
|
||||
img = decode_base64_to_image(img)
|
||||
imgs = [img] * p.batch_size
|
||||
|
||||
p.init_images = imgs
|
||||
# Override object param
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
if (not img2imgreq.include_init_images):
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
|
||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateAPI):
|
||||
def extrasapi(self):
|
||||
raise NotImplementedError
|
||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def prepareFiles(file):
|
||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
||||
file.orig_name = file.name
|
||||
return file
|
||||
|
||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
||||
reqDict.pop('imageList')
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
|
||||
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
|
||||
|
||||
return PNGInfoResponse(info=result[1])
|
||||
|
||||
def progressapi(self, req: ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
current_image = None
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
@@ -139,13 +194,7 @@ class Api:
|
||||
with self.queue_lock:
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
|
||||
return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None)
|
||||
|
||||
def extrasapi(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def pnginfoapi(self):
|
||||
raise NotImplementedError
|
||||
return InterrogateResponse(caption=processed)
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
|
@@ -1,10 +1,11 @@
|
||||
from array import array
|
||||
from inflection import underscore
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
import inspect
|
||||
|
||||
from click import prompt
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
@@ -51,17 +52,17 @@ class PydanticModelGenerator:
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
|
||||
if class_instance is not None:
|
||||
@@ -78,11 +79,11 @@ class PydanticModelGenerator:
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
@@ -99,21 +100,79 @@ class PydanticModelGenerator:
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
).generate_model()
|
||||
|
||||
InterrogateAPI = PydanticModelGenerator(
|
||||
"Interrogate",
|
||||
None,
|
||||
[{"key": "image", "type": str, "default": None}]
|
||||
).generate_model()
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
83
modules/extensions.py
Normal file
83
modules/extensions.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
for fetch in repo.remote().fetch("--dry-run"):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def pull(self):
|
||||
repo = git.Repo(self.path)
|
||||
repo.remotes.origin.pull()
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
path = os.path.join(extensions_dir, dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||
extensions.append(extension)
|
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
|
||||
@@ -7,6 +8,10 @@ from PIL import Image
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from typing import Callable, List, OrderedDict, Tuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
@@ -17,10 +22,38 @@ import piexif.helper
|
||||
import gradio as gr
|
||||
|
||||
|
||||
cached_images = {}
|
||||
class LruCache(OrderedDict):
|
||||
@dataclass(frozen=True)
|
||||
class Key:
|
||||
image_hash: int
|
||||
info_hash: int
|
||||
args_hash: int
|
||||
|
||||
@dataclass
|
||||
class Value:
|
||||
image: Image.Image
|
||||
info: str
|
||||
|
||||
def __init__(self, max_size: int = 5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_size = max_size
|
||||
|
||||
def get(self, key: LruCache.Key) -> LruCache.Value:
|
||||
ret = super().get(key)
|
||||
if ret is not None:
|
||||
self.move_to_end(key) # Move to end of eviction list
|
||||
return ret
|
||||
|
||||
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
|
||||
self[key] = value
|
||||
while len(self) > self._max_size:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||
cached_images: LruCache = LruCache(max_size=5)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
|
||||
devices.torch_gc()
|
||||
|
||||
imageArr = []
|
||||
@@ -39,7 +72,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
|
||||
if input_dir == '':
|
||||
return outputs, "Please select an input directory.", ''
|
||||
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for img in image_list:
|
||||
try:
|
||||
image = Image.open(img)
|
||||
@@ -56,7 +89,91 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
|
||||
# Extra operation definitions
|
||||
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if gfpgan_visibility < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
|
||||
res = cropped
|
||||
return res
|
||||
|
||||
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
|
||||
nonlocal upscaling_resize
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
return (image, info)
|
||||
|
||||
@dataclass
|
||||
class UpscaleParams:
|
||||
upscaler_idx: int
|
||||
blend_alpha: float
|
||||
|
||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
blended_result: Image.Image = None
|
||||
for upscaler in params:
|
||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
|
||||
info_hash=hash(info),
|
||||
args_hash=hash(upscale_args))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
|
||||
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
|
||||
else:
|
||||
res, info = cached_entry.image, cached_entry.info
|
||||
|
||||
if blended_result is None:
|
||||
blended_result = res
|
||||
else:
|
||||
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
|
||||
return (blended_result, info)
|
||||
|
||||
# Build a list of operations to run
|
||||
facefix_ops: List[Callable] = []
|
||||
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
|
||||
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
|
||||
|
||||
upscale_ops: List[Callable] = []
|
||||
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
|
||||
|
||||
if upscaling_resize != 0:
|
||||
step_params: List[UpscaleParams] = []
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
|
||||
|
||||
upscale_ops.append(partial(run_upscalers_blend, step_params))
|
||||
|
||||
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
|
||||
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
@@ -64,64 +181,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
|
||||
image = image.convert("RGB")
|
||||
info = ""
|
||||
# Run each operation on each image
|
||||
for op in extras_ops:
|
||||
image, info = op(image, info)
|
||||
|
||||
if gfpgan_visibility > 0:
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if gfpgan_visibility < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if codeformer_visibility > 0:
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
|
||||
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
|
||||
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
|
||||
c = cropped
|
||||
cached_images[key] = c
|
||||
|
||||
return c
|
||||
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
|
||||
image = res
|
||||
|
||||
while len(cached_images) > 2:
|
||||
del cached_images[next(iter(cached_images.keys()))]
|
||||
|
||||
if opts.use_original_name_batch and image_name != None:
|
||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||
else:
|
||||
@@ -141,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
|
||||
return outputs, plaintext_to_html(info), ''
|
||||
|
||||
def clear_cache():
|
||||
cached_images.clear()
|
||||
|
||||
|
||||
def run_pnginfo(image):
|
||||
if image is None:
|
||||
|
@@ -1,14 +1,25 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
paste_fields = {}
|
||||
bind_list = []
|
||||
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
bind_list.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
@@ -20,6 +31,111 @@ def quote(text):
|
||||
text = text.replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||
normfn = os.path.normpath(filename)
|
||||
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
|
||||
filedata = filedata[0]
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields):
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
import modules.ui
|
||||
if tabname == 'txt2img':
|
||||
modules.ui.txt2img_paste_fields = fields
|
||||
elif tabname == 'img2img':
|
||||
modules.ui.img2img_paste_fields = fields
|
||||
|
||||
|
||||
def integrate_settings_paste_fields(component_dict):
|
||||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
for tabname, info in paste_fields.items():
|
||||
if info["fields"] is not None:
|
||||
info["fields"] += settings_paste_fields
|
||||
|
||||
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
buttons[tab] = gr.Button(f"Send to {tab}")
|
||||
return buttons
|
||||
|
||||
|
||||
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
|
||||
def bind_buttons(buttons, send_image, send_generate_info):
|
||||
bind_list.append([buttons, send_image, send_generate_info])
|
||||
|
||||
|
||||
def run_bind():
|
||||
for buttons, send_image, send_generate_info in bind_list:
|
||||
for tab in buttons:
|
||||
button = buttons[tab]
|
||||
if send_image and paste_fields[tab]["init_img"]:
|
||||
if type(send_image) == gr.Gallery:
|
||||
button.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery",
|
||||
inputs=[send_image],
|
||||
outputs=[paste_fields[tab]["init_img"]],
|
||||
)
|
||||
else:
|
||||
button.click(
|
||||
fn=lambda x: x,
|
||||
inputs=[send_image],
|
||||
outputs=[paste_fields[tab]["init_img"]],
|
||||
)
|
||||
|
||||
if send_generate_info and paste_fields[tab]["fields"] is not None:
|
||||
if send_generate_info in paste_fields:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
|
||||
button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
|
||||
)
|
||||
else:
|
||||
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
|
||||
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{tab}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
@@ -68,7 +184,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, js=None):
|
||||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(script_path, "params.txt")
|
||||
@@ -106,7 +222,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=js,
|
||||
_js=jsfunc,
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
|
||||
|
||||
|
@@ -25,6 +25,7 @@ from statistics import stdev, mean
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
@@ -208,13 +209,16 @@ def list_hypernetworks(path):
|
||||
res = {}
|
||||
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
res[name] = filename
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name] = filename
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
if path is not None:
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
@@ -331,7 +335,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
@@ -357,39 +363,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
weights = hypernetwork.weights()
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
|
||||
size = len(ds.indexes)
|
||||
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
losses = torch.zeros((size,))
|
||||
previous_mean_losses = [0]
|
||||
previous_mean_loss = 0
|
||||
print("Mean loss of {} elements".format(size))
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step > steps:
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
steps_without_grad = 0
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
@@ -428,7 +439,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
|
||||
if len(previous_mean_losses) > 1:
|
||||
@@ -438,19 +451,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||
pbar.set_description(dataset_loss_info)
|
||||
|
||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||
hypernetwork.save(last_saved_file)
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
"loss": f"{previous_mean_loss:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
optimizer.zero_grad()
|
||||
@@ -503,13 +516,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
"""
|
||||
|
||||
report_statistics(loss_dict)
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
|
||||
hypernetwork.name = hypernetwork_name
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||
hypernetwork.save(filename)
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
|
||||
return hypernetwork, filename
|
||||
|
||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||
old_hypernetwork_name = hypernetwork.name
|
||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
try:
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.name = hypernetwork_name
|
||||
hypernetwork.save(filename)
|
||||
except:
|
||||
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
hypernetwork.name = old_hypernetwork_name
|
||||
raise
|
||||
|
@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
|
@@ -300,8 +300,8 @@ class FilenameGenerator:
|
||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||
'steps': lambda self: self.p and self.p.steps,
|
||||
'cfg': lambda self: self.p and self.p.cfg_scale,
|
||||
'width': lambda self: self.p and self.p.width,
|
||||
'height': lambda self: self.p and self.p.height,
|
||||
'width': lambda self: self.image.width,
|
||||
'height': lambda self: self.image.height,
|
||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
|
||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||
@@ -315,10 +315,11 @@ class FilenameGenerator:
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
def __init__(self, p, seed, prompt):
|
||||
def __init__(self, p, seed, prompt, image):
|
||||
self.p = p
|
||||
self.seed = seed
|
||||
self.prompt = prompt
|
||||
self.image = image
|
||||
|
||||
def prompt_no_style(self):
|
||||
if self.p is None or self.prompt is None:
|
||||
@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
txt_fullfn (`str` or None):
|
||||
If a text file is saved for this image, this will be its full path. Otherwise None.
|
||||
"""
|
||||
namegen = FilenameGenerator(p, seed, prompt)
|
||||
namegen = FilenameGenerator(p, seed, prompt, image)
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
@@ -19,7 +19,7 @@ import modules.scripts
|
||||
def process_batch(p, input_dir, output_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
||||
images = shared.listfiles(input_dir)
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
|
@@ -129,6 +129,73 @@ class StableDiffusionProcessing():
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return torch.zeros(
|
||||
x.shape[0], 5, 1, 1,
|
||||
dtype=x.dtype,
|
||||
device=x.device
|
||||
)
|
||||
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
return torch.zeros(
|
||||
latent_image.shape[0], 5, 1, 1,
|
||||
dtype=latent_image.dtype,
|
||||
device=latent_image.device
|
||||
)
|
||||
|
||||
# Handle the different mask inputs
|
||||
if image_mask is not None:
|
||||
if torch.is_tensor(image_mask):
|
||||
conditioning_mask = image_mask
|
||||
else:
|
||||
conditioning_mask = np.array(image_mask.convert("L"))
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
else:
|
||||
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
|
||||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||
conditioning_mask = conditioning_mask.to(source_image.device)
|
||||
conditioning_image = torch.lerp(
|
||||
source_image,
|
||||
source_image * (1.0 - conditioning_mask),
|
||||
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
||||
)
|
||||
|
||||
# Encode the new masked image using first stage of network.
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
pass
|
||||
|
||||
@@ -329,6 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
@@ -411,7 +479,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.run_alwayson_scripts(p)
|
||||
p.scripts.process(p)
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
@@ -434,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
if (len(prompts) == 0):
|
||||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
with devices.autocast():
|
||||
@@ -523,7 +591,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
devices.torch_gc()
|
||||
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
@@ -571,37 +645,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
|
||||
def create_dummy_mask(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
else:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
return samples
|
||||
|
||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||
|
||||
@@ -634,11 +687,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
||||
image_conditioning = self.txt2img_image_conditioning(x)
|
||||
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -770,33 +825,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
if self.image_mask is not None:
|
||||
conditioning_mask = np.array(self.image_mask.convert("L"))
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
else:
|
||||
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
||||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
conditioning_mask = conditioning_mask.to(image.device)
|
||||
conditioning_image = image * (1.0 - conditioning_mask)
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||
else:
|
||||
self.image_conditioning = torch.zeros(
|
||||
self.init_latent.shape[0], 5, 1, 1,
|
||||
dtype=self.init_latent.dtype,
|
||||
device=self.init_latent.device
|
||||
)
|
||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
||||
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
|
@@ -7,7 +7,7 @@ import modules.ui as ui
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks
|
||||
from modules import shared, paths, script_callbacks, extensions
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
@@ -64,7 +64,16 @@ class Script:
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
scripts. You can modify the processing object (p) here, inject hooks, etc.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -98,17 +107,8 @@ def list_scripts(scriptdirname, extension):
|
||||
for filename in sorted(os.listdir(basedir)):
|
||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||
|
||||
extdir = os.path.join(paths.script_path, "extensions")
|
||||
if os.path.exists(extdir):
|
||||
for dirname in sorted(os.listdir(extdir)):
|
||||
dirpath = os.path.join(extdir, dirname)
|
||||
scriptdirpath = os.path.join(dirpath, scriptdirname)
|
||||
|
||||
if not os.path.isdir(scriptdirpath):
|
||||
continue
|
||||
|
||||
for filename in sorted(os.listdir(scriptdirpath)):
|
||||
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
|
||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
@@ -118,11 +118,7 @@ def list_scripts(scriptdirname, extension):
|
||||
def list_files_with_name(filename):
|
||||
res = []
|
||||
|
||||
dirs = [paths.script_path]
|
||||
|
||||
extdir = os.path.join(paths.script_path, "extensions")
|
||||
if os.path.exists(extdir):
|
||||
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
|
||||
for dirpath in dirs:
|
||||
if not os.path.isdir(dirpath):
|
||||
@@ -236,7 +232,7 @@ class ScriptRunner:
|
||||
with gr.Group():
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
||||
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
||||
dropdown.save_to_config = True
|
||||
inputs[0] = dropdown
|
||||
|
||||
@@ -289,13 +285,22 @@ class ScriptRunner:
|
||||
|
||||
return processed
|
||||
|
||||
def run_alwayson_scripts(self, p):
|
||||
def process(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process(p, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
|
||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess(p, processed, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def reload_sources(self, cache):
|
||||
|
@@ -3,6 +3,7 @@ import os.path
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
@@ -35,8 +36,10 @@ def setup_model():
|
||||
list_models()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
return sorted([x.title for x in checkpoints_list.values()])
|
||||
def checkpoint_tiles():
|
||||
convert = lambda name: int(name) if name.isdigit() else name.lower()
|
||||
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
|
||||
|
||||
def list_models():
|
||||
@@ -170,7 +173,9 @@ def load_model_weights(model, checkpoint_info):
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
missing, extra = model.load_state_dict(sd, strict=False)
|
||||
del pl_sd
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
@@ -194,9 +199,10 @@ def load_model_weights(model, checkpoint_info):
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
else:
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
|
@@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
@@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
restricted_opts = {
|
||||
@@ -96,6 +97,9 @@ restricted_opts = {
|
||||
"outdir_save",
|
||||
}
|
||||
|
||||
if cmd_opts.share or cmd_opts.listen:
|
||||
cmd_opts.disable_extension_access = True
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
|
||||
|
||||
@@ -131,6 +135,7 @@ class State:
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
textinfo = None
|
||||
need_restart = False
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
@@ -143,9 +148,38 @@ class State:
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
def get_job_timestamp(self):
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.skipped,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
def end(self):
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
state = State()
|
||||
|
||||
@@ -267,6 +301,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
@@ -303,6 +338,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
@@ -322,6 +358,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
}))
|
||||
|
||||
options_templates.update()
|
||||
|
||||
|
||||
class Options:
|
||||
data = None
|
||||
@@ -333,8 +375,9 @@ class Options:
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if self.data is not None:
|
||||
if key in self.data:
|
||||
if key in self.data or key in self.data_labels:
|
||||
self.data[key] = value
|
||||
return
|
||||
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
@@ -449,3 +492,8 @@ total_tqdm = TotalTQDM()
|
||||
|
||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
|
||||
self.lines = lines
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
@@ -86,12 +88,12 @@ class PersonalizedBase(Dataset):
|
||||
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.initial_indexes = np.arange(len(self.dataset))
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
|
||||
self.indexes = np.random.permutation(self.dataset_length)
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
|
@@ -4,30 +4,37 @@ import tqdm
|
||||
class LearnScheduleIterator:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
"""
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
|
||||
"""
|
||||
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
for i, pair in enumerate(pairs):
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
if step > cur_step:
|
||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||
self.maxit += 1
|
||||
if step > max_steps:
|
||||
try:
|
||||
for i, pair in enumerate(pairs):
|
||||
if not pair.strip():
|
||||
continue
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
if step > cur_step:
|
||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||
self.maxit += 1
|
||||
if step > max_steps:
|
||||
return
|
||||
elif step == -1:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
elif step == -1:
|
||||
else:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
else:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
assert self.rates
|
||||
except (ValueError, AssertionError):
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@@ -52,7 +59,7 @@ class LearnRateScheduler:
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if step_number <= self.end_step:
|
||||
if step_number < self.end_step:
|
||||
return
|
||||
|
||||
try:
|
||||
|
@@ -119,7 +119,7 @@ class EmbeddingDatabase:
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('hash', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
|
||||
@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
|
||||
if step % shared.opts.training_write_csv_every != 0:
|
||||
if (step + 1) % shared.opts.training_write_csv_every != 0:
|
||||
return
|
||||
|
||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||
|
||||
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
||||
@@ -196,18 +195,39 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
csv_writer.writeheader()
|
||||
|
||||
epoch = step // epoch_len
|
||||
epoch_step = step - epoch * epoch_len
|
||||
epoch_step = step % epoch_len
|
||||
|
||||
csv_writer.writerow({
|
||||
"step": step + 1,
|
||||
"epoch": epoch + 1,
|
||||
"epoch": epoch,
|
||||
"epoch_step": epoch_step + 1,
|
||||
**values,
|
||||
})
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
assert learn_rate, "Learning rate is empty or 0"
|
||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||
assert batch_size > 0, "Batch size must be positive"
|
||||
assert data_root, "Dataset directory is empty"
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
assert template_file, "Prompt template file is empty"
|
||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||
assert steps, "Max steps is empty or 0"
|
||||
assert isinstance(steps, int), "Max steps must be integer"
|
||||
assert steps > 0 , "Max steps must be positive"
|
||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||
assert create_image_every >= 0 , "Create image must be positive or 0"
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
@@ -233,17 +253,28 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
os.makedirs(images_embeds_dir, exist_ok=True)
|
||||
else:
|
||||
images_embeds_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step >= steps:
|
||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
|
||||
@@ -252,13 +283,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, entries in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
@@ -282,17 +306,18 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
steps_done = embedding.step + 1
|
||||
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||
epoch_step = embedding.step % len(ds)
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding.name = f'{embedding_name}-{embedding.step}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
|
||||
embedding.save(last_saved_file)
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||
@@ -300,8 +325,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{embedding.step}'
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
@@ -334,7 +359,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
@@ -350,7 +375,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, embedding.step)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
@@ -373,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
embedding.cached_checksum = None
|
||||
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
|
||||
embedding.name = embedding_name
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
|
||||
embedding.save(filename)
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
||||
try:
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
embedding.name = embedding_name
|
||||
embedding.save(filename)
|
||||
except:
|
||||
embedding.sd_checkpoint = old_sd_checkpoint
|
||||
embedding.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
embedding.name = old_embedding_name
|
||||
embedding.cached_checksum = old_cached_checksum
|
||||
raise
|
||||
|
392
modules/ui.py
392
modules/ui.py
@@ -1,6 +1,4 @@
|
||||
import base64
|
||||
import html
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import mimetypes
|
||||
@@ -18,15 +16,10 @@ import gradio as gr
|
||||
import gradio.routes
|
||||
import gradio.utils
|
||||
import numpy as np
|
||||
import piexif
|
||||
import torch
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
|
||||
from modules.paths import script_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
@@ -35,7 +28,7 @@ if cmd_opts.deepdanbooru:
|
||||
from modules.deepbooru import get_deepbooru_tags
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.gfpgan_model
|
||||
import modules.hypernetworks.ui
|
||||
import modules.ldsr_model
|
||||
@@ -49,13 +42,11 @@ from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
txt2img_paste_fields = []
|
||||
img2img_paste_fields = []
|
||||
|
||||
|
||||
if not cmd_opts.share and not cmd_opts.listen:
|
||||
# fix gradio phoning home
|
||||
@@ -98,37 +89,11 @@ def plaintext_to_html(text):
|
||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||
return text
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||
normfn = os.path.normpath(filename)
|
||||
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
|
||||
filedata = filedata[0]
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
|
||||
def send_gradio_gallery_to_image(x):
|
||||
if len(x) == 0:
|
||||
return None
|
||||
|
||||
return image_from_url_text(x[0])
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
import csv
|
||||
filenames = []
|
||||
@@ -192,7 +157,6 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
|
||||
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
@@ -626,10 +590,88 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
|
||||
return refresh_button
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Group():
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or outdir),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
if tabname != "extras":
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
else:
|
||||
html_info_x = gr.HTML()
|
||||
html_info = gr.HTML()
|
||||
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
|
||||
|
||||
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
parameters_copypaste.reset()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
@@ -675,30 +717,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
|
||||
with gr.Group():
|
||||
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
||||
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
save = gr.Button('Save')
|
||||
send_to_img2img = gr.Button('Send to img2img')
|
||||
send_to_inpaint = gr.Button('Send to inpaint')
|
||||
send_to_extras = gr.Button('Send to extras')
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
@@ -756,23 +776,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
outputs=[hr_options],
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
txt2img_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
@@ -784,7 +787,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
]
|
||||
)
|
||||
|
||||
global txt2img_paste_fields
|
||||
txt2img_paste_fields = [
|
||||
(txt2img_prompt, "Prompt"),
|
||||
(txt2img_negative_prompt, "Negative prompt"),
|
||||
@@ -807,6 +809,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
(firstphase_height, "First pass size-2"),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||
|
||||
txt2img_preview_params = [
|
||||
txt2img_prompt,
|
||||
@@ -893,30 +896,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
|
||||
with gr.Group():
|
||||
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
||||
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
save = gr.Button('Save')
|
||||
img2img_send_to_img2img = gr.Button('Send to img2img')
|
||||
img2img_send_to_inpaint = gr.Button('Send to inpaint')
|
||||
img2img_send_to_extras = gr.Button('Send to extras')
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||
|
||||
with gr.Row():
|
||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
|
||||
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
@@ -1003,25 +984,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
fn=interrogate_deepbooru,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
img2img_gallery,
|
||||
do_make_zip,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_info,
|
||||
html_info,
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_img2img_tokens",
|
||||
@@ -1055,7 +1020,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
outputs=[prompt, negative_prompt, style1, style2],
|
||||
)
|
||||
|
||||
global img2img_paste_fields
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
|
||||
img2img_paste_fields = [
|
||||
(img2img_prompt, "Prompt"),
|
||||
(img2img_negative_prompt, "Negative prompt"),
|
||||
@@ -1074,7 +1040,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
(denoising_strength, "Denoising strength"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
]
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
@@ -1087,12 +1054,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||
|
||||
with gr.TabItem('Batch from Directory'):
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
|
||||
placeholder="A directory on the same machine where the server is running."
|
||||
)
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
|
||||
placeholder="Leave blank to save images to the default path."
|
||||
)
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
|
||||
show_extras_results = gr.Checkbox(label='Show result images', value=True)
|
||||
|
||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||
@@ -1104,9 +1067,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
|
||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
@@ -1119,17 +1082,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
|
||||
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
|
||||
|
||||
with gr.Group():
|
||||
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
|
||||
|
||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
result_images = gr.Gallery(label="Result", show_label=False)
|
||||
html_info_x = gr.HTML()
|
||||
html_info = gr.HTML()
|
||||
extras_send_to_img2img = gr.Button('Send to img2img')
|
||||
extras_send_to_inpaint = gr.Button('Send to inpaint')
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||
|
||||
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
|
||||
|
||||
submit.click(
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
@@ -1152,6 +1110,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
extras_upscaler_1,
|
||||
extras_upscaler_2,
|
||||
extras_upscaler_2_visibility,
|
||||
upscale_before_face_fix,
|
||||
],
|
||||
outputs=[
|
||||
result_images,
|
||||
@@ -1159,19 +1118,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
html_info,
|
||||
]
|
||||
)
|
||||
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
||||
|
||||
extras_send_to_img2img.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_img2img",
|
||||
inputs=[result_images],
|
||||
outputs=[init_img],
|
||||
)
|
||||
|
||||
extras_send_to_inpaint.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_inpaint",
|
||||
inputs=[result_images],
|
||||
outputs=[init_img_with_mask],
|
||||
extras_image.change(
|
||||
fn=modules.extras.clear_cache,
|
||||
inputs=[], outputs=[]
|
||||
)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||
@@ -1183,17 +1134,16 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
html = gr.HTML()
|
||||
generation_info = gr.Textbox(visible=False)
|
||||
html2 = gr.HTML()
|
||||
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
||||
parameters_copypaste.bind_buttons(buttons, image, generation_info)
|
||||
|
||||
image.change(
|
||||
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||
inputs=[image],
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as modelmerger_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
@@ -1238,7 +1188,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
|
||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||
@@ -1491,28 +1441,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
script_callbacks.ui_settings_callback()
|
||||
opts.reorder()
|
||||
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
def run_settings(*args):
|
||||
changed = 0
|
||||
|
||||
@@ -1584,8 +1512,9 @@ Requested path was: {f}
|
||||
column = None
|
||||
with gr.Row(elem_id="settings").style(equal_height=False):
|
||||
for i, (k, item) in enumerate(opts.data_labels.items()):
|
||||
section_must_be_skipped = item.section[0] is None
|
||||
|
||||
if previous_section != item.section:
|
||||
if previous_section != item.section and not section_must_be_skipped:
|
||||
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
|
||||
if column is not None:
|
||||
column.__exit__()
|
||||
@@ -1604,6 +1533,8 @@ Requested path was: {f}
|
||||
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
|
||||
quicksettings_list.append((i, k, item))
|
||||
components.append(dummy_component)
|
||||
elif section_must_be_skipped:
|
||||
components.append(dummy_component)
|
||||
else:
|
||||
component = create_setting_component(k)
|
||||
component_dict[k] = component
|
||||
@@ -1645,9 +1576,10 @@ Requested path was: {f}
|
||||
|
||||
def request_restart():
|
||||
shared.state.interrupt()
|
||||
settings_interface.gradio_ref.do_restart = True
|
||||
shared.state.need_restart = True
|
||||
|
||||
restart_gradio.click(
|
||||
|
||||
fn=request_restart,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
@@ -1666,10 +1598,6 @@ Requested path was: {f}
|
||||
(train_interface, "Train", "ti"),
|
||||
]
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
|
||||
interfaces += [(settings_interface, "Settings", "settings")]
|
||||
|
||||
css = ""
|
||||
|
||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||
@@ -1686,13 +1614,20 @@ Requested path was: {f}
|
||||
if not cmd_opts.no_progressbar_hiding:
|
||||
css += css_hide_progressbar
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
interfaces += [(settings_interface, "Settings", "settings")]
|
||||
|
||||
extensions_interface = ui_extensions.create_ui()
|
||||
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Row(elem_id="quicksettings"):
|
||||
for i, k, item in quicksettings_list:
|
||||
component = create_setting_component(k, is_quicksettings=True)
|
||||
component_dict[k] = component
|
||||
|
||||
settings_interface.gradio_ref = demo
|
||||
parameters_copypaste.integrate_settings_paste_fields(component_dict)
|
||||
parameters_copypaste.run_bind()
|
||||
|
||||
with gr.Tabs(elem_id="tabs") as tabs:
|
||||
for interface, label, ifid in interfaces:
|
||||
@@ -1747,85 +1682,6 @@ Requested path was: {f}
|
||||
component_dict['sd_model_checkpoint'],
|
||||
]
|
||||
)
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
|
||||
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
|
||||
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
|
||||
send_to_img2img.click(
|
||||
fn=lambda img, *args: (image_from_url_text(img),*args),
|
||||
_js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
|
||||
inputs=[txt2img_gallery] + txt2img_fields,
|
||||
outputs=[init_img] + img2img_fields,
|
||||
)
|
||||
|
||||
send_to_inpaint.click(
|
||||
fn=lambda x, *args: (image_from_url_text(x), *args),
|
||||
_js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
|
||||
inputs=[txt2img_gallery] + txt2img_fields,
|
||||
outputs=[init_img_with_mask] + img2img_fields,
|
||||
)
|
||||
|
||||
img2img_send_to_img2img.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_img2img",
|
||||
inputs=[img2img_gallery],
|
||||
outputs=[init_img],
|
||||
)
|
||||
|
||||
img2img_send_to_inpaint.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_inpaint",
|
||||
inputs=[img2img_gallery],
|
||||
outputs=[init_img_with_mask],
|
||||
)
|
||||
|
||||
send_to_extras.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_extras",
|
||||
inputs=[txt2img_gallery],
|
||||
outputs=[extras_image],
|
||||
)
|
||||
|
||||
open_txt2img_folder.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
open_img2img_folder.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
open_extras_folder.click(
|
||||
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
img2img_send_to_extras.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery_extras",
|
||||
inputs=[img2img_gallery],
|
||||
outputs=[extras_image],
|
||||
)
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
}
|
||||
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
|
||||
modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
|
||||
|
||||
ui_config_file = cmd_opts.ui_config_file
|
||||
ui_settings = {}
|
||||
@@ -1845,7 +1701,7 @@ Requested path was: {f}
|
||||
def apply_field(obj, field, condition=None, init_field=None):
|
||||
key = path + "/" + field
|
||||
|
||||
if getattr(obj,'custom_script_source',None) is not None:
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = 'customscript/' + obj.custom_script_source + '/' + key
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
@@ -1905,7 +1761,7 @@ def load_javascript(raw_response):
|
||||
javascript = f'<script>{jsfile.read()}</script>'
|
||||
|
||||
scripts_list = modules.scripts.list_scripts("javascript", ".js")
|
||||
|
||||
|
||||
for basedir, filename, path in scripts_list:
|
||||
with open(path, "r", encoding="utf8") as jsfile:
|
||||
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
||||
|
172
modules/ui_extensions.py
Normal file
172
modules/ui_extensions.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import json
|
||||
import os.path
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
import gradio as gr
|
||||
import html
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
|
||||
|
||||
def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
|
||||
|
||||
|
||||
def apply_and_restart(disable_list, update_list):
|
||||
check_access()
|
||||
|
||||
disabled = json.loads(disable_list)
|
||||
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||
|
||||
update = json.loads(update_list)
|
||||
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
||||
|
||||
update = set(update)
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.name not in update:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.pull()
|
||||
except Exception:
|
||||
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
|
||||
|
||||
def check_updates():
|
||||
check_access()
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.remote is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
except Exception:
|
||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return extension_table()
|
||||
|
||||
|
||||
def extension_table():
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>URL</th>
|
||||
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.can_update:
|
||||
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
||||
else:
|
||||
ext_status = ext.status
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
code += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def install_extension_from_url(dirname, url):
|
||||
check_access()
|
||||
|
||||
assert url, 'No URL specified'
|
||||
|
||||
if dirname is None or dirname == "":
|
||||
*parts, last_part = url.split('/')
|
||||
last_part = last_part.replace(".git", "")
|
||||
|
||||
dirname = last_part
|
||||
|
||||
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
||||
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||
|
||||
assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
repo = git.Repo.clone_from(url, tmpdir)
|
||||
repo.remote().fetch()
|
||||
|
||||
os.rename(tmpdir, target_dir)
|
||||
|
||||
extensions.list_extensions()
|
||||
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
_js="extensions_apply",
|
||||
inputs=[extensions_disabled_list, extensions_update_list],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
check.click(
|
||||
fn=check_updates,
|
||||
_js="extensions_check",
|
||||
inputs=[],
|
||||
outputs=[extensions_table],
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
intall_button = gr.Button(value="Install", variant="primary")
|
||||
intall_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
intall_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
inputs=[install_dirname, install_url],
|
||||
outputs=[extensions_table, intall_result],
|
||||
)
|
||||
|
||||
return ui
|
Reference in New Issue
Block a user