Gradient accumulation, autocast fix, new latent sampling method, etc

This commit is contained in:
flamelaw
2022-11-20 12:35:26 +09:00
parent 47a44c7e42
commit bd68e35de3
7 changed files with 408 additions and 273 deletions

View File

@@ -1262,7 +1262,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary')
run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change(
fn=lambda show: gr_show(show),
@@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
@@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name,
embedding_learn_rate,
batch_size,
gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
@@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,