mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-11 03:09:47 +00:00
Merge branch 'master' into embed-embeddings-in-images
This commit is contained in:
@@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
def randn(seed, shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
@@ -59,9 +60,12 @@ def randn_without_seed(shape):
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def autocast():
|
||||
def autocast(disable=False):
|
||||
from modules import shared
|
||||
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
@@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
return x
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||
x = model.decode_first_stage(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_fixed_seed(seed):
|
||||
if seed is None or seed == '' or seed == -1:
|
||||
return int(random.randrange(4294967294))
|
||||
@@ -398,9 +405,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
# use the image collected previously in sampler loop
|
||||
samples_ddim = shared.state.current_latent
|
||||
|
||||
samples_ddim = samples_ddim.to(devices.dtype)
|
||||
|
||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
del samples_ddim
|
||||
@@ -533,7 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if self.scale_latent:
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
else:
|
||||
decoded_samples = self.sd_model.decode_first_stage(samples)
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||
|
||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
||||
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
||||
|
@@ -12,6 +12,10 @@ import _codecs
|
||||
import zipfile
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
@@ -20,7 +24,7 @@ def encode(*args):
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return torch.storage._TypedStorage()
|
||||
return TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
|
@@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info):
|
||||
model.half()
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
||||
if os.path.exists(vae_file):
|
||||
@@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info):
|
||||
|
||||
model.first_stage_model.load_state_dict(vae_dict)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
@@ -7,7 +7,7 @@ import inspect
|
||||
import k_diffusion.sampling
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser
|
||||
from modules import prompt_parser, devices, processing
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
|
||||
|
||||
|
||||
def sample_to_image(samples):
|
||||
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
@@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
|
@@ -15,11 +15,10 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.size = size
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
@@ -7,8 +7,9 @@ import tqdm
|
||||
from modules import shared, images
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
|
||||
size = 512
|
||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
dst = os.path.abspath(process_dst)
|
||||
|
||||
@@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
|
||||
is_wide = ratio < 1 / 1.35
|
||||
|
||||
if process_split and is_tall:
|
||||
img = img.resize((size, size * img.height // img.width))
|
||||
img = img.resize((width, height * img.height // img.width))
|
||||
|
||||
top = img.crop((0, 0, size, size))
|
||||
top = img.crop((0, 0, width, height))
|
||||
save_pic(top, index)
|
||||
|
||||
bot = img.crop((0, img.height - size, size, img.height))
|
||||
bot = img.crop((0, img.height - height, width, img.height))
|
||||
save_pic(bot, index)
|
||||
elif process_split and is_wide:
|
||||
img = img.resize((size * img.width // img.height, size))
|
||||
img = img.resize((width * img.width // img.height, height))
|
||||
|
||||
left = img.crop((0, 0, size, size))
|
||||
left = img.crop((0, 0, width, height))
|
||||
save_pic(left, index)
|
||||
|
||||
right = img.crop((img.width - size, 0, img.width, size))
|
||||
right = img.crop((img.width - width, 0, img.width, height))
|
||||
save_pic(right, index)
|
||||
else:
|
||||
img = images.resize_image(1, img, size, size)
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
@@ -190,7 +190,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
@@ -222,7 +222,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
@@ -240,6 +240,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text) in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
@@ -263,7 +266,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||
epoch_num = embedding.step // epoch_len
|
||||
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
@@ -276,6 +282,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
sd_model=shared.sd_model,
|
||||
prompt=text,
|
||||
steps=20,
|
||||
height=training_height,
|
||||
width=training_width,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
@@ -1029,6 +1029,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
process_src = gr.Textbox(label='Source directory')
|
||||
process_dst = gr.Textbox(label='Destination directory')
|
||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
|
||||
with gr.Row():
|
||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||
@@ -1043,13 +1045,16 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||
|
||||
with gr.Group():
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||
@@ -1093,6 +1098,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
inputs=[
|
||||
process_src,
|
||||
process_dst,
|
||||
process_width,
|
||||
process_height,
|
||||
process_flip,
|
||||
process_split,
|
||||
process_caption,
|
||||
@@ -1111,7 +1118,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
learn_rate,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
num_repeats,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
|
Reference in New Issue
Block a user