mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-09 13:49:48 +00:00
Merge branch 'AUTOMATIC1111:master' into master
This commit is contained in:
@@ -7,6 +7,7 @@ import uvicorn
|
||||
from fastapi import Body, APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from typing import List
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
@@ -15,12 +16,12 @@ from PIL import Image
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
|
||||
@@ -65,7 +66,7 @@ class Api:
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
|
||||
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
|
||||
|
||||
|
||||
|
||||
@@ -111,7 +112,11 @@ class Api:
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
|
||||
if (not img2imgreq.include_init_images):
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
|
||||
|
||||
def extrasapi(self):
|
||||
raise NotImplementedError
|
||||
|
@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
field_exclude: bool = False
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
@@ -78,7 +79,8 @@ class PydanticModelGenerator:
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"]))
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
@@ -86,7 +88,7 @@ class PydanticModelGenerator:
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
).generate_model()
|
@@ -5,6 +5,7 @@ import html
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
|
||||
else:
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||
normal_(w, mean=0.0, std=0.01)
|
||||
normal_(b, mean=0.0, std=0.005)
|
||||
elif weight_init == 'XavierUniform':
|
||||
xavier_uniform_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'XavierNormal':
|
||||
xavier_normal_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingUniform':
|
||||
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingNormal':
|
||||
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
self.to(devices.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
@@ -105,7 +126,7 @@ class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
@@ -114,13 +135,14 @@ class Hypernetwork:
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.activation_func = activation_func
|
||||
self.weight_init = weight_init
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
@@ -144,6 +166,7 @@ class Hypernetwork:
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
@@ -158,15 +181,21 @@ class Hypernetwork:
|
||||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
@@ -458,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
activation_func=activation_func,
|
||||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
|
@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
max_filename_part_length = 128
|
||||
|
||||
@@ -343,7 +343,7 @@ class FilenameGenerator:
|
||||
def datetime(self, *args):
|
||||
time_datetime = datetime.datetime.now()
|
||||
|
||||
time_format = args[0] if len(args) > 0 else self.default_time_format
|
||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||
try:
|
||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
||||
@@ -362,9 +362,9 @@ class FilenameGenerator:
|
||||
|
||||
for m in re_pattern.finditer(x):
|
||||
text, pattern = m.groups()
|
||||
res += text
|
||||
|
||||
if pattern is None:
|
||||
res += text
|
||||
continue
|
||||
|
||||
pattern_args = []
|
||||
@@ -385,12 +385,9 @@ class FilenameGenerator:
|
||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if replacement is None:
|
||||
res += f'[{pattern}]'
|
||||
else:
|
||||
if replacement is not None:
|
||||
res += str(replacement)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
res += f'[{pattern}]'
|
||||
|
||||
@@ -454,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
"""
|
||||
namegen = FilenameGenerator(p, seed, prompt)
|
||||
|
||||
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if existing_info is not None:
|
||||
for k, v in existing_info.items():
|
||||
pnginfo.add_text(k, str(v))
|
||||
|
||||
pnginfo.add_text(pnginfo_section_name, info)
|
||||
else:
|
||||
pnginfo = None
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
||||
@@ -492,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
if add_number:
|
||||
basecount = get_next_sequence_number(path, basename)
|
||||
fullfn = None
|
||||
fullfn_without_extension = None
|
||||
for i in range(500):
|
||||
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
||||
if not os.path.exists(fullfn):
|
||||
break
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, file_decoration)
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, forced_filename)
|
||||
|
||||
pnginfo = existing_info or {}
|
||||
if info is not None:
|
||||
pnginfo[pnginfo_section_name] = info
|
||||
|
||||
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
||||
script_callbacks.before_image_saved_callback(params)
|
||||
|
||||
image = params.image
|
||||
fullfn = params.filename
|
||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
|
||||
def exif_bytes():
|
||||
return piexif.dump({
|
||||
@@ -513,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
},
|
||||
})
|
||||
|
||||
if extension.lower() in ("jpg", "jpeg", "webp"):
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
|
||||
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
image.save(fullfn, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
piexif.insert(exif_bytes(), fullfn)
|
||||
else:
|
||||
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||
image.save(fullfn, quality=opts.jpeg_quality)
|
||||
|
||||
target_side_length = 4000
|
||||
oversize = image.width > target_side_length or image.height > target_side_length
|
||||
@@ -541,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
else:
|
||||
txt_fullfn = None
|
||||
|
||||
script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
|
||||
script_callbacks.image_saved_callback(params)
|
||||
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
|
@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
break
|
||||
|
||||
img = Image.open(image)
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
@@ -61,19 +63,25 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
is_batch = mode == 2
|
||||
|
||||
if is_inpaint:
|
||||
# Drawn mask
|
||||
if mask_mode == 0:
|
||||
image = init_img_with_mask['image']
|
||||
mask = init_img_with_mask['mask']
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert('RGB')
|
||||
# Uploaded mask
|
||||
else:
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
# No mask
|
||||
else:
|
||||
image = init_img
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
|
@@ -77,9 +77,8 @@ def get_correct_sampler(p):
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
@@ -109,13 +108,14 @@ class StableDiffusionProcessing():
|
||||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
||||
self.paste_to = None
|
||||
self.color_corrections = None
|
||||
self.denoising_strength: float = 0
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = opts.ddim_discretize
|
||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
self.s_noise = s_noise or opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
@@ -129,7 +129,6 @@ class StableDiffusionProcessing():
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
pass
|
||||
|
||||
@@ -351,6 +350,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
|
||||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
for k, v in stored_opts.items():
|
||||
opts.data[k] = v
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
if type(p.prompt) == list:
|
||||
|
@@ -9,15 +9,34 @@ def report_exception(c, job):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callbacks_model_loaded = []
|
||||
callbacks_ui_tabs = []
|
||||
callbacks_ui_settings = []
|
||||
callbacks_before_image_saved = []
|
||||
callbacks_image_saved = []
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
callbacks_model_loaded.clear()
|
||||
callbacks_ui_tabs.clear()
|
||||
callbacks_ui_settings.clear()
|
||||
callbacks_before_image_saved.clear()
|
||||
callbacks_image_saved.clear()
|
||||
|
||||
|
||||
@@ -49,10 +68,18 @@ def ui_settings_callback():
|
||||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def image_saved_callback(image, p, fullfn, txt_fullfn):
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
try:
|
||||
c.callback(image, p, fullfn, txt_fullfn)
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
@@ -64,7 +91,6 @@ def add_callback(callbacks, fun):
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
@@ -90,11 +116,17 @@ def on_ui_settings(callback):
|
||||
add_callback(callbacks_ui_settings, callback)
|
||||
|
||||
|
||||
def on_save_imaged(callback):
|
||||
"""register a function to be called after modules.images.save_image is called.
|
||||
The callback is called with three arguments:
|
||||
- p - procesing object (or a dummy object with same fields if the image is saved using save button)
|
||||
- fullfn - image filename
|
||||
- txt_fullfn - text file with parameters; may be None
|
||||
def on_before_image_saved(callback):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callbacks_before_image_saved, callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callbacks_image_saved, callback)
|
||||
|
@@ -84,7 +84,7 @@ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load mod
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
restricted_opts = [
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
"outdir_samples",
|
||||
@@ -94,7 +94,7 @@ restricted_opts = [
|
||||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
]
|
||||
}
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
|
||||
|
341
modules/textual_inversion/autocrop.py
Normal file
341
modules/textual_inversion/autocrop.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import cv2
|
||||
import requests
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from math import log, sqrt
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results = []
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
||||
|
||||
pois = []
|
||||
|
||||
weight_pref_total = 0
|
||||
if len(corner_points) > 0:
|
||||
weight_pref_total += settings.corner_points_weight
|
||||
if len(entropy_points) > 0:
|
||||
weight_pref_total += settings.entropy_points_weight
|
||||
if len(face_points) > 0:
|
||||
weight_pref_total += settings.face_points_weight
|
||||
|
||||
corner_centroid = None
|
||||
if len(corner_points) > 0:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
if len(entropy_points) > 0:
|
||||
entropy_centroid = centroid(entropy_points)
|
||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||
pois.append(entropy_centroid)
|
||||
|
||||
face_centroid = None
|
||||
if len(face_points) > 0:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im)
|
||||
max_size = min(im.width, im.height) * 0.07
|
||||
if corner_centroid is not None:
|
||||
color = BLUE
|
||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(corner_points) > 1:
|
||||
for f in corner_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if entropy_centroid is not None:
|
||||
color = "#ff0"
|
||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(entropy_points) > 1:
|
||||
for f in entropy_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if face_centroid is not None:
|
||||
color = RED
|
||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(face_points) > 1:
|
||||
for f in face_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
def image_face_points(im, settings):
|
||||
if settings.dnn_model_path is not None:
|
||||
detector = cv2.FaceDetectorYN.create(
|
||||
settings.dnn_model_path,
|
||||
"",
|
||||
(im.width, im.height),
|
||||
0.9, # score threshold
|
||||
0.3, # nms threshold
|
||||
5000 # keep top k before nms
|
||||
)
|
||||
faces = detector.detect(np.array(im))
|
||||
results = []
|
||||
if faces[1] is not None:
|
||||
for face in faces[1]:
|
||||
x = face[0]
|
||||
y = face[1]
|
||||
w = face[2]
|
||||
h = face[3]
|
||||
results.append(
|
||||
PointOfInterest(
|
||||
int(x + (w * 0.5)), # face focus left/right is center
|
||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||
size = w,
|
||||
weight = 1/len(faces[1])
|
||||
)
|
||||
)
|
||||
return results
|
||||
else:
|
||||
np_im = np.array(im)
|
||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
tries = [
|
||||
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
|
||||
]
|
||||
for t in tries:
|
||||
classifier = cv2.CascadeClassifier(t[0])
|
||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except:
|
||||
continue
|
||||
|
||||
if len(faces) > 0:
|
||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
||||
return []
|
||||
|
||||
|
||||
def image_corner_points(im, settings):
|
||||
grayscale = im.convert("L")
|
||||
|
||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||
gd = ImageDraw.Draw(grayscale)
|
||||
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
|
||||
|
||||
np_im = np.array(grayscale)
|
||||
|
||||
points = cv2.goodFeaturesToTrack(
|
||||
np_im,
|
||||
maxCorners=100,
|
||||
qualityLevel=0.04,
|
||||
minDistance=min(grayscale.width, grayscale.height)*0.06,
|
||||
useHarrisDetector=False,
|
||||
)
|
||||
|
||||
if points is None:
|
||||
return []
|
||||
|
||||
focal_points = []
|
||||
for point in points:
|
||||
x, y = point.ravel()
|
||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
|
||||
|
||||
return focal_points
|
||||
|
||||
|
||||
def image_entropy_points(im, settings):
|
||||
landscape = im.height < im.width
|
||||
portrait = im.height > im.width
|
||||
if landscape:
|
||||
move_idx = [0, 2]
|
||||
move_max = im.size[0]
|
||||
elif portrait:
|
||||
move_idx = [1, 3]
|
||||
move_max = im.size[1]
|
||||
else:
|
||||
return []
|
||||
|
||||
e_max = 0
|
||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||
crop_best = crop_current
|
||||
while crop_current[move_idx[1]] < move_max:
|
||||
crop = im.crop(tuple(crop_current))
|
||||
e = image_entropy(crop)
|
||||
|
||||
if (e > e_max):
|
||||
e_max = e
|
||||
crop_best = list(crop_current)
|
||||
|
||||
crop_current[move_idx[0]] += 4
|
||||
crop_current[move_idx[1]] += 4
|
||||
|
||||
x_mid = int(crop_best[0] + settings.crop_width/2)
|
||||
y_mid = int(crop_best[1] + settings.crop_height/2)
|
||||
|
||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||
|
||||
|
||||
def image_entropy(im):
|
||||
# greyscale image entropy
|
||||
# band = np.asarray(im.convert("L"))
|
||||
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
||||
hist, _ = np.histogram(band, bins=range(0, 256))
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
weight = 0.0
|
||||
x = 0.0
|
||||
y = 0.0
|
||||
for poi in pois:
|
||||
weight += poi.weight
|
||||
x += poi.x * poi.weight
|
||||
y += poi.y * poi.weight
|
||||
avg_x = round(x / weight)
|
||||
avg_y = round(y / weight)
|
||||
|
||||
return PointOfInterest(avg_x, avg_y)
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
|
||||
|
||||
def download_and_cache_models(dirname):
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size//2,
|
||||
self.y - size//2,
|
||||
self.x + size//2,
|
||||
self.y + size//2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
@@ -7,12 +7,14 @@ import tqdm
|
||||
import time
|
||||
|
||||
from modules import shared, images
|
||||
from modules.paths import models_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
if cmd_opts.deepdanbooru:
|
||||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
|
||||
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
|
||||
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
subindex = [0]
|
||||
filename = os.path.join(src, imagefile)
|
||||
@@ -137,11 +140,36 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
ratio = (img.height * width) / (img.width * height)
|
||||
inverse_xy = True
|
||||
|
||||
process_default_resize = True
|
||||
|
||||
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
||||
for splitted in split_pic(img, inverse_xy):
|
||||
save_pic(splitted, index, existing_caption=existing_caption)
|
||||
else:
|
||||
process_default_resize = False
|
||||
|
||||
if process_focal_crop and img.height != img.width:
|
||||
|
||||
dnn_model_path = None
|
||||
try:
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
|
||||
except Exception as e:
|
||||
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
||||
|
||||
autocrop_settings = autocrop.Settings(
|
||||
crop_width = width,
|
||||
crop_height = height,
|
||||
face_points_weight = process_focal_crop_face_weight,
|
||||
entropy_points_weight = process_focal_crop_entropy_weight,
|
||||
corner_points_weight = process_focal_crop_edges_weight,
|
||||
annotate_image = process_focal_crop_debug,
|
||||
dnn_model_path = dnn_model_path,
|
||||
)
|
||||
for focal in autocrop.crop_image(img, autocrop_settings):
|
||||
save_pic(focal, index, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index, existing_caption=existing_caption)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.nextjob()
|
@@ -10,7 +10,7 @@ import csv
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||
|
||||
with devices.autocast():
|
||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||
|
||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
@@ -164,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
@@ -244,6 +249,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
@@ -283,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding.name = f'{embedding_name}-{embedding.step}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
|
||||
embedding.save(last_saved_file)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
@@ -293,8 +301,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
})
|
||||
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
forced_filename = f'{embedding_name}-{embedding.step}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
@@ -350,8 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
@@ -371,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
embedding.cached_checksum = None
|
||||
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
|
||||
embedding.name = embedding_name
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
|
||||
embedding.save(filename)
|
||||
|
||||
return embedding, filename
|
||||
|
@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
|
||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||
@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Row():
|
||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||
process_split = gr.Checkbox(label='Split oversized images')
|
||||
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
|
||||
process_caption = gr.Checkbox(label='Use BLIP for caption')
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
|
||||
|
||||
@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
|
||||
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
|
||||
|
||||
with gr.Row(visible=False) as process_focal_crop_row:
|
||||
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
|
||||
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
|
||||
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
|
||||
process_focal_crop_debug = gr.Checkbox(label='Create debug image')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
outputs=[process_split_extra_row],
|
||||
)
|
||||
|
||||
process_focal_crop.change(
|
||||
fn=lambda show: gr_show(show),
|
||||
inputs=[process_focal_crop],
|
||||
outputs=[process_focal_crop_row],
|
||||
)
|
||||
|
||||
with gr.Tab(label="Train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
with gr.Row():
|
||||
@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
overwrite_old_hypernetwork,
|
||||
new_hypernetwork_layer_structure,
|
||||
new_hypernetwork_activation_func,
|
||||
new_hypernetwork_initialization_option,
|
||||
new_hypernetwork_add_layer_norm,
|
||||
new_hypernetwork_use_dropout
|
||||
],
|
||||
@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
process_caption_deepbooru,
|
||||
process_split_threshold,
|
||||
process_overlap_ratio,
|
||||
process_focal_crop,
|
||||
process_focal_crop_face_weight,
|
||||
process_focal_crop_entropy_weight,
|
||||
process_focal_crop_edges_weight,
|
||||
process_focal_crop_debug,
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
|
Reference in New Issue
Block a user