Add job argument to State.begin()

This commit is contained in:
Aarni Koskela
2023-06-30 13:11:31 +03:00
parent b70001e618
commit f44feb6a10
6 changed files with 13 additions and 16 deletions

View File

@@ -327,7 +327,7 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin()
shared.state.begin(job="scripts_txt2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
@@ -384,7 +384,7 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin()
shared.state.begin(job="scripts_img2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
@@ -599,7 +599,7 @@ class Api:
def create_embedding(self, args: dict):
try:
shared.state.begin()
shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
@@ -610,7 +610,7 @@ class Api:
def create_hypernetwork(self, args: dict):
try:
shared.state.begin()
shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
@@ -620,7 +620,7 @@ class Api:
def preprocess(self, args: dict):
try:
shared.state.begin()
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info = 'preprocess complete')
@@ -636,7 +636,7 @@ class Api:
def train_embedding(self, args: dict):
try:
shared.state.begin()
shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -657,7 +657,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None