mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-10 10:50:09 +00:00
Merge remote-tracking branch 'upstream/master' into sub-quad_attn_opt
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from secrets import compare_digest
|
||||
|
||||
@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
@@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code = res.status_code,
|
||||
ver = req.scope.get('http_version', '0.0'),
|
||||
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot = req.scope.get('scheme', 'err'),
|
||||
method = req.scope.get('method', 'err'),
|
||||
endpoint = endpoint,
|
||||
duration = duration,
|
||||
))
|
||||
return res
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
@@ -79,6 +101,7 @@ class Api:
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
@@ -303,7 +326,7 @@ class Api:
|
||||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
@@ -2,9 +2,30 @@ import sys
|
||||
import traceback
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
lines = message.strip().split("\n")
|
||||
max_len = max([len(x) for x in lines])
|
||||
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
for line in lines:
|
||||
print(line, file=sys.stderr)
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
|
||||
|
||||
def display(e: Exception, task):
|
||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
except Exception as e:
|
||||
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
display(task, e)
|
||||
|
@@ -19,8 +19,6 @@ from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.codeformer_model
|
||||
import piexif
|
||||
import piexif.helper
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
||||
@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5)
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
imageArr = []
|
||||
# Also keep track of original file names
|
||||
imageNameArr = []
|
||||
@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
# Extra operation definitions
|
||||
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-gfpgan'
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
return (res, info)
|
||||
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-codeformer'
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
shared.state.job = 'extras-upscale'
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
|
||||
shared.state.textinfo = f'Processing image {image_name}'
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
if opts.enable_pnginfo: # append info before save
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if save_output:
|
||||
# Add upscaler name as a suffix.
|
||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
||||
@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if extras_mode != 2 or show_extras_results :
|
||||
outputs.append(image)
|
||||
|
||||
@@ -242,6 +249,9 @@ def run_pnginfo(image):
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
|
||||
if theta_func1 and not tertiary_model_info:
|
||||
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||
shared.state.end()
|
||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
del theta_2
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
@@ -303,8 +318,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
|
||||
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
if save_as_half:
|
||||
@@ -332,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
@@ -343,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
sd_models.list_models()
|
||||
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
@@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res):
|
||||
firstpass_width = math.ceil(scale * width / 64) * 64
|
||||
firstpass_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
|
||||
|
||||
res['Size-1'] = firstpass_width
|
||||
res['Size-2'] = firstpass_height
|
||||
res['Hires upscale'] = hr_scale
|
||||
res['Hires resize-1'] = width
|
||||
res['Hires resize-2'] = height
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
@@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
hypernet_hash = res.get("Hypernet hash", None)
|
||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
res["Hires resize-2"] = 0
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
return res
|
||||
|
@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
hypernetwork.train_mode()
|
||||
|
||||
@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||
# scaler.unscale_(optimizer)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
|
@@ -136,7 +136,8 @@ class InterrogateModels:
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = ""
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@@ -177,5 +178,6 @@ class InterrogateModels:
|
||||
res += "<error>"
|
||||
|
||||
self.unload()
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||
return image
|
||||
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
@@ -136,28 +154,12 @@ class StableDiffusionProcessing():
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1)
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
||||
def depth2img_image_conditioning(self, source_image):
|
||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||
@@ -420,7 +422,7 @@ def fix_seed(p):
|
||||
p.subseed = get_fixed_seed(p.subseed)
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||
@@ -544,6 +546,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
@@ -658,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
self.hr_scale = hr_scale
|
||||
self.hr_upscaler = hr_upscaler
|
||||
self.hr_second_pass_steps = hr_second_pass_steps
|
||||
self.hr_resize_x = hr_resize_x
|
||||
self.hr_resize_y = hr_resize_y
|
||||
self.hr_upscale_to_x = hr_resize_x
|
||||
self.hr_upscale_to_y = hr_resize_y
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
|
||||
@@ -671,14 +680,60 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.width = firstphase_width
|
||||
self.height = firstphase_height
|
||||
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter * 2
|
||||
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
||||
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
||||
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
||||
self.hr_upscale_to_y = int(self.height * self.hr_scale)
|
||||
else:
|
||||
state.job_count = state.job_count * 2
|
||||
self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
|
||||
|
||||
if self.hr_resize_y == 0:
|
||||
self.hr_upscale_to_x = self.hr_resize_x
|
||||
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
||||
elif self.hr_resize_x == 0:
|
||||
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
||||
self.hr_upscale_to_y = self.hr_resize_y
|
||||
else:
|
||||
target_w = self.hr_resize_x
|
||||
target_h = self.hr_resize_y
|
||||
src_ratio = self.width / self.height
|
||||
dst_ratio = self.hr_resize_x / self.hr_resize_y
|
||||
|
||||
if src_ratio < dst_ratio:
|
||||
self.hr_upscale_to_x = self.hr_resize_x
|
||||
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
||||
else:
|
||||
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
||||
self.hr_upscale_to_y = self.hr_resize_y
|
||||
|
||||
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||
|
||||
# special case: the user has chosen to do nothing
|
||||
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
||||
self.enable_hr = False
|
||||
self.denoising_strength = None
|
||||
self.extra_generation_params.pop("Hires upscale", None)
|
||||
self.extra_generation_params.pop("Hires resize", None)
|
||||
return
|
||||
|
||||
if not state.processing_has_refined_job_count:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter
|
||||
|
||||
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
||||
state.job_count = state.job_count * 2
|
||||
state.processing_has_refined_job_count = True
|
||||
|
||||
if self.hr_second_pass_steps:
|
||||
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
|
||||
|
||||
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
||||
if self.hr_upscaler is not None:
|
||||
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
||||
|
||||
@@ -695,8 +750,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
|
||||
target_width = int(self.width * self.hr_scale)
|
||||
target_height = int(self.height * self.hr_scale)
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
|
||||
def save_intermediate(image, index):
|
||||
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||
@@ -705,15 +760,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
return
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
image = sd_samplers.sample_to_image(image, index)
|
||||
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
||||
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
||||
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
|
||||
|
||||
if latent_scale_mode is not None:
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
|
||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
|
||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
||||
|
||||
# Avoid making the inpainting conditioning unless necessary as
|
||||
# this does need some extra compute to decode / encode the image again.
|
||||
@@ -750,13 +806,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
return samples
|
||||
|
||||
|
@@ -33,25 +33,34 @@ def apply_optimizations():
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
optimization_method = None
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
print("Applying xformers cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
optimization_method = 'xformers'
|
||||
elif cmd_opts.opt_sub_quad_attention:
|
||||
print("Applying sub-quadratic cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||
optimization_method = 'sub-quadratic'
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
optimization_method = 'Doggettx'
|
||||
|
||||
return optimization_method
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
@@ -72,6 +81,7 @@ class StableDiffusionModelHijack:
|
||||
layers = None
|
||||
circular_enabled = False
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
@@ -91,7 +101,7 @@ class StableDiffusionModelHijack:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
apply_optimizations()
|
||||
self.optimization_method = apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
|
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(checkpoint_info.config).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
|
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
@@ -48,6 +48,14 @@ def checkpoint_tiles():
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||
@@ -73,7 +81,7 @@ def list_models():
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
shared.opts.data['sd_model_checkpoint'] = title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
@@ -81,12 +89,7 @@ def list_models():
|
||||
h = model_hash(filename)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
|
||||
basename, _ = os.path.splitext(filename)
|
||||
config = basename + ".yaml"
|
||||
if not os.path.exists(config):
|
||||
config = shared.cmd_opts.config
|
||||
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
@@ -278,12 +284,14 @@ def enable_midas_autodownload():
|
||||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_info.config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_info.config}")
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
@@ -291,7 +299,7 @@ def load_model(checkpoint_info=None):
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
@@ -300,9 +308,6 @@ def load_model(checkpoint_info=None):
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
@@ -312,6 +317,7 @@ def load_model(checkpoint_info=None):
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@@ -336,14 +342,17 @@ def load_model(checkpoint_info=None):
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
@@ -356,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("Weights loaded.")
|
||||
|
||||
return sd_model
|
||||
|
@@ -97,8 +97,9 @@ sampler_extra_params = {
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = p.steps - 1
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
|
@@ -14,7 +14,7 @@ import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, sd_vae, extensions, script_loading
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
@@ -156,6 +157,7 @@ class State:
|
||||
job = ""
|
||||
job_no = 0
|
||||
job_count = 0
|
||||
processing_has_refined_job_count = False
|
||||
job_timestamp = '0'
|
||||
sampling_step = 0
|
||||
sampling_steps = 0
|
||||
@@ -186,6 +188,7 @@ class State:
|
||||
"interrupted": self.interrupted,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
@@ -196,6 +199,7 @@ class State:
|
||||
def begin(self):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
@@ -216,12 +220,13 @@ class State:
|
||||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
@@ -233,6 +238,7 @@ class State:
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
||||
state = State()
|
||||
|
||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
||||
@@ -359,7 +365,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
@@ -498,7 +504,12 @@ class Options:
|
||||
return False
|
||||
|
||||
if self.data_labels[key].onchange is not None:
|
||||
self.data_labels[key].onchange()
|
||||
try:
|
||||
self.data_labels[key].onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -563,8 +574,11 @@ if os.path.exists(config_filename):
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": "bilinear",
|
||||
"Latent (nearest)": "nearest",
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
|
||||
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
||||
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
||||
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
||||
}
|
||||
|
||||
sd_upscalers = []
|
||||
@@ -600,7 +614,7 @@ class TotalTQDM:
|
||||
return
|
||||
if self._tqdm is None:
|
||||
self.reset()
|
||||
self._tqdm.total=new_total
|
||||
self._tqdm.total = new_total
|
||||
|
||||
def clear(self):
|
||||
if self._tqdm is not None:
|
||||
|
@@ -58,14 +58,19 @@ class LearnRateScheduler:
|
||||
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
def step(self, step_number):
|
||||
if step_number < self.end_step:
|
||||
return
|
||||
return False
|
||||
|
||||
try:
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
except Exception:
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if not self.step(step_number):
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
|
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
|
||||
files = listfiles(src)
|
||||
|
||||
shared.state.job = "preprocess"
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
|
@@ -28,6 +28,7 @@ class Embedding:
|
||||
self.cached_checksum = None
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
@@ -41,6 +42,13 @@ class Embedding:
|
||||
|
||||
torch.save(embedding_data, filename)
|
||||
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
||||
optimizer_saved_dict = {
|
||||
'hash': self.checksum(),
|
||||
'optimizer_state_dict': self.optimizer_state_dict,
|
||||
}
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def checksum(self):
|
||||
if self.cached_checksum is not None:
|
||||
return self.cached_checksum
|
||||
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
@@ -240,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
|
||||
shared.state.job = "train-embedding"
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
@@ -282,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
return embedding, filename
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||
None
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
@@ -300,6 +317,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
if shared.opts.save_optimizer_state:
|
||||
optimizer_state_dict = None
|
||||
if os.path.exists(filename + '.optim'):
|
||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
@@ -315,6 +345,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
img_c = None
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
@@ -332,14 +365,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
with devices.autocast():
|
||||
# c = stack_conds(batch.cond).to(devices.device)
|
||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||
# print(mask)
|
||||
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
|
||||
if is_training_inpainting_model:
|
||||
if img_c is None:
|
||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||
|
||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||
else:
|
||||
cond = c
|
||||
|
||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
||||
del x
|
||||
|
||||
_loss_step += loss.item()
|
||||
@@ -348,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
embedding.step += 1
|
||||
@@ -366,9 +411,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
#if shared.opts.save_optimizer_state:
|
||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||
@@ -458,7 +501,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
@@ -470,7 +513,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
|
||||
return embedding, filename
|
||||
|
||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
@@ -481,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
embedding.name = embedding_name
|
||||
embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
embedding.save(filename)
|
||||
except:
|
||||
embedding.sd_checkpoint = old_sd_checkpoint
|
||||
|
@@ -8,7 +8,7 @@ import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
@@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||
denoising_strength=denoising_strength if enable_hr else None,
|
||||
hr_scale=hr_scale,
|
||||
hr_upscaler=hr_upscaler,
|
||||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
|
@@ -162,16 +162,14 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
|
||||
|
||||
def calc_time_left(progress, threshold, label, force_display):
|
||||
def calc_time_left(progress, threshold, label, force_display, show_eta):
|
||||
if progress == 0:
|
||||
return ""
|
||||
else:
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
if (eta_relative > threshold and progress > 0.02) or force_display:
|
||||
if (eta_relative > threshold and show_eta) or force_display:
|
||||
if eta_relative > 3600:
|
||||
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
||||
elif eta_relative > 60:
|
||||
@@ -193,7 +191,10 @@ def check_progress_call(id_part):
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
|
||||
# Show progress percentage and time left at the same moment, and base it also on steps done
|
||||
show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
|
||||
|
||||
time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
|
||||
if time_left != "":
|
||||
shared.state.time_left_force_display = True
|
||||
|
||||
@@ -201,7 +202,7 @@ def check_progress_call(id_part):
|
||||
|
||||
progressbar = ""
|
||||
if opts.show_progressbar:
|
||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
|
||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
|
||||
|
||||
image = gr_show(False)
|
||||
preview_visibility = gr_show(False)
|
||||
@@ -635,10 +636,11 @@ def create_sampler_and_steps_selection(choices, tabname):
|
||||
if opts.samplers_in_dropdown:
|
||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
|
||||
sampler_index.save_to_config = True
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||
else:
|
||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
|
||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||
|
||||
return steps, sampler_index
|
||||
@@ -707,10 +709,16 @@ def create_ui():
|
||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
||||
|
||||
elif category == "hires_fix":
|
||||
with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
||||
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
||||
with FormRow(elem_id="txt2img_hires_fix_row1"):
|
||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row2"):
|
||||
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
||||
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||
|
||||
elif category == "batch":
|
||||
if not opts.dimensions_and_batch_together:
|
||||
@@ -751,6 +759,9 @@ def create_ui():
|
||||
denoising_strength,
|
||||
hr_scale,
|
||||
hr_upscaler,
|
||||
hr_second_pass_steps,
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
@@ -802,6 +813,9 @@ def create_ui():
|
||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||
(hr_scale, "Hires upscale"),
|
||||
(hr_upscaler, "Hires upscaler"),
|
||||
(hr_second_pass_steps, "Hires steps"),
|
||||
(hr_resize_x, "Hires resize-1"),
|
||||
(hr_resize_y, "Hires resize-2"),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||
@@ -1279,38 +1293,48 @@ def create_ui():
|
||||
|
||||
with gr.Tab(label="Train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
with gr.Row():
|
||||
with FormRow():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
with gr.Row():
|
||||
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||
with gr.Row():
|
||||
|
||||
with FormRow():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||
|
||||
with FormRow():
|
||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||
|
||||
with FormRow():
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
||||
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
||||
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
||||
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
||||
|
||||
with FormRow():
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
||||
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
|
||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
|
||||
with gr.Row():
|
||||
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
||||
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
||||
with gr.Row():
|
||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
||||
|
||||
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
||||
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
||||
|
||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
||||
|
||||
with gr.Row():
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
||||
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
||||
|
||||
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||
|
||||
@@ -1400,6 +1424,8 @@ def create_ui():
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
clip_grad_mode,
|
||||
clip_grad_value,
|
||||
shuffle_tags,
|
||||
tag_drop_out,
|
||||
latent_sampling_method,
|
||||
@@ -1429,6 +1455,8 @@ def create_ui():
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
clip_grad_mode,
|
||||
clip_grad_value,
|
||||
shuffle_tags,
|
||||
tag_drop_out,
|
||||
latent_sampling_method,
|
||||
@@ -1793,6 +1821,7 @@ def create_ui():
|
||||
visit(img2img_interface, loadsave, "img2img")
|
||||
visit(extras_interface, loadsave, "extras")
|
||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||
visit(train_interface, loadsave, "train")
|
||||
|
||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||
|
Reference in New Issue
Block a user