Merge remote-tracking branch 'Melanpan/master'

This commit is contained in:
AUTOMATIC
2022-10-14 22:14:50 +03:00
3 changed files with 34 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
import csv
from PIL import Image, PngImagePlugin
@@ -256,6 +257,21 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
if write_csv_header:
csv_writer.writeheader()
csv_writer.writerow({"epoch": epoch_num + 1,
"epoch_step": epoch_step - 1,
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')