mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-10 10:50:09 +00:00
Merge branch 'AUTOMATIC1111:master' into fix-sd-arch-switch-in-override-settings
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from secrets import compare_digest
|
||||
|
||||
@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
@@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code = res.status_code,
|
||||
ver = req.scope.get('http_version', '0.0'),
|
||||
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot = req.scope.get('scheme', 'err'),
|
||||
method = req.scope.get('method', 'err'),
|
||||
endpoint = endpoint,
|
||||
duration = duration,
|
||||
))
|
||||
return res
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
@@ -79,6 +101,7 @@ class Api:
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
@@ -100,6 +123,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||
@@ -128,15 +152,14 @@ class Api:
|
||||
)
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
@@ -163,16 +186,14 @@ class Api:
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
p = StableDiffusionProcessingImg2Img(**args)
|
||||
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.end()
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
@@ -305,7 +326,7 @@ class Api:
|
||||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
@@ -330,6 +351,26 @@ class Api:
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
"sd_checkpoint": embedding.sd_checkpoint,
|
||||
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||
"shape": embedding.shape,
|
||||
"vectors": embedding.vectors,
|
||||
}
|
||||
|
||||
def convert_embeddings(embeddings):
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
|
@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
@@ -2,9 +2,30 @@ import sys
|
||||
import traceback
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
lines = message.strip().split("\n")
|
||||
max_len = max([len(x) for x in lines])
|
||||
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
for line in lines:
|
||||
print(line, file=sys.stderr)
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
|
||||
|
||||
def display(e: Exception, task):
|
||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
except Exception as e:
|
||||
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
display(task, e)
|
||||
|
@@ -19,8 +19,6 @@ from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.codeformer_model
|
||||
import piexif
|
||||
import piexif.helper
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
||||
@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5)
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
imageArr = []
|
||||
# Also keep track of original file names
|
||||
imageNameArr = []
|
||||
@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
# Extra operation definitions
|
||||
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-gfpgan'
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
return (res, info)
|
||||
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-codeformer'
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
shared.state.job = 'extras-upscale'
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
|
||||
shared.state.textinfo = f'Processing image {image_name}'
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
if opts.enable_pnginfo: # append info before save
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if save_output:
|
||||
# Add upscaler name as a suffix.
|
||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
||||
@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if extras_mode != 2 or show_extras_results :
|
||||
outputs.append(image)
|
||||
|
||||
@@ -242,6 +249,9 @@ def run_pnginfo(image):
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
|
||||
if theta_func1 and not tertiary_model_info:
|
||||
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||
shared.state.end()
|
||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
del theta_2
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
@@ -330,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
@@ -341,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
sd_models.list_models()
|
||||
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
@@ -1,12 +1,13 @@
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared
|
||||
from modules import shared, ui_tempdir
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
@@ -36,9 +37,12 @@ def quote(text):
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
||||
filedata = filedata[0]
|
||||
|
||||
if type(filedata) == dict and filedata.get("is_file", False):
|
||||
filename = filedata["name"]
|
||||
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
|
||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
return Image.open(filename)
|
||||
@@ -93,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
buttons[tab] = gr.Button(f"Send to {tab}")
|
||||
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
||||
return buttons
|
||||
|
||||
|
||||
@@ -102,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
|
||||
bind_list.append([buttons, send_image, send_generate_info])
|
||||
|
||||
|
||||
def send_image_and_dimensions(x):
|
||||
if isinstance(x, Image.Image):
|
||||
img = x
|
||||
else:
|
||||
img = image_from_url_text(x)
|
||||
|
||||
if shared.opts.send_size and isinstance(img, Image.Image):
|
||||
w = img.width
|
||||
h = img.height
|
||||
else:
|
||||
w = gr.update()
|
||||
h = gr.update()
|
||||
|
||||
return img, w, h
|
||||
|
||||
|
||||
def run_bind():
|
||||
for buttons, send_image, send_generate_info in bind_list:
|
||||
for buttons, source_image_component, send_generate_info in bind_list:
|
||||
for tab in buttons:
|
||||
button = buttons[tab]
|
||||
if send_image and paste_fields[tab]["init_img"]:
|
||||
if type(send_image) == gr.Gallery:
|
||||
button.click(
|
||||
fn=lambda x: image_from_url_text(x),
|
||||
_js="extract_image_from_gallery",
|
||||
inputs=[send_image],
|
||||
outputs=[paste_fields[tab]["init_img"]],
|
||||
)
|
||||
else:
|
||||
button.click(
|
||||
fn=lambda x: x,
|
||||
inputs=[send_image],
|
||||
outputs=[paste_fields[tab]["init_img"]],
|
||||
)
|
||||
destination_image_component = paste_fields[tab]["init_img"]
|
||||
fields = paste_fields[tab]["fields"]
|
||||
|
||||
if send_generate_info and paste_fields[tab]["fields"] is not None:
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if source_image_component and destination_image_component:
|
||||
if isinstance(source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
)
|
||||
|
||||
if send_generate_info and fields is not None:
|
||||
if send_generate_info in paste_fields:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else [])
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
)
|
||||
else:
|
||||
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
|
||||
connect_paste(button, fields, send_generate_info)
|
||||
|
||||
button.click(
|
||||
fn=None,
|
||||
@@ -164,6 +190,34 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||
return None
|
||||
|
||||
|
||||
def restore_old_hires_fix_params(res):
|
||||
"""for infotexts that specify old First pass size parameter, convert it into
|
||||
width, height, and hr scale"""
|
||||
|
||||
firstpass_width = res.get('First pass size-1', None)
|
||||
firstpass_height = res.get('First pass size-2', None)
|
||||
|
||||
if firstpass_width is None or firstpass_height is None:
|
||||
return
|
||||
|
||||
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
||||
width = int(res.get("Size-1", 512))
|
||||
height = int(res.get("Size-2", 512))
|
||||
|
||||
if firstpass_width == 0 or firstpass_height == 0:
|
||||
# old algorithm for auto-calculating first pass size
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = width * height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
firstpass_width = math.ceil(scale * width / 64) * 64
|
||||
firstpass_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
res['Size-1'] = firstpass_width
|
||||
res['Size-2'] = firstpass_height
|
||||
res['Hires resize-1'] = width
|
||||
res['Hires resize-2'] = height
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
@@ -221,6 +275,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
hypernet_hash = res.get("Hypernet hash", None)
|
||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
res["Hires resize-2"] = 0
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
hypernetwork.train_mode()
|
||||
|
||||
@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||
# scaler.unscale_(optimizer)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
|
@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
|
||||
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
||||
script_callbacks.image_grid_callback(params)
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
||||
|
||||
for i, img in enumerate(params.imgs):
|
||||
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
||||
|
||||
return grid
|
||||
|
||||
@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
|
||||
|
||||
def resize_image(resize_mode, im, width, height):
|
||||
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
"""
|
||||
Resizes an image with the specified resize_mode, width, and height.
|
||||
|
||||
Args:
|
||||
resize_mode: The mode to use when resizing the image.
|
||||
0: Resize the image to the specified width and height.
|
||||
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
im: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
||||
"""
|
||||
|
||||
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
||||
|
||||
def resize(im, w, h):
|
||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
||||
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
||||
return im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
scale = max(w / im.width, h / im.height)
|
||||
|
||||
if scale > 1.0:
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
||||
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
||||
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
|
||||
|
||||
upscaler = upscalers[0]
|
||||
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
||||
@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image_to_save.mode == 'RGBA':
|
||||
image_to_save = image_to_save.convert("RGB")
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
|
@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
@@ -135,8 +135,9 @@ class InterrogateModels:
|
||||
return caption[0]
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = None
|
||||
|
||||
res = ""
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@@ -177,5 +178,6 @@ class InterrogateModels:
|
||||
res += "<error>"
|
||||
|
||||
self.unload()
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
|
||||
def read(self):
|
||||
if not self.disabled:
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
self.data["free"] = free
|
||||
self.data["total"] = total
|
||||
|
||||
torch_stats = torch.cuda.memory_stats(self.device)
|
||||
self.data["active"] = torch_stats["active.all.current"]
|
||||
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
||||
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
||||
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
||||
self.data["system_peak"] = total - self.data["min_free"]
|
||||
|
||||
|
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
pass
|
||||
|
||||
|
||||
builtin_upscaler_classes = []
|
||||
forbidden_upscaler_classes = set()
|
||||
|
||||
|
||||
def list_builtin_upscalers():
|
||||
load_upscalers()
|
||||
|
||||
builtin_upscaler_classes.clear()
|
||||
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
||||
|
||||
|
||||
def forbid_loaded_nonbuiltin_upscalers():
|
||||
for cls in Upscaler.__subclasses__():
|
||||
if cls not in builtin_upscaler_classes:
|
||||
forbidden_upscaler_classes.add(cls)
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
@@ -139,6 +156,9 @@ def load_upscalers():
|
||||
datas = []
|
||||
commandline_options = vars(shared.cmd_opts)
|
||||
for cls in Upscaler.__subclasses__():
|
||||
if cls in forbidden_upscaler_classes:
|
||||
continue
|
||||
|
||||
name = cls.__name__
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
scaler = cls(commandline_options.get(cmd_name, None))
|
||||
|
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||
return image
|
||||
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
@@ -136,28 +154,12 @@ class StableDiffusionProcessing():
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1)
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
||||
def depth2img_image_conditioning(self, source_image):
|
||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||
@@ -239,7 +241,7 @@ class StableDiffusionProcessing():
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||
self.images = images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
@@ -247,6 +249,7 @@ class Processed:
|
||||
self.subseed = subseed
|
||||
self.subseed_strength = p.subseed_strength
|
||||
self.info = info
|
||||
self.comments = comments
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
@@ -338,13 +341,14 @@ def slerp(val, low, high):
|
||||
|
||||
|
||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
||||
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
||||
xs = []
|
||||
|
||||
# if we have multiple seeds, this means we are working with batch size>1; this then
|
||||
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
||||
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
||||
# produce the same images as with two batches [100], [101].
|
||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
||||
else:
|
||||
sampler_noises = None
|
||||
@@ -384,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
if sampler_noises is not None:
|
||||
cnt = p.sampler.number_of_needed_noises(p)
|
||||
|
||||
if opts.eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
||||
if eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + eta_noise_seed_delta)
|
||||
|
||||
for j in range(cnt):
|
||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||
@@ -418,7 +422,7 @@ def fix_seed(p):
|
||||
p.subseed = get_fixed_seed(p.subseed)
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||
@@ -544,6 +548,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
@@ -647,7 +653,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
@@ -658,76 +664,114 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
self.firstphase_width = firstphase_width
|
||||
self.firstphase_height = firstphase_height
|
||||
self.hr_scale = hr_scale
|
||||
self.hr_upscaler = hr_upscaler
|
||||
self.hr_second_pass_steps = hr_second_pass_steps
|
||||
self.hr_resize_x = hr_resize_x
|
||||
self.hr_resize_y = hr_resize_y
|
||||
self.hr_upscale_to_x = hr_resize_x
|
||||
self.hr_upscale_to_y = hr_resize_y
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
|
||||
self.hr_scale = self.width / firstphase_width
|
||||
self.width = firstphase_width
|
||||
self.height = firstphase_height
|
||||
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter * 2
|
||||
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
||||
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
||||
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
||||
self.hr_upscale_to_y = int(self.height * self.hr_scale)
|
||||
else:
|
||||
state.job_count = state.job_count * 2
|
||||
self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
|
||||
|
||||
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
||||
|
||||
if self.firstphase_width == 0 or self.firstphase_height == 0:
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = self.width * self.height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
||||
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
||||
firstphase_width_truncated = int(scale * self.width)
|
||||
firstphase_height_truncated = int(scale * self.height)
|
||||
|
||||
else:
|
||||
|
||||
width_ratio = self.width / self.firstphase_width
|
||||
height_ratio = self.height / self.firstphase_height
|
||||
|
||||
if width_ratio > height_ratio:
|
||||
firstphase_width_truncated = self.firstphase_width
|
||||
firstphase_height_truncated = self.firstphase_width * self.height / self.width
|
||||
if self.hr_resize_y == 0:
|
||||
self.hr_upscale_to_x = self.hr_resize_x
|
||||
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
||||
elif self.hr_resize_x == 0:
|
||||
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
||||
self.hr_upscale_to_y = self.hr_resize_y
|
||||
else:
|
||||
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
||||
firstphase_height_truncated = self.firstphase_height
|
||||
target_w = self.hr_resize_x
|
||||
target_h = self.hr_resize_y
|
||||
src_ratio = self.width / self.height
|
||||
dst_ratio = self.hr_resize_x / self.hr_resize_y
|
||||
|
||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
if src_ratio < dst_ratio:
|
||||
self.hr_upscale_to_x = self.hr_resize_x
|
||||
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
||||
else:
|
||||
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
||||
self.hr_upscale_to_y = self.hr_resize_y
|
||||
|
||||
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||
|
||||
# special case: the user has chosen to do nothing
|
||||
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
||||
self.enable_hr = False
|
||||
self.denoising_strength = None
|
||||
self.extra_generation_params.pop("Hires upscale", None)
|
||||
self.extra_generation_params.pop("Hires resize", None)
|
||||
return
|
||||
|
||||
if not state.processing_has_refined_job_count:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter
|
||||
|
||||
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
||||
state.job_count = state.job_count * 2
|
||||
state.processing_has_refined_job_count = True
|
||||
|
||||
if self.hr_second_pass_steps:
|
||||
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
|
||||
|
||||
if self.hr_upscaler is not None:
|
||||
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||
if self.enable_hr and latent_scale_mode is None:
|
||||
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
|
||||
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
return samples
|
||||
|
||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||
|
||||
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||
def save_intermediate(image, index):
|
||||
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||
|
||||
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
||||
return
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
image = sd_samplers.sample_to_image(image, index)
|
||||
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
||||
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
||||
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
|
||||
|
||||
if opts.use_scale_latent_for_hires_fix:
|
||||
if latent_scale_mode is not None:
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
||||
|
||||
# Avoid making the inpainting conditioning unless necessary as
|
||||
# this does need some extra compute to decode / encode the image again.
|
||||
@@ -747,7 +791,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
save_intermediate(image, i)
|
||||
|
||||
image = images.resize_image(0, image, self.width, self.height)
|
||||
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
batch_images.append(image)
|
||||
@@ -764,13 +808,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
return samples
|
||||
|
||||
|
@@ -51,6 +51,13 @@ class UiTrainTabParams:
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
class ImageGridLoopParams:
|
||||
def __init__(self, imgs, cols, rows):
|
||||
self.imgs = imgs
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
@@ -63,6 +70,7 @@ callback_map = dict(
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs):
|
||||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def image_grid_callback(params: ImageGridLoopParams):
|
||||
for c in callback_map['callbacks_image_grid']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
@@ -255,3 +271,11 @@ def on_before_component(callback):
|
||||
def on_after_component(callback):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback)
|
||||
|
||||
|
||||
def on_image_grid(callback):
|
||||
"""register a function to be called before making an image grid.
|
||||
The callback is called with one argument:
|
||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||
|
@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
|
||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||
|
||||
@@ -35,26 +35,35 @@ def apply_optimizations():
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
optimization_method = None
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
print("Applying xformers cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
optimization_method = 'xformers'
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||
if not invokeAI_mps_available and shared.device.type == 'mps':
|
||||
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
else:
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
optimization_method = 'Doggettx'
|
||||
|
||||
return optimization_method
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
@@ -68,27 +77,37 @@ def fix_checkpoint():
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
comments = []
|
||||
layers = None
|
||||
circular_enabled = False
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
self.optimization_method = apply_optimizations()
|
||||
|
||||
apply_optimizations()
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
fix_checkpoint()
|
||||
|
||||
def flatten(el):
|
||||
@@ -101,7 +120,11 @@ class StableDiffusionModelHijack:
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
@@ -129,8 +152,8 @@ class StableDiffusionModelHijack:
|
||||
|
||||
def tokenize(self, text):
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
|
@@ -5,7 +5,6 @@ import torch
|
||||
from modules import prompt_parser, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
@@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
self.comma_token = vocab.get(',</w>', None)
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
@@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
@@ -12,191 +12,6 @@ from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch DDIMSampler methods from RunwayML repo directly.
|
||||
# Adapted from:
|
||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||
# =================================================================================================
|
||||
@torch.no_grad()
|
||||
def sample_ddim(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch PLMSSampler methods.
|
||||
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
|
||||
# Adapted from:
|
||||
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||
# =================================================================================================
|
||||
@torch.no_grad()
|
||||
def sample_plms(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
@@ -280,61 +95,17 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||
# Adapted from:
|
||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
|
||||
# =================================================================================================
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
|
||||
class LatentInpaintDiffusion(LatentDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.masked_image_key = masked_image_key
|
||||
assert self.masked_image_key in concat_keys
|
||||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(checkpoint_info.config).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
# this file should be cleaned up later if everything turns out to work fine
|
||||
|
||||
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
# ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
# ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
|
34
modules/sd_hijack_xlmr.py
Normal file
34
modules/sd_hijack_xlmr.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.id_start = wrapped.config.bos_token_id
|
||||
self.id_end = wrapped.config.eos_token_id
|
||||
self.id_pad = wrapped.config.pad_token_id
|
||||
|
||||
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||
# layer to work with - you have to use the last
|
||||
|
||||
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||
z = features['projection_state']
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.roberta.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
@@ -48,6 +48,14 @@ def checkpoint_tiles():
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||
@@ -73,7 +81,7 @@ def list_models():
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
shared.opts.data['sd_model_checkpoint'] = title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
@@ -81,12 +89,7 @@ def list_models():
|
||||
h = model_hash(filename)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
|
||||
basename, _ = os.path.splitext(filename)
|
||||
config = basename + ".yaml"
|
||||
if not os.path.exists(config):
|
||||
config = shared.cmd_opts.config
|
||||
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
@@ -228,6 +234,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
@@ -276,12 +284,14 @@ def enable_midas_autodownload():
|
||||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_info.config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_info.config}")
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
@@ -289,7 +299,7 @@ def load_model(checkpoint_info=None):
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
@@ -298,9 +308,6 @@ def load_model(checkpoint_info=None):
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
@@ -310,6 +317,7 @@ def load_model(checkpoint_info=None):
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@@ -322,23 +330,29 @@ def load_model(checkpoint_info=None):
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("Model loaded.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
@@ -351,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("Weights loaded.")
|
||||
|
||||
return sd_model
|
||||
|
@@ -97,8 +97,9 @@ sampler_extra_params = {
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = p.steps - 1
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
@@ -465,7 +466,9 @@ class KDiffusionSampler:
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks
|
||||
from modules.paths import models_path
|
||||
@@ -30,6 +31,7 @@ base_vae = None
|
||||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
if vae_file:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
|
@@ -14,7 +14,7 @@ import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, sd_vae, extensions, script_loading
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ demo = None
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
@@ -82,6 +82,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
@@ -109,6 +110,17 @@ restricted_opts = {
|
||||
"outdir_save",
|
||||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
"sampler",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"batch",
|
||||
"scripts",
|
||||
]
|
||||
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
@@ -141,6 +153,7 @@ class State:
|
||||
job = ""
|
||||
job_no = 0
|
||||
job_count = 0
|
||||
processing_has_refined_job_count = False
|
||||
job_timestamp = '0'
|
||||
sampling_step = 0
|
||||
sampling_steps = 0
|
||||
@@ -168,9 +181,10 @@ class State:
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
@@ -181,6 +195,7 @@ class State:
|
||||
def begin(self):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
@@ -201,12 +216,13 @@ class State:
|
||||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
@@ -218,6 +234,7 @@ class State:
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
||||
state = State()
|
||||
|
||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
||||
@@ -327,7 +344,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
@@ -345,7 +361,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
@@ -356,6 +372,7 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
@@ -367,13 +384,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
@@ -405,7 +426,10 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
|
||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
||||
@@ -476,7 +500,12 @@ class Options:
|
||||
return False
|
||||
|
||||
if self.data_labels[key].onchange is not None:
|
||||
self.data_labels[key].onchange()
|
||||
try:
|
||||
self.data_labels[key].onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -539,6 +568,15 @@ opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
|
||||
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
||||
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
||||
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
||||
}
|
||||
|
||||
sd_upscalers = []
|
||||
|
||||
sd_model = None
|
||||
@@ -572,7 +610,7 @@ class TotalTQDM:
|
||||
return
|
||||
if self._tqdm is None:
|
||||
self.reset()
|
||||
self._tqdm.total=new_total
|
||||
self._tqdm.total = new_total
|
||||
|
||||
def clear(self):
|
||||
if self._tqdm is not None:
|
||||
|
@@ -58,14 +58,19 @@ class LearnRateScheduler:
|
||||
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
def step(self, step_number):
|
||||
if step_number < self.end_step:
|
||||
return
|
||||
return False
|
||||
|
||||
try:
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
except Exception:
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if not self.step(step_number):
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
|
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
|
||||
files = listfiles(src)
|
||||
|
||||
shared.state.job = "preprocess"
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
|
@@ -23,9 +23,12 @@ class Embedding:
|
||||
self.vec = vec
|
||||
self.name = name
|
||||
self.step = step
|
||||
self.shape = None
|
||||
self.vectors = 0
|
||||
self.cached_checksum = None
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
@@ -39,6 +42,13 @@ class Embedding:
|
||||
|
||||
torch.save(embedding_data, filename)
|
||||
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
||||
optimizer_saved_dict = {
|
||||
'hash': self.checksum(),
|
||||
'optimizer_state_dict': self.optimizer_state_dict,
|
||||
}
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def checksum(self):
|
||||
if self.cached_checksum is not None:
|
||||
return self.cached_checksum
|
||||
@@ -57,8 +67,10 @@ class EmbeddingDatabase:
|
||||
def __init__(self, embeddings_dir):
|
||||
self.ids_lookup = {}
|
||||
self.word_embeddings = {}
|
||||
self.skipped_embeddings = {}
|
||||
self.dir_mtime = None
|
||||
self.embeddings_dir = embeddings_dir
|
||||
self.expected_shape = -1
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
|
||||
@@ -75,21 +87,26 @@ class EmbeddingDatabase:
|
||||
|
||||
return embedding
|
||||
|
||||
def load_textual_inversion_embeddings(self):
|
||||
def get_expected_shape(self):
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
||||
mt = os.path.getmtime(self.embeddings_dir)
|
||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
|
||||
self.dir_mtime = mt
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
data = []
|
||||
|
||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
@@ -97,8 +114,10 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
@@ -122,7 +141,13 @@ class EmbeddingDatabase:
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
self.skipped_embeddings[name] = embedding
|
||||
|
||||
for fn in os.listdir(self.embeddings_dir):
|
||||
try:
|
||||
@@ -137,8 +162,9 @@ class EmbeddingDatabase:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
@@ -225,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
|
||||
shared.state.job = "train-embedding"
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
@@ -267,7 +294,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
return embedding, filename
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||
None
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
@@ -285,6 +317,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
if shared.opts.save_optimizer_state:
|
||||
optimizer_state_dict = None
|
||||
if os.path.exists(filename + '.optim'):
|
||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
@@ -295,12 +340,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
img_c = None
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
@@ -318,14 +365,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
with devices.autocast():
|
||||
# c = stack_conds(batch.cond).to(devices.device)
|
||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||
# print(mask)
|
||||
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
|
||||
if is_training_inpainting_model:
|
||||
if img_c is None:
|
||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||
|
||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||
else:
|
||||
cond = c
|
||||
|
||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
||||
del x
|
||||
|
||||
_loss_step += loss.item()
|
||||
@@ -334,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
embedding.step += 1
|
||||
@@ -352,9 +411,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
#if shared.opts.save_optimizer_state:
|
||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||
@@ -444,7 +501,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
@@ -456,7 +513,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
|
||||
return embedding, filename
|
||||
|
||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
@@ -467,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
embedding.name = embedding_name
|
||||
embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
embedding.save(filename)
|
||||
except:
|
||||
embedding.sd_checkpoint = old_sd_checkpoint
|
||||
|
@@ -8,7 +8,7 @@ import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
@@ -33,8 +33,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||
tiling=tiling,
|
||||
enable_hr=enable_hr,
|
||||
denoising_strength=denoising_strength if enable_hr else None,
|
||||
firstphase_width=firstphase_width if enable_hr else None,
|
||||
firstphase_height=firstphase_height if enable_hr else None,
|
||||
hr_scale=hr_scale,
|
||||
hr_upscaler=hr_upscaler,
|
||||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
@@ -59,4 +62,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
595
modules/ui.py
595
modules/ui.py
File diff suppressed because it is too large
Load Diff
25
modules/ui_components.py
Normal file
25
modules/ui_components.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "row"
|
||||
|
||||
|
||||
class FormGroup(gr.Group, gr.components.FormComponent):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "group"
|
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@@ -12,10 +13,29 @@ from modules import shared
|
||||
Savedfile = namedtuple("Savedfile", ["name"])
|
||||
|
||||
|
||||
def register_tmp_file(gradio, filename):
|
||||
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
|
||||
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
|
||||
|
||||
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
|
||||
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
|
||||
|
||||
|
||||
def check_tmp_file(gradio, filename):
|
||||
if hasattr(gradio, 'temp_file_sets'):
|
||||
return any([filename in fileset for fileset in gradio.temp_file_sets])
|
||||
|
||||
if hasattr(gradio, 'temp_dirs'):
|
||||
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
||||
if already_saved_as and os.path.isfile(already_saved_as):
|
||||
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
|
||||
register_tmp_file(shared.demo, already_saved_as)
|
||||
|
||||
file_obj = Savedfile(already_saved_as)
|
||||
return file_obj
|
||||
|
||||
@@ -44,7 +64,7 @@ def on_tmpdir_changed():
|
||||
|
||||
os.makedirs(shared.opts.temp_dir, exist_ok=True)
|
||||
|
||||
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
|
||||
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
|
||||
|
||||
|
||||
def cleanup_tmpdr():
|
||||
|
@@ -53,10 +53,10 @@ class Upscaler:
|
||||
def do_upscale(self, img: PIL.Image, selected_model: str):
|
||||
return img
|
||||
|
||||
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
||||
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
||||
self.scale = scale
|
||||
dest_w = img.width * scale
|
||||
dest_h = img.height * scale
|
||||
dest_w = int(img.width * scale)
|
||||
dest_h = int(img.height * scale)
|
||||
|
||||
for i in range(3):
|
||||
shape = (img.width, img.height)
|
||||
|
137
modules/xlmr.py
Normal file
137
modules/xlmr.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
|
||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
config.bos_token_id=0
|
||||
config.eos_token_id=2
|
||||
config.hidden_act='gelu'
|
||||
config.hidden_dropout_prob=0.1
|
||||
config.hidden_size=1024
|
||||
config.initializer_range=0.02
|
||||
config.intermediate_size=4096
|
||||
config.layer_norm_eps=1e-05
|
||||
config.max_position_embeddings=514
|
||||
|
||||
config.num_attention_heads=16
|
||||
config.num_hidden_layers=24
|
||||
config.output_past=True
|
||||
config.pad_token_id=1
|
||||
config.position_embedding_type= "absolute"
|
||||
|
||||
config.type_vocab_size= 1
|
||||
config.use_cache=True
|
||||
config.vocab_size= 250002
|
||||
config.project_dim = 768
|
||||
config.learn_encoder = False
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
self.pooler = lambda x: x[:,0]
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) :
|
||||
r"""
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# last module outputs
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
# project every module
|
||||
sequence_output_ln = self.pre_LN(sequence_output)
|
||||
|
||||
# pooler
|
||||
pooler_output = self.pooler(sequence_output_ln)
|
||||
pooler_output = self.transformation(pooler_output)
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
|
||||
return {
|
||||
'pooler_output':pooler_output,
|
||||
'last_hidden_state':outputs.last_hidden_state,
|
||||
'hidden_states':outputs.hidden_states,
|
||||
'attentions':outputs.attentions,
|
||||
'projection_state':projection_state,
|
||||
'sequence_out': sequence_output
|
||||
}
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
Reference in New Issue
Block a user