mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 21:30:27 +00:00
let user choose his own prompt token count limit
This commit is contained in:
@@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
|
||||
@@ -83,7 +82,7 @@ class StableDiffusionModelHijack:
|
||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||
|
||||
def tokenize(self, text):
|
||||
max_length = self.clip.max_length - 2
|
||||
max_length = opts.max_prompt_tokens - 2
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, max_length
|
||||
|
||||
@@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
self.wrapped = wrapped
|
||||
self.hijack: StableDiffusionModelHijack = hijack
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.max_length = wrapped.max_length
|
||||
self.token_mults = {}
|
||||
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
@@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length
|
||||
maxlen = opts.max_prompt_tokens
|
||||
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
@@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
overflowing_words = []
|
||||
@@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76]
|
||||
position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
|
Reference in New Issue
Block a user