mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
change StableDiffusionProcessing to internally use sampler name instead of sampler index
This commit is contained in:
@@ -6,9 +6,9 @@ from threading import Lock
|
||||
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
from PIL import PngImagePlugin
|
||||
from modules.sd_models import checkpoints_list
|
||||
@@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
return name
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
@@ -82,14 +86,9 @@ class Api:
|
||||
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
}
|
||||
@@ -109,12 +108,6 @@ class Api:
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(img2imgreq.sampler_index)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
@@ -123,10 +116,9 @@ class Api:
|
||||
if mask:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
@@ -272,7 +264,7 @@ class Api:
|
||||
return vars(shared.cmd_opts)
|
||||
|
||||
def get_samplers(self):
|
||||
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
|
||||
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_upscalers(self):
|
||||
upscalers = []
|
||||
|
Reference in New Issue
Block a user