text_encoder_ve.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from collections import OrderedDict
  4. from typing import Callable, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. from torch.utils.checkpoint import checkpoint
  8. from .model_misc import LayerScale
  9. class ResidualAttentionBlock(nn.Module):
  10. def __init__(
  11. self,
  12. d_model: int,
  13. n_head: int,
  14. mlp_ratio: float = 4.0,
  15. ls_init_value: Optional[float] = None,
  16. act_layer: Callable[[], nn.Module] = nn.GELU,
  17. norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
  18. ):
  19. super().__init__()
  20. # Attention
  21. self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
  22. # LayerNorm, LayerScale
  23. self.ln_1 = norm_layer(d_model)
  24. self.ln_2 = norm_layer(d_model)
  25. self.ls_1 = (
  26. LayerScale(d_model, ls_init_value)
  27. if ls_init_value is not None
  28. else nn.Identity()
  29. )
  30. self.ls_2 = (
  31. LayerScale(d_model, ls_init_value)
  32. if ls_init_value is not None
  33. else nn.Identity()
  34. )
  35. # MLP
  36. mlp_width = int(d_model * mlp_ratio)
  37. self.mlp = nn.Sequential(
  38. OrderedDict(
  39. [
  40. ("c_fc", nn.Linear(d_model, mlp_width)),
  41. ("gelu", act_layer()),
  42. ("c_proj", nn.Linear(mlp_width, d_model)),
  43. ]
  44. )
  45. )
  46. def attention(
  47. self,
  48. q_x: torch.Tensor,
  49. k_x: Optional[torch.Tensor] = None,
  50. v_x: Optional[torch.Tensor] = None,
  51. attn_mask: Optional[torch.Tensor] = None,
  52. ) -> torch.Tensor:
  53. k_x = k_x if k_x is not None else q_x
  54. v_x = v_x if v_x is not None else q_x
  55. if attn_mask is not None:
  56. # Leave boolean masks as is
  57. if not attn_mask.dtype == torch.bool:
  58. attn_mask = attn_mask.to(q_x.dtype)
  59. return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
  60. def forward(
  61. self,
  62. q_x: torch.Tensor,
  63. k_x: Optional[torch.Tensor] = None,
  64. v_x: Optional[torch.Tensor] = None,
  65. attn_mask: Optional[torch.Tensor] = None,
  66. ) -> torch.Tensor:
  67. k_x = (
  68. self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
  69. )
  70. v_x = (
  71. self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
  72. )
  73. x = q_x + self.ls_1(
  74. self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
  75. )
  76. x = x + self.ls_2(self.mlp(self.ln_2(x)))
  77. return x
  78. class Transformer(nn.Module):
  79. def __init__(
  80. self,
  81. width: int,
  82. layers: int,
  83. heads: int,
  84. mlp_ratio: float = 4.0,
  85. ls_init_value: Optional[float] = None,
  86. act_layer: Callable[[], nn.Module] = nn.GELU,
  87. norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
  88. compile_mode: Optional[str] = None,
  89. use_act_checkpoint: bool = False,
  90. ):
  91. super().__init__()
  92. self.width = width
  93. self.layers = layers
  94. self.grad_checkpointing = use_act_checkpoint
  95. self.resblocks = nn.ModuleList(
  96. [
  97. ResidualAttentionBlock(
  98. width,
  99. heads,
  100. mlp_ratio,
  101. ls_init_value=ls_init_value,
  102. act_layer=act_layer,
  103. norm_layer=norm_layer,
  104. )
  105. for _ in range(layers)
  106. ]
  107. )
  108. if compile_mode is not None:
  109. self.forward = torch.compile(
  110. self.forward, mode=compile_mode, fullgraph=True
  111. )
  112. if self.grad_checkpointing:
  113. torch._dynamo.config.optimize_ddp = False
  114. def forward(
  115. self,
  116. x: torch.Tensor,
  117. attn_mask: Optional[torch.Tensor] = None,
  118. ) -> torch.Tensor:
  119. for _, r in enumerate(self.resblocks):
  120. if (
  121. self.grad_checkpointing
  122. and not torch.jit.is_scripting()
  123. and self.training
  124. ):
  125. x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
  126. else:
  127. x = r(
  128. x,
  129. attn_mask=attn_mask,
  130. )
  131. return x
  132. def text_global_pool(
  133. x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
  134. ) -> Tuple[torch.Tensor, torch.Tensor]:
  135. if pool_type == "first":
  136. pooled, tokens = x[:, 0], x[:, 1:]
  137. elif pool_type == "last":
  138. pooled, tokens = x[:, -1], x[:, :-1]
  139. elif pool_type == "argmax":
  140. # take features from the eot embedding (eot_token is the highest number in each sequence)
  141. assert text is not None
  142. pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
  143. else:
  144. pooled = tokens = x
  145. return pooled, tokens
  146. class TextTransformer(nn.Module):
  147. def __init__(
  148. self,
  149. context_length: int = 77,
  150. vocab_size: int = 49408,
  151. width: int = 512,
  152. heads: int = 8,
  153. layers: int = 12,
  154. mlp_ratio: float = 4.0,
  155. ls_init_value: Optional[float] = None,
  156. output_dim: int = 512,
  157. no_causal_mask: bool = False,
  158. pool_type: str = "none", # no pooling
  159. proj_bias: bool = False,
  160. act_layer: Callable = nn.GELU,
  161. norm_layer: Callable = nn.LayerNorm,
  162. output_tokens: bool = False,
  163. use_ln_post: bool = True,
  164. compile_mode: Optional[str] = None,
  165. use_act_checkpoint: bool = False,
  166. ):
  167. super().__init__()
  168. assert pool_type in ("first", "last", "argmax", "none")
  169. self.output_tokens = output_tokens
  170. self.num_pos = self.context_length = context_length
  171. self.vocab_size = vocab_size
  172. self.width = width
  173. self.output_dim = output_dim
  174. self.heads = heads
  175. self.pool_type = pool_type
  176. self.token_embedding = nn.Embedding(self.vocab_size, width)
  177. self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
  178. self.transformer = Transformer(
  179. width=width,
  180. layers=layers,
  181. heads=heads,
  182. mlp_ratio=mlp_ratio,
  183. ls_init_value=ls_init_value,
  184. act_layer=act_layer,
  185. norm_layer=norm_layer,
  186. compile_mode=compile_mode,
  187. use_act_checkpoint=use_act_checkpoint,
  188. )
  189. self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
  190. if no_causal_mask:
  191. self.attn_mask = None
  192. else:
  193. self.register_buffer(
  194. "attn_mask", self.build_causal_mask(), persistent=False
  195. )
  196. if proj_bias:
  197. self.text_projection = nn.Linear(width, output_dim)
  198. else:
  199. self.text_projection = nn.Parameter(torch.empty(width, output_dim))
  200. def build_causal_mask(self) -> torch.Tensor:
  201. # lazily create causal attention mask, with full attention between the tokens
  202. # pytorch uses additive attention mask; fill with -inf
  203. mask = torch.empty(self.num_pos, self.num_pos)
  204. mask.fill_(float("-inf"))
  205. mask.triu_(1) # zero out the lower diagonal
  206. return mask
  207. def forward(
  208. self, text: torch.Tensor
  209. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  210. seq_len = text.shape[1]
  211. x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
  212. attn_mask = self.attn_mask
  213. if attn_mask is not None:
  214. attn_mask = attn_mask[:seq_len, :seq_len]
  215. x = x + self.positional_embedding[:seq_len]
  216. x = self.transformer(x, attn_mask=attn_mask)
  217. x = self.ln_final(x)
  218. pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
  219. if self.text_projection is not None:
  220. if isinstance(self.text_projection, nn.Linear):
  221. pooled = self.text_projection(pooled)
  222. else:
  223. pooled = pooled @ self.text_projection
  224. if self.output_tokens:
  225. return pooled, tokens
  226. return pooled
  227. class VETextEncoder(nn.Module):
  228. def __init__(
  229. self,
  230. d_model: int,
  231. tokenizer: Callable,
  232. width: int = 1024,
  233. heads: int = 16,
  234. layers: int = 24,
  235. context_length: int = 32,
  236. vocab_size: int = 49408,
  237. use_ln_post: bool = True,
  238. compile_mode: Optional[str] = None,
  239. use_act_checkpoint: bool = True,
  240. ):
  241. super().__init__()
  242. self.context_length = context_length
  243. self.use_ln_post = use_ln_post
  244. self.tokenizer = tokenizer
  245. self.encoder = TextTransformer(
  246. context_length=self.context_length,
  247. vocab_size=vocab_size,
  248. width=width,
  249. heads=heads,
  250. layers=layers,
  251. # we want the tokens, not just the pooled output
  252. output_tokens=True,
  253. use_ln_post=use_ln_post,
  254. compile_mode=compile_mode,
  255. use_act_checkpoint=use_act_checkpoint,
  256. )
  257. self.resizer = nn.Linear(self.encoder.width, d_model)
  258. def forward(
  259. self,
  260. text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
  261. input_boxes: Optional[List] = None,
  262. device: torch.device = None,
  263. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  264. if isinstance(text[0], str):
  265. # no use case for this
  266. assert input_boxes is None or len(input_boxes) == 0, "not supported"
  267. # Encode the text
  268. tokenized = self.tokenizer(text, context_length=self.context_length).to(
  269. device
  270. ) # [b, seq_len]
  271. text_attention_mask = (tokenized != 0).bool()
  272. # manually embed the tokens
  273. inputs_embeds = self.encoder.token_embedding(
  274. tokenized
  275. ) # [b, seq_len, d=1024]
  276. _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
  277. assert text_memory.shape[1] == inputs_embeds.shape[1]
  278. # Invert attention mask because its the opposite in pytorch transformer
  279. text_attention_mask = text_attention_mask.ne(1)
  280. # Transpose memory because pytorch's attention expects sequence first
  281. text_memory = text_memory.transpose(0, 1)
  282. # Resize the encoder hidden states to be of the same d_model as the decoder
  283. text_memory_resized = self.resizer(text_memory)
  284. else:
  285. # The text is already encoded, use as is.
  286. text_attention_mask, text_memory_resized, tokenized = text
  287. inputs_embeds = tokenized["inputs_embeds"]
  288. assert input_boxes is None or len(input_boxes) == 0, (
  289. "Can't replace boxes in text if it's already encoded"
  290. )
  291. # Note that the input_embeds are returned in pytorch's convention (sequence first)
  292. return (
  293. text_attention_mask,
  294. text_memory_resized,
  295. inputs_embeds.transpose(0, 1),
  296. )