mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Add support for checkpoint merging
This commit is contained in:
@@ -3,6 +3,8 @@ import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
from modules import processing, shared, images, devices
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
@@ -135,3 +137,25 @@ def run_pnginfo(image):
|
||||
info = f"<div><p>{message}<p></div>"
|
||||
|
||||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(modelname_0, modelname_1, alpha):
|
||||
model_0 = torch.load('models/' + modelname_0 + '.ckpt')
|
||||
model_1 = torch.load('models/' + modelname_1 + '.ckpt')
|
||||
|
||||
theta_0 = model_0['state_dict']
|
||||
theta_1 = model_1['state_dict']
|
||||
|
||||
for key in theta_0.keys():
|
||||
if 'model' in key and key in theta_1:
|
||||
theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key]
|
||||
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
theta_0[key] = theta_1[key]
|
||||
|
||||
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt';
|
||||
|
||||
torch.save(model_0, output_modelname)
|
||||
|
||||
return "<p>Model saved to " + output_modelname + "</p>"
|
||||
|
Reference in New Issue
Block a user