mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
Merge remote-tracking branch 'Melanpan/master'
This commit is contained in:
@@ -5,6 +5,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
import tqdm
|
||||
import csv
|
||||
|
||||
import torch
|
||||
|
||||
@@ -262,6 +263,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||
hypernetwork.save(last_saved_file)
|
||||
|
||||
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
|
||||
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
|
||||
|
||||
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
|
||||
|
||||
csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
|
||||
|
||||
if write_csv_header:
|
||||
csv_writer.writeheader()
|
||||
|
||||
csv_writer.writerow({"step": hypernetwork.step,
|
||||
"loss": f"{losses.mean():.7f}",
|
||||
"learn_rate": scheduler.learn_rate})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
|
||||
|
Reference in New Issue
Block a user