tokenizer_ve.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Text Tokenizer.
  5. Copied and lightly adapted from VE repo, which in turn copied
  6. from open_clip and openAI CLIP.
  7. """
  8. import gzip
  9. import html
  10. import io
  11. import os
  12. import string
  13. from functools import lru_cache
  14. from typing import List, Optional, Union
  15. import ftfy
  16. import regex as re
  17. import torch
  18. from iopath.common.file_io import g_pathmgr
  19. # https://stackoverflow.com/q/62691279
  20. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  21. DEFAULT_CONTEXT_LENGTH = 77
  22. @lru_cache()
  23. def bytes_to_unicode():
  24. """
  25. Returns list of utf-8 byte and a corresponding list of unicode strings.
  26. The reversible bpe codes work on unicode strings.
  27. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  28. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  29. This is a significant percentage of your normal, say, 32K bpe vocab.
  30. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  31. And avoids mapping to whitespace/control characters the bpe code barfs on.
  32. """
  33. bs = (
  34. list(range(ord("!"), ord("~") + 1))
  35. + list(range(ord("¡"), ord("¬") + 1))
  36. + list(range(ord("®"), ord("ÿ") + 1))
  37. )
  38. cs = bs[:]
  39. n = 0
  40. for b in range(2**8):
  41. if b not in bs:
  42. bs.append(b)
  43. cs.append(2**8 + n)
  44. n += 1
  45. cs = [chr(n) for n in cs]
  46. return dict(zip(bs, cs))
  47. def get_pairs(word):
  48. """Return set of symbol pairs in a word.
  49. Word is represented as tuple of symbols (symbols being variable-length strings).
  50. """
  51. pairs = set()
  52. prev_char = word[0]
  53. for char in word[1:]:
  54. pairs.add((prev_char, char))
  55. prev_char = char
  56. return pairs
  57. def basic_clean(text):
  58. text = ftfy.fix_text(text)
  59. text = html.unescape(html.unescape(text))
  60. return text.strip()
  61. def whitespace_clean(text):
  62. text = re.sub(r"\s+", " ", text)
  63. text = text.strip()
  64. return text
  65. def _clean_canonicalize(x):
  66. # basic, remove whitespace, remove punctuation, lower case
  67. return canonicalize_text(basic_clean(x))
  68. def _clean_lower(x):
  69. # basic, remove whitespace, lower case
  70. return whitespace_clean(basic_clean(x)).lower()
  71. def _clean_whitespace(x):
  72. # basic, remove whitespace
  73. return whitespace_clean(basic_clean(x))
  74. def get_clean_fn(type: str):
  75. if type == "canonicalize":
  76. return _clean_canonicalize
  77. elif type == "lower":
  78. return _clean_lower
  79. elif type == "whitespace":
  80. return _clean_whitespace
  81. else:
  82. assert False, f"Invalid clean function ({type})."
  83. def canonicalize_text(text, *, keep_punctuation_exact_string=None):
  84. """Returns canonicalized `text` (lowercase and punctuation removed).
  85. From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
  86. Args:
  87. text: string to be canonicalized.
  88. keep_punctuation_exact_string: If provided, then this exact string kept.
  89. For example providing '{}' will keep any occurrences of '{}' (but will
  90. still remove '{' and '}' that appear separately).
  91. """
  92. text = text.replace("_", " ")
  93. if keep_punctuation_exact_string:
  94. text = keep_punctuation_exact_string.join(
  95. part.translate(str.maketrans("", "", string.punctuation))
  96. for part in text.split(keep_punctuation_exact_string)
  97. )
  98. else:
  99. text = text.translate(str.maketrans("", "", string.punctuation))
  100. text = text.lower()
  101. text = re.sub(r"\s+", " ", text)
  102. return text.strip()
  103. class SimpleTokenizer(object):
  104. def __init__(
  105. self,
  106. bpe_path: Union[str, os.PathLike],
  107. additional_special_tokens: Optional[List[str]] = None,
  108. context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
  109. clean: str = "lower",
  110. ):
  111. self.byte_encoder = bytes_to_unicode()
  112. self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
  113. with g_pathmgr.open(bpe_path, "rb") as fh:
  114. bpe_bytes = io.BytesIO(fh.read())
  115. merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
  116. # merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
  117. merges = merges[1 : 49152 - 256 - 2 + 1]
  118. merges = [tuple(merge.split()) for merge in merges]
  119. vocab = list(bytes_to_unicode().values())
  120. vocab = vocab + [v + "</w>" for v in vocab]
  121. for merge in merges:
  122. vocab.append("".join(merge))
  123. special_tokens = ["<start_of_text>", "<end_of_text>"]
  124. if additional_special_tokens:
  125. special_tokens += additional_special_tokens
  126. vocab.extend(special_tokens)
  127. self.encoder = dict(zip(vocab, range(len(vocab))))
  128. self.decoder = {v: k for k, v in self.encoder.items()}
  129. self.bpe_ranks = dict(zip(merges, range(len(merges))))
  130. self.cache = {t: t for t in special_tokens}
  131. special = "|".join(special_tokens)
  132. self.pat = re.compile(
  133. special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
  134. re.IGNORECASE,
  135. )
  136. self.vocab_size = len(self.encoder)
  137. self.all_special_ids = [self.encoder[t] for t in special_tokens]
  138. self.sot_token_id = self.all_special_ids[0]
  139. self.eot_token_id = self.all_special_ids[1]
  140. self.context_length = context_length
  141. self.clean_fn = get_clean_fn(clean)
  142. def bpe(self, token):
  143. if token in self.cache:
  144. return self.cache[token]
  145. word = tuple(token[:-1]) + (token[-1] + "</w>",)
  146. pairs = get_pairs(word)
  147. if not pairs:
  148. return token + "</w>"
  149. while True:
  150. bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
  151. if bigram not in self.bpe_ranks:
  152. break
  153. first, second = bigram
  154. new_word = []
  155. i = 0
  156. while i < len(word):
  157. try:
  158. j = word.index(first, i)
  159. new_word.extend(word[i:j])
  160. i = j
  161. except:
  162. new_word.extend(word[i:])
  163. break
  164. if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
  165. new_word.append(first + second)
  166. i += 2
  167. else:
  168. new_word.append(word[i])
  169. i += 1
  170. new_word = tuple(new_word)
  171. word = new_word
  172. if len(word) == 1:
  173. break
  174. else:
  175. pairs = get_pairs(word)
  176. word = " ".join(word)
  177. self.cache[token] = word
  178. return word
  179. def encode(self, text):
  180. bpe_tokens = []
  181. text = self.clean_fn(text)
  182. for token in re.findall(self.pat, text):
  183. token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
  184. bpe_tokens.extend(
  185. self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
  186. )
  187. return bpe_tokens
  188. def decode(self, tokens):
  189. text = "".join([self.decoder[token] for token in tokens])
  190. text = (
  191. bytearray([self.byte_decoder[c] for c in text])
  192. .decode("utf-8", errors="replace")
  193. .replace("</w>", " ")
  194. )
  195. return text
  196. def __call__(
  197. self, texts: Union[str, List[str]], context_length: Optional[int] = None
  198. ) -> torch.LongTensor:
  199. """Returns the tokenized representation of given input string(s)
  200. Parameters
  201. ----------
  202. texts : Union[str, List[str]]
  203. An input string or a list of input strings to tokenize
  204. context_length : int
  205. The context length to use; all CLIP models use 77 as the context length
  206. Returns
  207. -------
  208. A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
  209. """
  210. if isinstance(texts, str):
  211. texts = [texts]
  212. context_length = context_length or self.context_length
  213. assert context_length, "Please set a valid context length"
  214. all_tokens = [
  215. [self.sot_token_id] + self.encode(text) + [self.eot_token_id]
  216. for text in texts
  217. ]
  218. result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
  219. for i, tokens in enumerate(all_tokens):
  220. if len(tokens) > context_length:
  221. tokens = tokens[:context_length] # Truncate
  222. tokens[-1] = self.eot_token_id
  223. result[i, : len(tokens)] = torch.tensor(tokens)
  224. return result