hypernetwork training mk1

This commit is contained in:
AUTOMATIC
2022-10-07 23:22:22 +03:00
parent f7c787eb7c
commit 12c4d5c6b5
12 changed files with 414 additions and 107 deletions

View File

@@ -37,6 +37,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetwork.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
@@ -965,6 +966,18 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
new_hypernetwork_name = gr.Textbox(label="Name")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_hypernetwork = gr.Button(value="Create", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
@@ -986,6 +999,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -993,15 +1007,12 @@ def create_ui(wrap_gradio_gpu_call):
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
with gr.Row():
with gr.Column(scale=2):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_embedding = gr.Button(value="Train", variant='primary')
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1027,6 +1038,18 @@ def create_ui(wrap_gradio_gpu_call):
]
)
create_hypernetwork.click(
fn=modules.hypernetwork.ui.create_hypernetwork,
inputs=[
new_hypernetwork_name,
],
outputs=[
train_hypernetwork_name,
ti_output,
ti_outcome,
]
)
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
@@ -1062,12 +1085,33 @@ def create_ui(wrap_gradio_gpu_call):
]
)
train_hypernetwork.click(
fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
learn_rate,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
preview_image_prompt,
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default