Merge branch 'master' into api-authorization

This commit is contained in:
Maiko Tan
2022-11-19 20:13:07 +08:00
28 changed files with 193 additions and 150 deletions

View File

@@ -9,9 +9,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
@@ -28,8 +28,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
return name
def setUpscalers(req: dict):
reqDict = vars(req)
@@ -77,6 +81,7 @@ class Api:
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
@@ -103,14 +108,9 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
@@ -130,12 +130,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -144,10 +138,9 @@ class Api:
if mask:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
@@ -266,6 +259,9 @@ class Api:
return {}
def skip(self):
shared.state.skip()
def get_config(self):
options = {}
for key in shared.opts.data.keys():
@@ -277,14 +273,10 @@ class Api:
return options
def set_config(self, req: OptionsModel):
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
# overwrite all options with default values.
raise RuntimeError('Setting options via API is not supported')
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
def set_config(self, req: Dict[str, Any]):
for o in req:
setattr(shared.opts, o, req[o])
shared.opts.save(shared.config_filename)
return
@@ -293,7 +285,7 @@ class Api:
return vars(shared.cmd_opts)
def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
upscalers = []