change StableDiffusionProcessing to internally use sampler name instead of sampler index

This commit is contained in:
AUTOMATIC
2022-11-19 12:01:51 +03:00
parent d9fd4525a5
commit cdc8020d13
11 changed files with 49 additions and 54 deletions

View File

@@ -6,9 +6,9 @@ from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
@@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
return name
def setUpscalers(req: dict):
reqDict = vars(req)
@@ -82,14 +86,9 @@ class Api:
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
@@ -109,12 +108,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -123,10 +116,9 @@ class Api:
if mask:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
@@ -272,7 +264,7 @@ class Api:
return vars(shared.cmd_opts)
def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
upscalers = []