mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
support scheduler selection in hires fix
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
|
||||
import functools
|
||||
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
|
||||
@@ -64,4 +66,60 @@ def visible_samplers():
|
||||
return [x for x in samplers if x.name not in samplers_hidden]
|
||||
|
||||
|
||||
def get_sampler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||
|
||||
|
||||
def get_scheduler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||
|
||||
|
||||
def get_hr_sampler_and_scheduler(d: dict):
|
||||
hr_sampler = d.get("Hires sampler", "Use same sampler")
|
||||
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
|
||||
|
||||
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
|
||||
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
|
||||
|
||||
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
|
||||
|
||||
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
|
||||
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
|
||||
|
||||
return sampler, scheduler
|
||||
|
||||
|
||||
def get_hr_sampler_from_infotext(d: dict):
|
||||
return get_hr_sampler_and_scheduler(d)[0]
|
||||
|
||||
|
||||
def get_hr_scheduler_from_infotext(d: dict):
|
||||
return get_hr_sampler_and_scheduler(d)[1]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
||||
default_sampler = samplers[0]
|
||||
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||
|
||||
name = sampler_name or default_sampler.name
|
||||
|
||||
for scheduler in sd_schedulers.schedulers:
|
||||
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||
|
||||
for name_option in name_options:
|
||||
if name.endswith(" " + name_option):
|
||||
found_scheduler = scheduler
|
||||
name = name[0:-(len(name_option) + 1)]
|
||||
break
|
||||
|
||||
sampler = all_samplers_map.get(name, default_sampler)
|
||||
|
||||
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||
found_scheduler = sd_schedulers.schedulers[0]
|
||||
|
||||
return sampler.name, found_scheduler.label
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
Reference in New Issue
Block a user