mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-10 18:59:49 +00:00
Merge branch 'master' into test_resolve_conflicts
This commit is contained in:
@@ -91,7 +91,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
|
||||
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
|
||||
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
@@ -175,11 +176,14 @@ def run_pnginfo(image):
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
|
||||
def weighted_sum(theta0, theta1, theta2, alpha):
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
def add_difference(theta0, theta1, theta2, alpha):
|
||||
return theta0 + (theta1 - theta2) * alpha
|
||||
def get_difference(theta1, theta2):
|
||||
return theta1 - theta2
|
||||
|
||||
def add_difference(theta0, theta1_2_diff, alpha):
|
||||
return theta0 + (alpha * theta1_2_diff)
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||
@@ -198,23 +202,28 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
|
||||
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
|
||||
else:
|
||||
teritary_model = None
|
||||
theta_2 = None
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": weighted_sum,
|
||||
"Add difference": add_difference,
|
||||
"Weighted sum": (None, weighted_sum),
|
||||
"Add difference": (get_difference, add_difference),
|
||||
}
|
||||
theta_func = theta_funcs[interp_method]
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
|
||||
print(f"Merging...")
|
||||
|
||||
if theta_func1:
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if 'model' in key:
|
||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
del theta_2, teritary_model
|
||||
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
t2 = (theta_2 or {}).get(key)
|
||||
if t2 is None:
|
||||
t2 = torch.zeros_like(theta_0[key])
|
||||
|
||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
|
||||
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
|
||||
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
@@ -123,7 +123,7 @@ class InterrogateModels:
|
||||
|
||||
return caption[0]
|
||||
|
||||
def interrogate(self, pil_image, include_ranks=False):
|
||||
def interrogate(self, pil_image):
|
||||
res = None
|
||||
|
||||
try:
|
||||
@@ -156,10 +156,10 @@ class InterrogateModels:
|
||||
for name, topn, items in self.categories:
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for match, score in matches:
|
||||
if include_ranks:
|
||||
res += ", " + match
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
res += f", ({match}:{score/100:.3f})"
|
||||
else:
|
||||
res += f", ({match}:{score})"
|
||||
res += ", " + match
|
||||
|
||||
except Exception:
|
||||
print(f"Error interrogating", file=sys.stderr)
|
||||
|
31
modules/localization.py
Normal file
31
modules/localization.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
localizations = {}
|
||||
|
||||
|
||||
def list_localizations(dirname):
|
||||
localizations.clear()
|
||||
|
||||
for file in os.listdir(dirname):
|
||||
fn, ext = os.path.splitext(file)
|
||||
if ext.lower() != ".json":
|
||||
continue
|
||||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
|
||||
|
||||
def localization_js(current_localization_name):
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
data = {}
|
||||
if fn is not None:
|
||||
try:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
data = json.load(file)
|
||||
except Exception:
|
||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return f"var localization = {json.dumps(data)}\n"
|
@@ -58,6 +58,9 @@ def load_scripts(basedir):
|
||||
for filename in sorted(os.listdir(basedir)):
|
||||
path = os.path.join(basedir, filename)
|
||||
|
||||
if os.path.splitext(path)[1].lower() != '.py':
|
||||
continue
|
||||
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
|
||||
|
@@ -14,7 +14,7 @@ import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers, sd_models
|
||||
from modules import sd_samplers, sd_models, localization
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
@@ -34,6 +34,7 @@ parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_pa
|
||||
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'),
|
||||
help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
@@ -106,6 +107,7 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
aesthetic_embeddings = {}
|
||||
|
||||
def update_aesthetic_embeddings():
|
||||
@@ -116,6 +118,7 @@ def update_aesthetic_embeddings():
|
||||
|
||||
update_aesthetic_embeddings()
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
global hypernetworks
|
||||
|
||||
@@ -163,6 +166,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
@@ -308,6 +313,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
|
@@ -137,6 +137,7 @@ class EmbeddingDatabase:
|
||||
continue
|
||||
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
|
@@ -23,7 +23,7 @@ import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack, sd_models
|
||||
from modules import sd_hijack, sd_models, localization
|
||||
from modules.paths import script_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
|
||||
@@ -1102,10 +1102,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
|
||||
|
||||
with gr.Group():
|
||||
@@ -1282,10 +1282,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Tab(label="Train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||
with gr.Row():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
with gr.Row():
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
@@ -1452,16 +1452,18 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
else:
|
||||
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
||||
|
||||
elem_id = "setting_"+key
|
||||
|
||||
if info.refresh is not None:
|
||||
if is_quicksettings:
|
||||
res = comp(label=info.label, value=fun, **(args or {}))
|
||||
refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
with gr.Row(variant="compact"):
|
||||
res = comp(label=info.label, value=fun, **(args or {}))
|
||||
refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
res = comp(label=info.label, value=fun, **(args or {}))
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
|
||||
|
||||
return res
|
||||
@@ -1585,6 +1587,9 @@ Requested path was: {f}
|
||||
|
||||
with gr.Row():
|
||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
|
||||
with gr.Row():
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
||||
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
||||
|
||||
@@ -1595,6 +1600,13 @@ Requested path was: {f}
|
||||
_js='function(){}'
|
||||
)
|
||||
|
||||
download_localization.click(
|
||||
fn=lambda: None,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='download_localization'
|
||||
)
|
||||
|
||||
def reload_scripts():
|
||||
modules.scripts.reload_script_body_only()
|
||||
|
||||
@@ -1843,6 +1855,7 @@ Requested path was: {f}
|
||||
visit(txt2img_interface, loadsave, "txt2img")
|
||||
visit(img2img_interface, loadsave, "img2img")
|
||||
visit(extras_interface, loadsave, "extras")
|
||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||
|
||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||
@@ -1859,6 +1872,7 @@ for filename in sorted(os.listdir(jsdir)):
|
||||
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
||||
javascript += f"\n<script>{jsfile.read()}</script>"
|
||||
|
||||
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
|
||||
|
||||
if 'gradio_routes_templates_response' not in globals():
|
||||
def template_response(*args, **kwargs):
|
||||
|
Reference in New Issue
Block a user