transformer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import contextlib
  6. import math
  7. import warnings
  8. from functools import partial
  9. from typing import Tuple, Type
  10. import torch
  11. import torch.nn.functional as F
  12. from torch import nn, Tensor
  13. from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
  14. from sam2.modeling.sam2_utils import MLP
  15. from sam2.utils.misc import get_sdpa_settings
  16. warnings.simplefilter(action="ignore", category=FutureWarning)
  17. # Check whether Flash Attention is available (and use it by default)
  18. OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
  19. # A fallback setting to allow all available kernels if Flash Attention fails
  20. ALLOW_ALL_KERNELS = False
  21. def sdp_kernel_context(dropout_p):
  22. """
  23. Get the context for the attention scaled dot-product kernel. We use Flash Attention
  24. by default, but fall back to all available kernels if Flash Attention fails.
  25. """
  26. if ALLOW_ALL_KERNELS:
  27. return contextlib.nullcontext()
  28. return torch.backends.cuda.sdp_kernel(
  29. enable_flash=USE_FLASH_ATTN,
  30. # if Flash attention kernel is off, then math kernel needs to be enabled
  31. enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
  32. enable_mem_efficient=OLD_GPU,
  33. )
  34. class TwoWayTransformer(nn.Module):
  35. def __init__(
  36. self,
  37. depth: int,
  38. embedding_dim: int,
  39. num_heads: int,
  40. mlp_dim: int,
  41. activation: Type[nn.Module] = nn.ReLU,
  42. attention_downsample_rate: int = 2,
  43. ) -> None:
  44. """
  45. A transformer decoder that attends to an input image using
  46. queries whose positional embedding is supplied.
  47. Args:
  48. depth (int): number of layers in the transformer
  49. embedding_dim (int): the channel dimension for the input embeddings
  50. num_heads (int): the number of heads for multihead attention. Must
  51. divide embedding_dim
  52. mlp_dim (int): the channel dimension internal to the MLP block
  53. activation (nn.Module): the activation to use in the MLP block
  54. """
  55. super().__init__()
  56. self.depth = depth
  57. self.embedding_dim = embedding_dim
  58. self.num_heads = num_heads
  59. self.mlp_dim = mlp_dim
  60. self.layers = nn.ModuleList()
  61. for i in range(depth):
  62. self.layers.append(
  63. TwoWayAttentionBlock(
  64. embedding_dim=embedding_dim,
  65. num_heads=num_heads,
  66. mlp_dim=mlp_dim,
  67. activation=activation,
  68. attention_downsample_rate=attention_downsample_rate,
  69. skip_first_layer_pe=(i == 0),
  70. )
  71. )
  72. self.final_attn_token_to_image = Attention(
  73. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  74. )
  75. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  76. def forward(
  77. self,
  78. image_embedding: Tensor,
  79. image_pe: Tensor,
  80. point_embedding: Tensor,
  81. ) -> Tuple[Tensor, Tensor]:
  82. """
  83. Args:
  84. image_embedding (torch.Tensor): image to attend to. Should be shape
  85. B x embedding_dim x h x w for any h and w.
  86. image_pe (torch.Tensor): the positional encoding to add to the image. Must
  87. have the same shape as image_embedding.
  88. point_embedding (torch.Tensor): the embedding to add to the query points.
  89. Must have shape B x N_points x embedding_dim for any N_points.
  90. Returns:
  91. torch.Tensor: the processed point_embedding
  92. torch.Tensor: the processed image_embedding
  93. """
  94. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  95. bs, c, h, w = image_embedding.shape
  96. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  97. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  98. # Prepare queries
  99. queries = point_embedding
  100. keys = image_embedding
  101. # Apply transformer blocks and final layernorm
  102. for layer in self.layers:
  103. queries, keys = layer(
  104. queries=queries,
  105. keys=keys,
  106. query_pe=point_embedding,
  107. key_pe=image_pe,
  108. )
  109. # Apply the final attention layer from the points to the image
  110. q = queries + point_embedding
  111. k = keys + image_pe
  112. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  113. queries = queries + attn_out
  114. queries = self.norm_final_attn(queries)
  115. return queries, keys
  116. class TwoWayAttentionBlock(nn.Module):
  117. def __init__(
  118. self,
  119. embedding_dim: int,
  120. num_heads: int,
  121. mlp_dim: int = 2048,
  122. activation: Type[nn.Module] = nn.ReLU,
  123. attention_downsample_rate: int = 2,
  124. skip_first_layer_pe: bool = False,
  125. ) -> None:
  126. """
  127. A transformer block with four layers: (1) self-attention of sparse
  128. inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
  129. block on sparse inputs, and (4) cross attention of dense inputs to sparse
  130. inputs.
  131. Arguments:
  132. embedding_dim (int): the channel dimension of the embeddings
  133. num_heads (int): the number of heads in the attention layers
  134. mlp_dim (int): the hidden dimension of the mlp block
  135. activation (nn.Module): the activation of the mlp block
  136. skip_first_layer_pe (bool): skip the PE on the first layer
  137. """
  138. super().__init__()
  139. self.self_attn = Attention(embedding_dim, num_heads)
  140. self.norm1 = nn.LayerNorm(embedding_dim)
  141. self.cross_attn_token_to_image = Attention(
  142. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  143. )
  144. self.norm2 = nn.LayerNorm(embedding_dim)
  145. self.mlp = MLP(
  146. embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
  147. )
  148. self.norm3 = nn.LayerNorm(embedding_dim)
  149. self.norm4 = nn.LayerNorm(embedding_dim)
  150. self.cross_attn_image_to_token = Attention(
  151. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  152. )
  153. self.skip_first_layer_pe = skip_first_layer_pe
  154. def forward(
  155. self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
  156. ) -> Tuple[Tensor, Tensor]:
  157. # Self attention block
  158. if self.skip_first_layer_pe:
  159. queries = self.self_attn(q=queries, k=queries, v=queries)
  160. else:
  161. q = queries + query_pe
  162. attn_out = self.self_attn(q=q, k=q, v=queries)
  163. queries = queries + attn_out
  164. queries = self.norm1(queries)
  165. # Cross attention block, tokens attending to image embedding
  166. q = queries + query_pe
  167. k = keys + key_pe
  168. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  169. queries = queries + attn_out
  170. queries = self.norm2(queries)
  171. # MLP block
  172. mlp_out = self.mlp(queries)
  173. queries = queries + mlp_out
  174. queries = self.norm3(queries)
  175. # Cross attention block, image embedding attending to tokens
  176. q = queries + query_pe
  177. k = keys + key_pe
  178. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  179. keys = keys + attn_out
  180. keys = self.norm4(keys)
  181. return queries, keys
  182. class Attention(nn.Module):
  183. """
  184. An attention layer that allows for downscaling the size of the embedding
  185. after projection to queries, keys, and values.
  186. """
  187. def __init__(
  188. self,
  189. embedding_dim: int,
  190. num_heads: int,
  191. downsample_rate: int = 1,
  192. dropout: float = 0.0,
  193. kv_in_dim: int = None,
  194. ) -> None:
  195. super().__init__()
  196. self.embedding_dim = embedding_dim
  197. self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
  198. self.internal_dim = embedding_dim // downsample_rate
  199. self.num_heads = num_heads
  200. assert (
  201. self.internal_dim % num_heads == 0
  202. ), "num_heads must divide embedding_dim."
  203. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  204. self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  205. self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  206. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  207. self.dropout_p = dropout
  208. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  209. b, n, c = x.shape
  210. x = x.reshape(b, n, num_heads, c // num_heads)
  211. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  212. def _recombine_heads(self, x: Tensor) -> Tensor:
  213. b, n_heads, n_tokens, c_per_head = x.shape
  214. x = x.transpose(1, 2)
  215. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  216. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  217. # Input projections
  218. q = self.q_proj(q)
  219. k = self.k_proj(k)
  220. v = self.v_proj(v)
  221. # Separate into heads
  222. q = self._separate_heads(q, self.num_heads)
  223. k = self._separate_heads(k, self.num_heads)
  224. v = self._separate_heads(v, self.num_heads)
  225. dropout_p = self.dropout_p if self.training else 0.0
  226. # Attention
  227. try:
  228. with sdp_kernel_context(dropout_p):
  229. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  230. except Exception as e:
  231. # Fall back to all kernels if the Flash attention kernel fails
  232. warnings.warn(
  233. f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
  234. f"kernels for scaled_dot_product_attention (which may have a slower speed).",
  235. category=UserWarning,
  236. stacklevel=2,
  237. )
  238. global ALLOW_ALL_KERNELS
  239. ALLOW_ALL_KERNELS = True
  240. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  241. out = self._recombine_heads(out)
  242. out = self.out_proj(out)
  243. return out
  244. class RoPEAttention(Attention):
  245. """Attention with rotary position encoding."""
  246. def __init__(
  247. self,
  248. *args,
  249. rope_theta=10000.0,
  250. # whether to repeat q rope to match k length
  251. # this is needed for cross-attention to memories
  252. rope_k_repeat=False,
  253. feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
  254. **kwargs,
  255. ):
  256. super().__init__(*args, **kwargs)
  257. self.compute_cis = partial(
  258. compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
  259. )
  260. freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
  261. self.freqs_cis = freqs_cis
  262. self.rope_k_repeat = rope_k_repeat
  263. def forward(
  264. self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
  265. ) -> Tensor:
  266. # Input projections
  267. q = self.q_proj(q)
  268. k = self.k_proj(k)
  269. v = self.v_proj(v)
  270. # Separate into heads
  271. q = self._separate_heads(q, self.num_heads)
  272. k = self._separate_heads(k, self.num_heads)
  273. v = self._separate_heads(v, self.num_heads)
  274. # Apply rotary position encoding
  275. w = h = math.sqrt(q.shape[-2])
  276. self.freqs_cis = self.freqs_cis.to(q.device)
  277. if self.freqs_cis.shape[0] != q.shape[-2]:
  278. self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
  279. if q.shape[-2] != k.shape[-2]:
  280. assert self.rope_k_repeat
  281. num_k_rope = k.size(-2) - num_k_exclude_rope
  282. q, k[:, :, :num_k_rope] = apply_rotary_enc(
  283. q,
  284. k[:, :, :num_k_rope],
  285. freqs_cis=self.freqs_cis,
  286. repeat_freqs_k=self.rope_k_repeat,
  287. )
  288. dropout_p = self.dropout_p if self.training else 0.0
  289. # Attention
  290. try:
  291. with sdp_kernel_context(dropout_p):
  292. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  293. except Exception as e:
  294. # Fall back to all kernels if the Flash attention kernel fails
  295. warnings.warn(
  296. f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
  297. f"kernels for scaled_dot_product_attention (which may have a slower speed).",
  298. category=UserWarning,
  299. stacklevel=2,
  300. )
  301. global ALLOW_ALL_KERNELS
  302. ALLOW_ALL_KERNELS = True
  303. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  304. out = self._recombine_heads(out)
  305. out = self.out_proj(out)
  306. return out