transformer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. import warnings
  7. from functools import partial
  8. from typing import Tuple, Type
  9. import torch
  10. import torch.nn.functional as F
  11. from torch import nn, Tensor
  12. from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
  13. from sam2.modeling.sam2_utils import MLP
  14. from sam2.utils.misc import get_sdpa_settings
  15. warnings.simplefilter(action="ignore", category=FutureWarning)
  16. OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
  17. class TwoWayTransformer(nn.Module):
  18. def __init__(
  19. self,
  20. depth: int,
  21. embedding_dim: int,
  22. num_heads: int,
  23. mlp_dim: int,
  24. activation: Type[nn.Module] = nn.ReLU,
  25. attention_downsample_rate: int = 2,
  26. ) -> None:
  27. """
  28. A transformer decoder that attends to an input image using
  29. queries whose positional embedding is supplied.
  30. Args:
  31. depth (int): number of layers in the transformer
  32. embedding_dim (int): the channel dimension for the input embeddings
  33. num_heads (int): the number of heads for multihead attention. Must
  34. divide embedding_dim
  35. mlp_dim (int): the channel dimension internal to the MLP block
  36. activation (nn.Module): the activation to use in the MLP block
  37. """
  38. super().__init__()
  39. self.depth = depth
  40. self.embedding_dim = embedding_dim
  41. self.num_heads = num_heads
  42. self.mlp_dim = mlp_dim
  43. self.layers = nn.ModuleList()
  44. for i in range(depth):
  45. self.layers.append(
  46. TwoWayAttentionBlock(
  47. embedding_dim=embedding_dim,
  48. num_heads=num_heads,
  49. mlp_dim=mlp_dim,
  50. activation=activation,
  51. attention_downsample_rate=attention_downsample_rate,
  52. skip_first_layer_pe=(i == 0),
  53. )
  54. )
  55. self.final_attn_token_to_image = Attention(
  56. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  57. )
  58. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  59. def forward(
  60. self,
  61. image_embedding: Tensor,
  62. image_pe: Tensor,
  63. point_embedding: Tensor,
  64. ) -> Tuple[Tensor, Tensor]:
  65. """
  66. Args:
  67. image_embedding (torch.Tensor): image to attend to. Should be shape
  68. B x embedding_dim x h x w for any h and w.
  69. image_pe (torch.Tensor): the positional encoding to add to the image. Must
  70. have the same shape as image_embedding.
  71. point_embedding (torch.Tensor): the embedding to add to the query points.
  72. Must have shape B x N_points x embedding_dim for any N_points.
  73. Returns:
  74. torch.Tensor: the processed point_embedding
  75. torch.Tensor: the processed image_embedding
  76. """
  77. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  78. bs, c, h, w = image_embedding.shape
  79. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  80. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  81. # Prepare queries
  82. queries = point_embedding
  83. keys = image_embedding
  84. # Apply transformer blocks and final layernorm
  85. for layer in self.layers:
  86. queries, keys = layer(
  87. queries=queries,
  88. keys=keys,
  89. query_pe=point_embedding,
  90. key_pe=image_pe,
  91. )
  92. # Apply the final attention layer from the points to the image
  93. q = queries + point_embedding
  94. k = keys + image_pe
  95. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  96. queries = queries + attn_out
  97. queries = self.norm_final_attn(queries)
  98. return queries, keys
  99. class TwoWayAttentionBlock(nn.Module):
  100. def __init__(
  101. self,
  102. embedding_dim: int,
  103. num_heads: int,
  104. mlp_dim: int = 2048,
  105. activation: Type[nn.Module] = nn.ReLU,
  106. attention_downsample_rate: int = 2,
  107. skip_first_layer_pe: bool = False,
  108. ) -> None:
  109. """
  110. A transformer block with four layers: (1) self-attention of sparse
  111. inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
  112. block on sparse inputs, and (4) cross attention of dense inputs to sparse
  113. inputs.
  114. Arguments:
  115. embedding_dim (int): the channel dimension of the embeddings
  116. num_heads (int): the number of heads in the attention layers
  117. mlp_dim (int): the hidden dimension of the mlp block
  118. activation (nn.Module): the activation of the mlp block
  119. skip_first_layer_pe (bool): skip the PE on the first layer
  120. """
  121. super().__init__()
  122. self.self_attn = Attention(embedding_dim, num_heads)
  123. self.norm1 = nn.LayerNorm(embedding_dim)
  124. self.cross_attn_token_to_image = Attention(
  125. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  126. )
  127. self.norm2 = nn.LayerNorm(embedding_dim)
  128. self.mlp = MLP(
  129. embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
  130. )
  131. self.norm3 = nn.LayerNorm(embedding_dim)
  132. self.norm4 = nn.LayerNorm(embedding_dim)
  133. self.cross_attn_image_to_token = Attention(
  134. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  135. )
  136. self.skip_first_layer_pe = skip_first_layer_pe
  137. def forward(
  138. self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
  139. ) -> Tuple[Tensor, Tensor]:
  140. # Self attention block
  141. if self.skip_first_layer_pe:
  142. queries = self.self_attn(q=queries, k=queries, v=queries)
  143. else:
  144. q = queries + query_pe
  145. attn_out = self.self_attn(q=q, k=q, v=queries)
  146. queries = queries + attn_out
  147. queries = self.norm1(queries)
  148. # Cross attention block, tokens attending to image embedding
  149. q = queries + query_pe
  150. k = keys + key_pe
  151. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  152. queries = queries + attn_out
  153. queries = self.norm2(queries)
  154. # MLP block
  155. mlp_out = self.mlp(queries)
  156. queries = queries + mlp_out
  157. queries = self.norm3(queries)
  158. # Cross attention block, image embedding attending to tokens
  159. q = queries + query_pe
  160. k = keys + key_pe
  161. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  162. keys = keys + attn_out
  163. keys = self.norm4(keys)
  164. return queries, keys
  165. class Attention(nn.Module):
  166. """
  167. An attention layer that allows for downscaling the size of the embedding
  168. after projection to queries, keys, and values.
  169. """
  170. def __init__(
  171. self,
  172. embedding_dim: int,
  173. num_heads: int,
  174. downsample_rate: int = 1,
  175. dropout: float = 0.0,
  176. kv_in_dim: int = None,
  177. ) -> None:
  178. super().__init__()
  179. self.embedding_dim = embedding_dim
  180. self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
  181. self.internal_dim = embedding_dim // downsample_rate
  182. self.num_heads = num_heads
  183. assert (
  184. self.internal_dim % num_heads == 0
  185. ), "num_heads must divide embedding_dim."
  186. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  187. self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  188. self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  189. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  190. self.dropout_p = dropout
  191. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  192. b, n, c = x.shape
  193. x = x.reshape(b, n, num_heads, c // num_heads)
  194. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  195. def _recombine_heads(self, x: Tensor) -> Tensor:
  196. b, n_heads, n_tokens, c_per_head = x.shape
  197. x = x.transpose(1, 2)
  198. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  199. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  200. # Input projections
  201. q = self.q_proj(q)
  202. k = self.k_proj(k)
  203. v = self.v_proj(v)
  204. # Separate into heads
  205. q = self._separate_heads(q, self.num_heads)
  206. k = self._separate_heads(k, self.num_heads)
  207. v = self._separate_heads(v, self.num_heads)
  208. dropout_p = self.dropout_p if self.training else 0.0
  209. # Attention
  210. with torch.backends.cuda.sdp_kernel(
  211. enable_flash=USE_FLASH_ATTN,
  212. # if Flash attention kernel is off, then math kernel needs to be enabled
  213. enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
  214. enable_mem_efficient=OLD_GPU,
  215. ):
  216. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  217. out = self._recombine_heads(out)
  218. out = self.out_proj(out)
  219. return out
  220. class RoPEAttention(Attention):
  221. """Attention with rotary position encoding."""
  222. def __init__(
  223. self,
  224. *args,
  225. rope_theta=10000.0,
  226. # whether to repeat q rope to match k length
  227. # this is needed for cross-attention to memories
  228. rope_k_repeat=False,
  229. feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
  230. **kwargs,
  231. ):
  232. super().__init__(*args, **kwargs)
  233. self.compute_cis = partial(
  234. compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
  235. )
  236. freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
  237. self.freqs_cis = freqs_cis
  238. self.rope_k_repeat = rope_k_repeat
  239. def forward(
  240. self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
  241. ) -> Tensor:
  242. # Input projections
  243. q = self.q_proj(q)
  244. k = self.k_proj(k)
  245. v = self.v_proj(v)
  246. # Separate into heads
  247. q = self._separate_heads(q, self.num_heads)
  248. k = self._separate_heads(k, self.num_heads)
  249. v = self._separate_heads(v, self.num_heads)
  250. # Apply rotary position encoding
  251. w = h = math.sqrt(q.shape[-2])
  252. self.freqs_cis = self.freqs_cis.to(q.device)
  253. if self.freqs_cis.shape[0] != q.shape[-2]:
  254. self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
  255. if q.shape[-2] != k.shape[-2]:
  256. assert self.rope_k_repeat
  257. num_k_rope = k.size(-2) - num_k_exclude_rope
  258. q, k[:, :, :num_k_rope] = apply_rotary_enc(
  259. q,
  260. k[:, :, :num_k_rope],
  261. freqs_cis=self.freqs_cis,
  262. repeat_freqs_k=self.rope_k_repeat,
  263. )
  264. dropout_p = self.dropout_p if self.training else 0.0
  265. # Attention
  266. with torch.backends.cuda.sdp_kernel(
  267. enable_flash=USE_FLASH_ATTN,
  268. # if Flash attention kernel is off, then math kernel needs to be enabled
  269. enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
  270. enable_mem_efficient=OLD_GPU,
  271. ):
  272. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  273. out = self._recombine_heads(out)
  274. out = self.out_proj(out)
  275. return out