provide sampler by name

This commit is contained in:
arcticfaded
2022-10-18 19:04:56 +00:00
parent 8d5d863a9d
commit e7f4808505
2 changed files with 24 additions and 4 deletions

View File

@@ -42,7 +42,8 @@ class PydanticModelGenerator:
def __init__(
self,
model_name: str = None,
class_instance = None
class_instance = None,
additional_fields = None,
):
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
@@ -70,6 +71,13 @@ class PydanticModelGenerator:
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"]))
def generate_model(self):
"""
@@ -84,4 +92,8 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True
return DynamicModel
StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
).generate_model()