mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-05 03:32:37 +00:00
Save Optimizer next to TI embedding
Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
This commit is contained in:
@@ -28,6 +28,7 @@ class Embedding:
|
||||
self.cached_checksum = None
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
@@ -41,6 +42,13 @@ class Embedding:
|
||||
|
||||
torch.save(embedding_data, filename)
|
||||
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
||||
optimizer_saved_dict = {
|
||||
'hash': self.checksum(),
|
||||
'optimizer_state_dict': self.optimizer_state_dict,
|
||||
}
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def checksum(self):
|
||||
if self.cached_checksum is not None:
|
||||
return self.cached_checksum
|
||||
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
@@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
if shared.opts.save_optimizer_state:
|
||||
optimizer_state_dict = None
|
||||
if os.path.exists(filename + '.optim'):
|
||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
@@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||
#if shared.opts.save_optimizer_state:
|
||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||
@@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
@@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
|
||||
return embedding, filename
|
||||
|
||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
@@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
embedding.name = embedding_name
|
||||
embedding.optimizer_state_dict = optimizer.state_dict()
|
||||
embedding.save(filename)
|
||||
except:
|
||||
embedding.sd_checkpoint = old_sd_checkpoint
|
||||
|
Reference in New Issue
Block a user