123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- import math
- from collections import namedtuple
- import torch
- from modules import prompt_parser, devices, sd_hijack
- from modules.shared import opts
- class PromptChunk:
- """
- This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
- If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
- Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
- so just 75 tokens from prompt.
- """
- def __init__(self):
- self.tokens = []
- self.multipliers = []
- self.fixes = []
- PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
- """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
- chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
- are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
- class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
- """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
- have unlimited prompt length and assign weights to tokens in prompt.
- """
- def __init__(self, wrapped, hijack):
- super().__init__()
- self.wrapped = wrapped
- """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
- depending on model."""
- self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
- self.chunk_length = 75
- self.is_trainable = getattr(wrapped, 'is_trainable', False)
- self.input_key = getattr(wrapped, 'input_key', 'txt')
- self.legacy_ucg_val = None
- def empty_chunk(self):
- """creates an empty PromptChunk and returns it"""
- chunk = PromptChunk()
- chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
- chunk.multipliers = [1.0] * (self.chunk_length + 2)
- return chunk
- def get_target_prompt_token_count(self, token_count):
- """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
- return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
- def tokenize(self, texts):
- """Converts a batch of texts into a batch of token ids"""
- raise NotImplementedError
- def encode_with_transformers(self, tokens):
- """
- converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
- All python lists with tokens are assumed to have same length, usually 77.
- if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
- model - can be 768 and 1024.
- Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
- """
- raise NotImplementedError
- def encode_embedding_init_text(self, init_text, nvpt):
- """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
- transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
- raise NotImplementedError
- def tokenize_line(self, line):
- """
- this transforms a single prompt into a list of PromptChunk objects - as many as needed to
- represent the prompt.
- Returns the list and the total number of tokens in the prompt.
- """
- if opts.enable_emphasis:
- parsed = prompt_parser.parse_prompt_attention(line)
- else:
- parsed = [[line, 1.0]]
- tokenized = self.tokenize([text for text, _ in parsed])
- chunks = []
- chunk = PromptChunk()
- token_count = 0
- last_comma = -1
- def next_chunk(is_last=False):
- """puts current chunk into the list of results and produces the next one - empty;
- if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
- nonlocal token_count
- nonlocal last_comma
- nonlocal chunk
- if is_last:
- token_count += len(chunk.tokens)
- else:
- token_count += self.chunk_length
- to_add = self.chunk_length - len(chunk.tokens)
- if to_add > 0:
- chunk.tokens += [self.id_end] * to_add
- chunk.multipliers += [1.0] * to_add
- chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
- chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
- last_comma = -1
- chunks.append(chunk)
- chunk = PromptChunk()
- for tokens, (text, weight) in zip(tokenized, parsed):
- if text == 'BREAK' and weight == -1:
- next_chunk()
- continue
- position = 0
- while position < len(tokens):
- token = tokens[position]
- if token == self.comma_token:
- last_comma = len(chunk.tokens)
- # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
- # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
- elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
- break_location = last_comma + 1
- reloc_tokens = chunk.tokens[break_location:]
- reloc_mults = chunk.multipliers[break_location:]
- chunk.tokens = chunk.tokens[:break_location]
- chunk.multipliers = chunk.multipliers[:break_location]
- next_chunk()
- chunk.tokens = reloc_tokens
- chunk.multipliers = reloc_mults
- if len(chunk.tokens) == self.chunk_length:
- next_chunk()
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
- if embedding is None:
- chunk.tokens.append(token)
- chunk.multipliers.append(weight)
- position += 1
- continue
- emb_len = int(embedding.vec.shape[0])
- if len(chunk.tokens) + emb_len > self.chunk_length:
- next_chunk()
- chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
- chunk.tokens += [0] * emb_len
- chunk.multipliers += [weight] * emb_len
- position += embedding_length_in_tokens
- if chunk.tokens or not chunks:
- next_chunk(is_last=True)
- return chunks, token_count
- def process_texts(self, texts):
- """
- Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
- length, in tokens, of all texts.
- """
- token_count = 0
- cache = {}
- batch_chunks = []
- for line in texts:
- if line in cache:
- chunks = cache[line]
- else:
- chunks, current_token_count = self.tokenize_line(line)
- token_count = max(current_token_count, token_count)
- cache[line] = chunks
- batch_chunks.append(chunks)
- return batch_chunks, token_count
- def forward(self, texts):
- """
- Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
- Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
- An example shape returned by this function can be: (2, 77, 768).
- For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
- Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
- is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
- """
- if opts.use_old_emphasis_implementation:
- import modules.sd_hijack_clip_old
- return modules.sd_hijack_clip_old.forward_old(self, texts)
- batch_chunks, token_count = self.process_texts(texts)
- used_embeddings = {}
- chunk_count = max([len(x) for x in batch_chunks])
- zs = []
- for i in range(chunk_count):
- batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
- tokens = [x.tokens for x in batch_chunk]
- multipliers = [x.multipliers for x in batch_chunk]
- self.hijack.fixes = [x.fixes for x in batch_chunk]
- for fixes in self.hijack.fixes:
- for _position, embedding in fixes:
- used_embeddings[embedding.name] = embedding
- z = self.process_tokens(tokens, multipliers)
- zs.append(z)
- if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
- hashes = []
- for name, embedding in used_embeddings.items():
- shorthash = embedding.shorthash
- if not shorthash:
- continue
- name = name.replace(":", "").replace(",", "")
- hashes.append(f"{name}: {shorthash}")
- if hashes:
- self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
- if getattr(self.wrapped, 'return_pooled', False):
- return torch.hstack(zs), zs[0].pooled
- else:
- return torch.hstack(zs)
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- """
- sends one single prompt chunk to be encoded by transformers neural network.
- remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
- there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
- Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
- corresponds to one token.
- """
- tokens = torch.asarray(remade_batch_tokens).to(devices.device)
- # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
- if self.id_end != self.id_pad:
- for batch_pos in range(len(remade_batch_tokens)):
- index = remade_batch_tokens[batch_pos].index(self.id_end)
- tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
- z = self.encode_with_transformers(tokens)
- pooled = getattr(z, 'pooled', None)
- # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
- original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
- new_mean = z.mean()
- z = z * (original_mean / new_mean)
- if pooled is not None:
- z.pooled = pooled
- return z
- class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
- self.tokenizer = wrapped.tokenizer
- vocab = self.tokenizer.get_vocab()
- self.comma_token = vocab.get(',</w>', None)
- self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
- for text, ident in tokens_with_parens:
- mult = 1.0
- for c in text:
- if c == '[':
- mult /= 1.1
- if c == ']':
- mult *= 1.1
- if c == '(':
- mult *= 1.1
- if c == ')':
- mult /= 1.1
- if mult != 1.0:
- self.token_mults[ident] = mult
- self.id_start = self.wrapped.tokenizer.bos_token_id
- self.id_end = self.wrapped.tokenizer.eos_token_id
- self.id_pad = self.id_end
- def tokenize(self, texts):
- tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
- return tokenized
- def encode_with_transformers(self, tokens):
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- if opts.CLIP_stop_at_last_layers > 1:
- z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
- z = self.wrapped.transformer.text_model.final_layer_norm(z)
- else:
- z = outputs.last_hidden_state
- return z
- def encode_embedding_init_text(self, init_text, nvpt):
- embedding_layer = self.wrapped.transformer.text_model.embeddings
- ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
- return embedded
- class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
- def encode_with_transformers(self, tokens):
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
- if self.wrapped.layer == "last":
- z = outputs.last_hidden_state
- else:
- z = outputs.hidden_states[self.wrapped.layer_idx]
- return z
|