mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Merge branch 'dev' into feature/restore-progress
This commit is contained in:
85
webui.py
85
webui.py
@@ -5,6 +5,7 @@ import importlib
|
||||
import signal
|
||||
import re
|
||||
import warnings
|
||||
import json
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
@@ -20,6 +21,9 @@ startup_timer = timer.Timer()
|
||||
import torch
|
||||
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
|
||||
startup_timer.record("import torch")
|
||||
|
||||
import gradio
|
||||
@@ -37,7 +41,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
@@ -67,11 +71,51 @@ else:
|
||||
server_name = "0.0.0.0" if cmd_opts.listen else None
|
||||
|
||||
|
||||
def fix_asyncio_event_loop_policy():
|
||||
"""
|
||||
The default `asyncio` event loop policy only automatically creates
|
||||
event loops in the main threads. Other threads must create event
|
||||
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||
loops to be created automatically on any thread, matching the
|
||||
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
||||
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
||||
# interface for composing policies so pick the right base.
|
||||
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
||||
else:
|
||||
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
||||
|
||||
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
||||
"""Event loop policy that allows loop creation on any thread.
|
||||
Usage::
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
"""
|
||||
|
||||
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
return super().get_event_loop()
|
||||
except (RuntimeError, AssertionError):
|
||||
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||
# and changed to a RuntimeError in 3.4.3.
|
||||
# "There is no current event loop in thread %r"
|
||||
loop = self.new_event_loop()
|
||||
self.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
|
||||
|
||||
def check_versions():
|
||||
if shared.cmd_opts.skip_version_check:
|
||||
return
|
||||
|
||||
expected_torch_version = "1.13.1"
|
||||
expected_torch_version = "2.0.0"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
errors.print_error_explanation(f"""
|
||||
@@ -84,7 +128,7 @@ there are reports of issues with training tab on the latest version.
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
expected_xformers_version = "0.0.16rc425"
|
||||
expected_xformers_version = "0.0.17"
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
@@ -99,12 +143,27 @@ Use --skip-version-check commandline argument to disable this check.
|
||||
|
||||
|
||||
def initialize():
|
||||
fix_asyncio_event_loop_policy()
|
||||
|
||||
check_versions()
|
||||
|
||||
extensions.list_extensions()
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
startup_timer.record("list extensions")
|
||||
|
||||
config_state_file = shared.opts.restore_config_state_file
|
||||
shared.opts.restore_config_state_file = ""
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
if os.path.isfile(config_state_file):
|
||||
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||
config_state = json.load(f)
|
||||
config_states.restore_extension_config(config_state)
|
||||
startup_timer.record("restore extension config")
|
||||
elif config_state_file:
|
||||
print(f"!!! Config state backup not found: {config_state_file}")
|
||||
|
||||
if cmd_opts.ui_debug_mode:
|
||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||
modules.scripts.load_scripts()
|
||||
@@ -126,9 +185,6 @@ def initialize():
|
||||
modules.scripts.load_scripts()
|
||||
startup_timer.record("load scripts")
|
||||
|
||||
modelloader.load_upscalers()
|
||||
startup_timer.record("load upscalers")
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
startup_timer.record("refresh VAE")
|
||||
|
||||
@@ -150,6 +206,7 @@ def initialize():
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
@@ -212,6 +269,8 @@ def wait_on_server(demo=None):
|
||||
time.sleep(0.5)
|
||||
demo.close()
|
||||
time.sleep(0.5)
|
||||
|
||||
modules.script_callbacks.app_reload_callback()
|
||||
break
|
||||
|
||||
|
||||
@@ -260,6 +319,7 @@ def webui():
|
||||
server_port=cmd_opts.port,
|
||||
ssl_keyfile=cmd_opts.tls_keyfile,
|
||||
ssl_certfile=cmd_opts.tls_certfile,
|
||||
ssl_verify=cmd_opts.disable_tls_verify,
|
||||
debug=cmd_opts.gradio_debug,
|
||||
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
|
||||
inbrowser=cmd_opts.autolaunch,
|
||||
@@ -302,6 +362,19 @@ def webui():
|
||||
extensions.list_extensions()
|
||||
startup_timer.record("list extensions")
|
||||
|
||||
config_state_file = shared.opts.restore_config_state_file
|
||||
shared.opts.restore_config_state_file = ""
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
if os.path.isfile(config_state_file):
|
||||
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||
config_state = json.load(f)
|
||||
config_states.restore_extension_config(config_state)
|
||||
startup_timer.record("restore extension config")
|
||||
elif config_state_file:
|
||||
print(f"!!! Config state backup not found: {config_state_file}")
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
||||
|
Reference in New Issue
Block a user