transformer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. from functools import partial
  5. from typing import Tuple, Type
  6. import torch
  7. import torch.nn.functional as F
  8. from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
  9. from torch import nn, Tensor
  10. from .common import MLPBlock
  11. class TwoWayTransformer(nn.Module):
  12. def __init__(
  13. self,
  14. depth: int,
  15. embedding_dim: int,
  16. num_heads: int,
  17. mlp_dim: int,
  18. activation: Type[nn.Module] = nn.ReLU,
  19. attention_downsample_rate: int = 2,
  20. ) -> None:
  21. """
  22. A transformer decoder that attends to an input image using
  23. queries whose positional embedding is supplied.
  24. Args:
  25. depth (int): number of layers in the transformer
  26. embedding_dim (int): the channel dimension for the input embeddings
  27. num_heads (int): the number of heads for multihead attention. Must
  28. divide embedding_dim
  29. mlp_dim (int): the channel dimension internal to the MLP block
  30. activation (nn.Module): the activation to use in the MLP block
  31. """
  32. super().__init__()
  33. self.depth = depth
  34. self.embedding_dim = embedding_dim
  35. self.num_heads = num_heads
  36. self.mlp_dim = mlp_dim
  37. self.layers = nn.ModuleList()
  38. for i in range(depth):
  39. self.layers.append(
  40. TwoWayAttentionBlock(
  41. embedding_dim=embedding_dim,
  42. num_heads=num_heads,
  43. mlp_dim=mlp_dim,
  44. activation=activation,
  45. attention_downsample_rate=attention_downsample_rate,
  46. skip_first_layer_pe=(i == 0),
  47. )
  48. )
  49. self.final_attn_token_to_image = Attention(
  50. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  51. )
  52. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  53. def forward(
  54. self,
  55. image_embedding: Tensor,
  56. image_pe: Tensor,
  57. point_embedding: Tensor,
  58. ) -> Tuple[Tensor, Tensor]:
  59. """
  60. Args:
  61. image_embedding (torch.Tensor): image to attend to. Should be shape
  62. B x embedding_dim x h x w for any h and w.
  63. image_pe (torch.Tensor): the positional encoding to add to the image. Must
  64. have the same shape as image_embedding.
  65. point_embedding (torch.Tensor): the embedding to add to the query points.
  66. Must have shape B x N_points x embedding_dim for any N_points.
  67. Returns:
  68. torch.Tensor: the processed point_embedding
  69. torch.Tensor: the processed image_embedding
  70. """
  71. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  72. bs, c, h, w = image_embedding.shape
  73. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  74. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  75. # Prepare queries
  76. queries = point_embedding
  77. keys = image_embedding
  78. # Apply transformer blocks and final layernorm
  79. for layer in self.layers:
  80. queries, keys = layer(
  81. queries=queries,
  82. keys=keys,
  83. query_pe=point_embedding,
  84. key_pe=image_pe,
  85. )
  86. # Apply the final attention layer from the points to the image
  87. q = queries + point_embedding
  88. k = keys + image_pe
  89. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  90. queries = queries + attn_out
  91. queries = self.norm_final_attn(queries)
  92. return queries, keys
  93. class TwoWayAttentionBlock(nn.Module):
  94. def __init__(
  95. self,
  96. embedding_dim: int,
  97. num_heads: int,
  98. mlp_dim: int = 2048,
  99. activation: Type[nn.Module] = nn.ReLU,
  100. attention_downsample_rate: int = 2,
  101. skip_first_layer_pe: bool = False,
  102. ) -> None:
  103. """
  104. A transformer block with four layers: (1) self-attention of sparse
  105. inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
  106. block on sparse inputs, and (4) cross attention of dense inputs to sparse
  107. inputs.
  108. Arguments:
  109. embedding_dim (int): the channel dimension of the embeddings
  110. num_heads (int): the number of heads in the attention layers
  111. mlp_dim (int): the hidden dimension of the mlp block
  112. activation (nn.Module): the activation of the mlp block
  113. skip_first_layer_pe (bool): skip the PE on the first layer
  114. """
  115. super().__init__()
  116. self.self_attn = Attention(embedding_dim, num_heads)
  117. self.norm1 = nn.LayerNorm(embedding_dim)
  118. self.cross_attn_token_to_image = Attention(
  119. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  120. )
  121. self.norm2 = nn.LayerNorm(embedding_dim)
  122. self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
  123. self.norm3 = nn.LayerNorm(embedding_dim)
  124. self.norm4 = nn.LayerNorm(embedding_dim)
  125. self.cross_attn_image_to_token = Attention(
  126. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  127. )
  128. self.skip_first_layer_pe = skip_first_layer_pe
  129. def forward(
  130. self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
  131. ) -> Tuple[Tensor, Tensor]:
  132. # Self attention block
  133. if self.skip_first_layer_pe:
  134. queries = self.self_attn(q=queries, k=queries, v=queries)
  135. else:
  136. q = queries + query_pe
  137. attn_out = self.self_attn(q=q, k=q, v=queries)
  138. queries = queries + attn_out
  139. queries = self.norm1(queries)
  140. # Cross attention block, tokens attending to image embedding
  141. q = queries + query_pe
  142. k = keys + key_pe
  143. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  144. queries = queries + attn_out
  145. queries = self.norm2(queries)
  146. # MLP block
  147. mlp_out = self.mlp(queries)
  148. queries = queries + mlp_out
  149. queries = self.norm3(queries)
  150. # Cross attention block, image embedding attending to tokens
  151. q = queries + query_pe
  152. k = keys + key_pe
  153. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  154. keys = keys + attn_out
  155. keys = self.norm4(keys)
  156. return queries, keys
  157. class Attention(nn.Module):
  158. """
  159. An attention layer that allows for downscaling the size of the embedding
  160. after projection to queries, keys, and values.
  161. """
  162. def __init__(
  163. self,
  164. embedding_dim: int,
  165. num_heads: int,
  166. downsample_rate: int = 1,
  167. dropout: float = 0.0,
  168. kv_in_dim: int = None,
  169. use_fa3: bool = False,
  170. ) -> None:
  171. super().__init__()
  172. self.embedding_dim = embedding_dim
  173. self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
  174. self.internal_dim = embedding_dim // downsample_rate
  175. self.num_heads = num_heads
  176. self.use_fa3 = use_fa3
  177. assert self.internal_dim % num_heads == 0, (
  178. "num_heads must divide embedding_dim."
  179. )
  180. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  181. self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  182. self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  183. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  184. self.dropout_p = dropout
  185. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  186. b, n, c = x.shape
  187. x = x.reshape(b, n, num_heads, c // num_heads)
  188. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  189. def _recombine_heads(self, x: Tensor) -> Tensor:
  190. b, n_heads, n_tokens, c_per_head = x.shape
  191. x = x.transpose(1, 2)
  192. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  193. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  194. # Input projections
  195. q = self.q_proj(q)
  196. k = self.k_proj(k)
  197. v = self.v_proj(v)
  198. # Separate into heads
  199. q = self._separate_heads(q, self.num_heads)
  200. k = self._separate_heads(k, self.num_heads)
  201. v = self._separate_heads(v, self.num_heads)
  202. dropout_p = self.dropout_p if self.training else 0.0
  203. # Attention
  204. # with torch.backends.cuda.sdp_kernel(
  205. # enable_flash=USE_FLASH_ATTN,
  206. # # if Flash attention kernel is off, then math kernel needs to be enabled
  207. # enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
  208. # enable_mem_efficient=OLD_GPU,
  209. # ):
  210. # Let's trust the dispatcher....
  211. if self.use_fa3:
  212. from sam3.perflib.fa3 import flash_attn_func
  213. assert dropout_p == 0.0
  214. out = flash_attn_func(
  215. q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
  216. ).transpose(1, 2)
  217. else:
  218. torch.backends.cuda.enable_flash_sdp(True)
  219. torch.backends.cuda.enable_math_sdp(True)
  220. torch.backends.cuda.enable_mem_efficient_sdp(True)
  221. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  222. out = self._recombine_heads(out)
  223. out = self.out_proj(out)
  224. return out
  225. class RoPEAttention(Attention):
  226. """Attention with rotary position encoding."""
  227. def __init__(
  228. self,
  229. *args,
  230. rope_theta=10000.0,
  231. # whether to repeat q rope to match k length
  232. # this is needed for cross-attention to memories
  233. rope_k_repeat=False,
  234. feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
  235. use_rope_real=False,
  236. **kwargs,
  237. ):
  238. super().__init__(*args, **kwargs)
  239. self.use_rope_real = use_rope_real
  240. self.compute_cis = partial(
  241. compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
  242. )
  243. device = torch.device("cuda") if torch.cuda.is_available() else None
  244. self.freqs_cis = self.compute_cis(
  245. end_x=feat_sizes[0], end_y=feat_sizes[1], device=device
  246. )
  247. if self.use_rope_real:
  248. self.freqs_cis_real = self.freqs_cis.real
  249. self.freqs_cis_imag = self.freqs_cis.imag
  250. self.rope_k_repeat = rope_k_repeat
  251. def forward(
  252. self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
  253. ) -> Tensor:
  254. # Input projections
  255. q = self.q_proj(q)
  256. k = self.k_proj(k)
  257. v = self.v_proj(v)
  258. # Separate into heads
  259. q = self._separate_heads(q, self.num_heads)
  260. k = self._separate_heads(k, self.num_heads)
  261. v = self._separate_heads(v, self.num_heads)
  262. # Apply rotary position encoding
  263. w = h = math.sqrt(q.shape[-2])
  264. if self.freqs_cis.shape[0] != q.shape[-2]:
  265. self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device)
  266. self.freqs_cis_real = self.freqs_cis.real
  267. self.freqs_cis_imag = self.freqs_cis.imag
  268. if q.shape[-2] != k.shape[-2]:
  269. assert self.rope_k_repeat
  270. num_k_rope = k.size(-2) - num_k_exclude_rope
  271. if self.use_rope_real:
  272. q, k[:, :, :num_k_rope] = apply_rotary_enc_real(
  273. q,
  274. k[:, :, :num_k_rope],
  275. freqs_cis_real=self.freqs_cis_real,
  276. freqs_cis_imag=self.freqs_cis_imag,
  277. repeat_freqs_k=self.rope_k_repeat,
  278. )
  279. else:
  280. q, k[:, :, :num_k_rope] = apply_rotary_enc(
  281. q,
  282. k[:, :, :num_k_rope],
  283. self.freqs_cis,
  284. repeat_freqs_k=self.rope_k_repeat,
  285. )
  286. dropout_p = self.dropout_p if self.training else 0.0
  287. # Attention
  288. # with torch.backends.cuda.sdp_kernel(
  289. # enable_flash=USE_FLASH_ATTN,
  290. # # if Flash attention kernel is off, then math kernel needs to be enabled
  291. # enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
  292. # enable_mem_efficient=OLD_GPU,
  293. # ):
  294. # Let's trust the dispatcher....
  295. if self.use_fa3:
  296. from sam3.perflib.fa3 import flash_attn_func
  297. assert dropout_p == 0.0
  298. out = flash_attn_func(
  299. q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
  300. ).transpose(1, 2)
  301. else:
  302. torch.backends.cuda.enable_flash_sdp(True)
  303. torch.backends.cuda.enable_math_sdp(True)
  304. torch.backends.cuda.enable_mem_efficient_sdp(True)
  305. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  306. out = self._recombine_heads(out)
  307. out = self.out_proj(out)
  308. return out