transformer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. from functools import partial
  7. from typing import Tuple, Type
  8. import torch
  9. import torch.nn.functional as F
  10. from torch import nn, Tensor
  11. from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
  12. from sam2.modeling.sam2_utils import MLP
  13. class TwoWayTransformer(nn.Module):
  14. def __init__(
  15. self,
  16. depth: int,
  17. embedding_dim: int,
  18. num_heads: int,
  19. mlp_dim: int,
  20. activation: Type[nn.Module] = nn.ReLU,
  21. attention_downsample_rate: int = 2,
  22. ) -> None:
  23. """
  24. A transformer decoder that attends to an input image using
  25. queries whose positional embedding is supplied.
  26. Args:
  27. depth (int): number of layers in the transformer
  28. embedding_dim (int): the channel dimension for the input embeddings
  29. num_heads (int): the number of heads for multihead attention. Must
  30. divide embedding_dim
  31. mlp_dim (int): the channel dimension internal to the MLP block
  32. activation (nn.Module): the activation to use in the MLP block
  33. """
  34. super().__init__()
  35. self.depth = depth
  36. self.embedding_dim = embedding_dim
  37. self.num_heads = num_heads
  38. self.mlp_dim = mlp_dim
  39. self.layers = nn.ModuleList()
  40. for i in range(depth):
  41. self.layers.append(
  42. TwoWayAttentionBlock(
  43. embedding_dim=embedding_dim,
  44. num_heads=num_heads,
  45. mlp_dim=mlp_dim,
  46. activation=activation,
  47. attention_downsample_rate=attention_downsample_rate,
  48. skip_first_layer_pe=(i == 0),
  49. )
  50. )
  51. self.final_attn_token_to_image = Attention(
  52. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  53. )
  54. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  55. def forward(
  56. self,
  57. image_embedding: Tensor,
  58. image_pe: Tensor,
  59. point_embedding: Tensor,
  60. ) -> Tuple[Tensor, Tensor]:
  61. """
  62. Args:
  63. image_embedding (torch.Tensor): image to attend to. Should be shape
  64. B x embedding_dim x h x w for any h and w.
  65. image_pe (torch.Tensor): the positional encoding to add to the image. Must
  66. have the same shape as image_embedding.
  67. point_embedding (torch.Tensor): the embedding to add to the query points.
  68. Must have shape B x N_points x embedding_dim for any N_points.
  69. Returns:
  70. torch.Tensor: the processed point_embedding
  71. torch.Tensor: the processed image_embedding
  72. """
  73. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  74. bs, c, h, w = image_embedding.shape
  75. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  76. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  77. # Prepare queries
  78. queries = point_embedding
  79. keys = image_embedding
  80. # Apply transformer blocks and final layernorm
  81. for layer in self.layers:
  82. queries, keys = layer(
  83. queries=queries,
  84. keys=keys,
  85. query_pe=point_embedding,
  86. key_pe=image_pe,
  87. )
  88. # Apply the final attention layer from the points to the image
  89. q = queries + point_embedding
  90. k = keys + image_pe
  91. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  92. queries = queries + attn_out
  93. queries = self.norm_final_attn(queries)
  94. return queries, keys
  95. class TwoWayAttentionBlock(nn.Module):
  96. def __init__(
  97. self,
  98. embedding_dim: int,
  99. num_heads: int,
  100. mlp_dim: int = 2048,
  101. activation: Type[nn.Module] = nn.ReLU,
  102. attention_downsample_rate: int = 2,
  103. skip_first_layer_pe: bool = False,
  104. ) -> None:
  105. """
  106. A transformer block with four layers: (1) self-attention of sparse
  107. inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
  108. block on sparse inputs, and (4) cross attention of dense inputs to sparse
  109. inputs.
  110. Arguments:
  111. embedding_dim (int): the channel dimension of the embeddings
  112. num_heads (int): the number of heads in the attention layers
  113. mlp_dim (int): the hidden dimension of the mlp block
  114. activation (nn.Module): the activation of the mlp block
  115. skip_first_layer_pe (bool): skip the PE on the first layer
  116. """
  117. super().__init__()
  118. self.self_attn = Attention(embedding_dim, num_heads)
  119. self.norm1 = nn.LayerNorm(embedding_dim)
  120. self.cross_attn_token_to_image = Attention(
  121. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  122. )
  123. self.norm2 = nn.LayerNorm(embedding_dim)
  124. self.mlp = MLP(
  125. embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
  126. )
  127. self.norm3 = nn.LayerNorm(embedding_dim)
  128. self.norm4 = nn.LayerNorm(embedding_dim)
  129. self.cross_attn_image_to_token = Attention(
  130. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  131. )
  132. self.skip_first_layer_pe = skip_first_layer_pe
  133. def forward(
  134. self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
  135. ) -> Tuple[Tensor, Tensor]:
  136. # Self attention block
  137. if self.skip_first_layer_pe:
  138. queries = self.self_attn(q=queries, k=queries, v=queries)
  139. else:
  140. q = queries + query_pe
  141. attn_out = self.self_attn(q=q, k=q, v=queries)
  142. queries = queries + attn_out
  143. queries = self.norm1(queries)
  144. # Cross attention block, tokens attending to image embedding
  145. q = queries + query_pe
  146. k = keys + key_pe
  147. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  148. queries = queries + attn_out
  149. queries = self.norm2(queries)
  150. # MLP block
  151. mlp_out = self.mlp(queries)
  152. queries = queries + mlp_out
  153. queries = self.norm3(queries)
  154. # Cross attention block, image embedding attending to tokens
  155. q = queries + query_pe
  156. k = keys + key_pe
  157. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  158. keys = keys + attn_out
  159. keys = self.norm4(keys)
  160. return queries, keys
  161. class Attention(nn.Module):
  162. """
  163. An attention layer that allows for downscaling the size of the embedding
  164. after projection to queries, keys, and values.
  165. """
  166. def __init__(
  167. self,
  168. embedding_dim: int,
  169. num_heads: int,
  170. downsample_rate: int = 1,
  171. dropout: float = 0.0,
  172. kv_in_dim: int = None,
  173. ) -> None:
  174. super().__init__()
  175. self.embedding_dim = embedding_dim
  176. self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
  177. self.internal_dim = embedding_dim // downsample_rate
  178. self.num_heads = num_heads
  179. assert (
  180. self.internal_dim % num_heads == 0
  181. ), "num_heads must divide embedding_dim."
  182. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  183. self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  184. self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  185. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  186. self.dropout_p = dropout
  187. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  188. b, n, c = x.shape
  189. x = x.reshape(b, n, num_heads, c // num_heads)
  190. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  191. def _recombine_heads(self, x: Tensor) -> Tensor:
  192. b, n_heads, n_tokens, c_per_head = x.shape
  193. x = x.transpose(1, 2)
  194. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  195. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  196. # Input projections
  197. q = self.q_proj(q)
  198. k = self.k_proj(k)
  199. v = self.v_proj(v)
  200. # Separate into heads
  201. q = self._separate_heads(q, self.num_heads)
  202. k = self._separate_heads(k, self.num_heads)
  203. v = self._separate_heads(v, self.num_heads)
  204. dropout_p = self.dropout_p if self.training else 0.0
  205. # Attention
  206. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  207. out = self._recombine_heads(out)
  208. out = self.out_proj(out)
  209. return out
  210. class RoPEAttention(Attention):
  211. """Attention with rotary position encoding."""
  212. def __init__(
  213. self,
  214. *args,
  215. rope_theta=10000.0,
  216. # whether to repeat q rope to match k length
  217. # this is needed for cross-attention to memories
  218. rope_k_repeat=False,
  219. feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
  220. **kwargs,
  221. ):
  222. super().__init__(*args, **kwargs)
  223. self.compute_cis = partial(
  224. compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
  225. )
  226. freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
  227. self.freqs_cis = (
  228. freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
  229. )
  230. self.rope_k_repeat = rope_k_repeat
  231. def forward(
  232. self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
  233. ) -> Tensor:
  234. # Input projections
  235. q = self.q_proj(q)
  236. k = self.k_proj(k)
  237. v = self.v_proj(v)
  238. # Separate into heads
  239. q = self._separate_heads(q, self.num_heads)
  240. k = self._separate_heads(k, self.num_heads)
  241. v = self._separate_heads(v, self.num_heads)
  242. # Apply rotary position encoding
  243. w = h = math.sqrt(q.shape[-2])
  244. self.freqs_cis = self.freqs_cis.to(q.device)
  245. if self.freqs_cis.shape[0] != q.shape[-2]:
  246. self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
  247. if q.shape[-2] != k.shape[-2]:
  248. assert self.rope_k_repeat
  249. num_k_rope = k.size(-2) - num_k_exclude_rope
  250. q, k[:, :, :num_k_rope] = apply_rotary_enc(
  251. q,
  252. k[:, :, :num_k_rope],
  253. freqs_cis=self.freqs_cis,
  254. repeat_freqs_k=self.rope_k_repeat,
  255. )
  256. dropout_p = self.dropout_p if self.training else 0.0
  257. # Attention
  258. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  259. out = self._recombine_heads(out)
  260. out = self.out_proj(out)
  261. return out