decoder.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Transformer decoder.
  5. Inspired from Pytorch's version, adds the pre-norm variant
  6. """
  7. from typing import Any, Dict, List, Optional
  8. import numpy as np
  9. import torch
  10. from sam3.sam.transformer import RoPEAttention
  11. from torch import nn, Tensor
  12. from torchvision.ops.roi_align import RoIAlign
  13. from .act_ckpt_utils import activation_ckpt_wrapper
  14. from .box_ops import box_cxcywh_to_xyxy
  15. from .model_misc import (
  16. gen_sineembed_for_position,
  17. get_activation_fn,
  18. get_clones,
  19. inverse_sigmoid,
  20. MLP,
  21. )
  22. class TransformerDecoderLayer(nn.Module):
  23. def __init__(
  24. self,
  25. activation: str,
  26. d_model: int,
  27. dim_feedforward: int,
  28. dropout: float,
  29. cross_attention: nn.Module,
  30. n_heads: int,
  31. use_text_cross_attention: bool = False,
  32. ):
  33. super().__init__()
  34. # cross attention
  35. self.cross_attn = cross_attention
  36. self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
  37. self.norm1 = nn.LayerNorm(d_model)
  38. # cross attention text
  39. self.use_text_cross_attention = use_text_cross_attention
  40. if use_text_cross_attention:
  41. self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
  42. self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
  43. self.catext_norm = nn.LayerNorm(d_model)
  44. # self attention
  45. self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
  46. self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
  47. self.norm2 = nn.LayerNorm(d_model)
  48. # ffn
  49. self.linear1 = nn.Linear(d_model, dim_feedforward)
  50. self.activation = get_activation_fn(activation)
  51. self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
  52. self.linear2 = nn.Linear(dim_feedforward, d_model)
  53. self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
  54. self.norm3 = nn.LayerNorm(d_model)
  55. @staticmethod
  56. def with_pos_embed(tensor, pos):
  57. return tensor if pos is None else tensor + pos
  58. def forward_ffn(self, tgt):
  59. with torch.amp.autocast(device_type="cuda", enabled=False):
  60. tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
  61. tgt = tgt + self.dropout4(tgt2)
  62. tgt = self.norm3(tgt)
  63. return tgt
  64. def forward(
  65. self,
  66. # for tgt
  67. tgt: Optional[Tensor], # nq, bs, d_model
  68. tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
  69. tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
  70. tgt_key_padding_mask: Optional[Tensor] = None,
  71. tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
  72. memory_text: Optional[Tensor] = None, # num_token, bs, d_model
  73. text_attention_mask: Optional[Tensor] = None, # bs, num_token
  74. # for memory
  75. memory: Optional[Tensor] = None, # hw, bs, d_model
  76. memory_key_padding_mask: Optional[Tensor] = None,
  77. memory_level_start_index: Optional[Tensor] = None, # num_levels
  78. memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
  79. memory_pos: Optional[Tensor] = None, # pos for memory
  80. # sa
  81. self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
  82. cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
  83. # dac
  84. dac=False,
  85. dac_use_selfatt_ln=True,
  86. presence_token=None,
  87. # skip inside deformable attn
  88. identity=0.0,
  89. **kwargs, # additional kwargs for compatibility
  90. ):
  91. """
  92. Input:
  93. - tgt/tgt_query_pos: nq, bs, d_model
  94. -
  95. """
  96. # self attention
  97. if self.self_attn is not None:
  98. if dac:
  99. # we only apply self attention to the first half of the queries
  100. assert tgt.shape[0] % 2 == 0
  101. num_o2o_queries = tgt.shape[0] // 2
  102. tgt_o2o = tgt[:num_o2o_queries]
  103. tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
  104. tgt_o2m = tgt[num_o2o_queries:]
  105. else:
  106. tgt_o2o = tgt
  107. tgt_query_pos_o2o = tgt_query_pos
  108. if presence_token is not None:
  109. tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
  110. tgt_query_pos_o2o = torch.cat(
  111. [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
  112. )
  113. tgt_query_pos = torch.cat(
  114. [torch.zeros_like(presence_token), tgt_query_pos], dim=0
  115. )
  116. q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
  117. tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0]
  118. tgt_o2o = tgt_o2o + self.dropout2(tgt2)
  119. if dac:
  120. if not dac_use_selfatt_ln:
  121. tgt_o2o = self.norm2(tgt_o2o)
  122. tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine
  123. if dac_use_selfatt_ln:
  124. tgt = self.norm2(tgt)
  125. else:
  126. tgt = tgt_o2o
  127. tgt = self.norm2(tgt)
  128. if self.use_text_cross_attention:
  129. tgt2 = self.ca_text(
  130. self.with_pos_embed(tgt, tgt_query_pos),
  131. memory_text,
  132. memory_text,
  133. key_padding_mask=text_attention_mask,
  134. )[0]
  135. tgt = tgt + self.catext_dropout(tgt2)
  136. tgt = self.catext_norm(tgt)
  137. if presence_token is not None:
  138. presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
  139. cross_attn_mask = torch.cat(
  140. [presence_token_mask, cross_attn_mask], dim=1
  141. ) # (bs*nheads, 1+nq, hw)
  142. # Cross attention to image
  143. tgt2 = self.cross_attn(
  144. query=self.with_pos_embed(tgt, tgt_query_pos),
  145. key=self.with_pos_embed(memory, memory_pos),
  146. value=memory,
  147. attn_mask=cross_attn_mask,
  148. key_padding_mask=(
  149. memory_key_padding_mask.transpose(0, 1)
  150. if memory_key_padding_mask is not None
  151. else None
  152. ),
  153. )[0]
  154. tgt = tgt + self.dropout1(tgt2)
  155. tgt = self.norm1(tgt)
  156. # ffn
  157. tgt = self.forward_ffn(tgt)
  158. presence_token_out = None
  159. if presence_token is not None:
  160. presence_token_out = tgt[:1]
  161. tgt = tgt[1:]
  162. return tgt, presence_token_out
  163. class TransformerDecoder(nn.Module):
  164. def __init__(
  165. self,
  166. d_model: int,
  167. frozen: bool,
  168. interaction_layer,
  169. layer,
  170. num_layers: int,
  171. num_queries: int,
  172. return_intermediate: bool,
  173. box_refine: bool = False,
  174. num_o2m_queries: int = 0,
  175. dac: bool = False,
  176. boxRPB: str = "none",
  177. # Experimental: An object query for SAM 2 tasks
  178. instance_query: bool = False,
  179. # Defines the number of additional instance queries,
  180. # 1 or 4 are the most likely for single vs multi mask support
  181. num_instances: int = 1, # Irrelevant if instance_query is False
  182. dac_use_selfatt_ln: bool = True,
  183. use_act_checkpoint: bool = False,
  184. compile_mode=None,
  185. presence_token: bool = False,
  186. clamp_presence_logits: bool = True,
  187. clamp_presence_logit_max_val: float = 10.0,
  188. use_normed_output_consistently: bool = True,
  189. separate_box_head_instance: bool = False,
  190. separate_norm_instance: bool = False,
  191. resolution: Optional[int] = None,
  192. stride: Optional[int] = None,
  193. ):
  194. super().__init__()
  195. self.d_model = d_model
  196. self.layers = get_clones(layer, num_layers)
  197. self.fine_layers = (
  198. get_clones(interaction_layer, num_layers)
  199. if interaction_layer is not None
  200. else [None] * num_layers
  201. )
  202. self.num_layers = num_layers
  203. self.num_queries = num_queries
  204. self.dac = dac
  205. if dac:
  206. self.num_o2m_queries = num_queries
  207. tot_num_queries = num_queries
  208. else:
  209. self.num_o2m_queries = num_o2m_queries
  210. tot_num_queries = num_queries + num_o2m_queries
  211. self.norm = nn.LayerNorm(d_model)
  212. self.return_intermediate = return_intermediate
  213. self.bbox_embed = MLP(d_model, d_model, 4, 3)
  214. self.query_embed = nn.Embedding(tot_num_queries, d_model)
  215. self.instance_query_embed = None
  216. self.instance_query_reference_points = None
  217. self.use_instance_query = instance_query
  218. self.num_instances = num_instances
  219. self.use_normed_output_consistently = use_normed_output_consistently
  220. self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
  221. self.instance_bbox_embed = None
  222. if separate_box_head_instance:
  223. self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
  224. if instance_query:
  225. self.instance_query_embed = nn.Embedding(num_instances, d_model)
  226. self.box_refine = box_refine
  227. if box_refine:
  228. nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
  229. nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
  230. self.reference_points = nn.Embedding(num_queries, 4)
  231. if instance_query:
  232. self.instance_reference_points = nn.Embedding(num_instances, 4)
  233. assert boxRPB in ["none", "log", "linear", "both"]
  234. self.boxRPB = boxRPB
  235. if boxRPB != "none":
  236. try:
  237. nheads = self.layers[0].cross_attn_image.num_heads
  238. except AttributeError:
  239. nheads = self.layers[0].cross_attn.num_heads
  240. n_input = 4 if boxRPB == "both" else 2
  241. self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
  242. self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
  243. self.compilable_cord_cache = None
  244. self.compilable_stored_size = None
  245. self.coord_cache = {}
  246. if resolution is not None and stride is not None:
  247. feat_size = resolution // stride
  248. coords_h, coords_w = self._get_coords(
  249. feat_size, feat_size, device="cuda"
  250. )
  251. self.compilable_cord_cache = (coords_h, coords_w)
  252. self.compilable_stored_size = (feat_size, feat_size)
  253. self.roi_pooler = (
  254. RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
  255. if interaction_layer is not None
  256. else None
  257. )
  258. if frozen:
  259. for p in self.parameters():
  260. p.requires_grad_(False)
  261. self.presence_token = None
  262. self.clamp_presence_logits = clamp_presence_logits
  263. self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
  264. if presence_token:
  265. self.presence_token = nn.Embedding(1, d_model)
  266. self.presence_token_head = MLP(d_model, d_model, 1, 3)
  267. self.presence_token_out_norm = nn.LayerNorm(d_model)
  268. self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
  269. self.dac_use_selfatt_ln = dac_use_selfatt_ln
  270. self.use_act_checkpoint = use_act_checkpoint
  271. nn.init.normal_(self.query_embed.weight.data)
  272. if self.instance_query_embed is not None:
  273. nn.init.normal_(self.instance_query_embed.weight.data)
  274. assert self.roi_pooler is None
  275. assert self.return_intermediate, "support return_intermediate only"
  276. assert self.box_refine, "support box refine only"
  277. self.compile_mode = compile_mode
  278. self.compiled = False
  279. # We defer compilation till after the first forward, to first warm-up the boxRPB cache
  280. # assign layer index to each layer so that some layers can decide what to do
  281. # based on which layer index they are (e.g. cross attention to memory bank only
  282. # in selected layers)
  283. for layer_idx, layer in enumerate(self.layers):
  284. layer.layer_idx = layer_idx
  285. @staticmethod
  286. def _get_coords(H, W, device):
  287. coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H
  288. coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W
  289. return coords_h, coords_w
  290. def _get_rpb_matrix(self, reference_boxes, feat_size):
  291. H, W = feat_size
  292. boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
  293. bs, num_queries, _ = boxes_xyxy.shape
  294. if self.compilable_cord_cache is None:
  295. self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
  296. self.compilable_stored_size = (H, W)
  297. if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
  298. H,
  299. W,
  300. ):
  301. # good, hitting the cache, will be compilable
  302. coords_h, coords_w = self.compilable_cord_cache
  303. else:
  304. # cache miss, will create compilation issue
  305. # In case we're not compiling, we'll still rely on the dict-based cache
  306. if feat_size not in self.coord_cache:
  307. self.coord_cache[feat_size] = self._get_coords(
  308. H, W, reference_boxes.device
  309. )
  310. coords_h, coords_w = self.coord_cache[feat_size]
  311. assert coords_h.shape == (H,)
  312. assert coords_w.shape == (W,)
  313. deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
  314. deltas_y = deltas_y.view(bs, num_queries, -1, 2)
  315. deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
  316. deltas_x = deltas_x.view(bs, num_queries, -1, 2)
  317. if self.boxRPB in ["log", "both"]:
  318. deltas_x_log = deltas_x * 8 # normalize to -8, 8
  319. deltas_x_log = (
  320. torch.sign(deltas_x_log)
  321. * torch.log2(torch.abs(deltas_x_log) + 1.0)
  322. / np.log2(8)
  323. )
  324. deltas_y_log = deltas_y * 8 # normalize to -8, 8
  325. deltas_y_log = (
  326. torch.sign(deltas_y_log)
  327. * torch.log2(torch.abs(deltas_y_log) + 1.0)
  328. / np.log2(8)
  329. )
  330. if self.boxRPB == "log":
  331. deltas_x = deltas_x_log
  332. deltas_y = deltas_y_log
  333. else:
  334. deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
  335. deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
  336. if self.training:
  337. assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
  338. deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)(
  339. x=deltas_x,
  340. act_ckpt_enable=self.training and self.use_act_checkpoint,
  341. ) # bs, num_queries, W, n_heads
  342. deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)(
  343. x=deltas_y,
  344. act_ckpt_enable=self.training and self.use_act_checkpoint,
  345. ) # bs, num_queries, H, n_heads
  346. if not torch.compiler.is_dynamo_compiling():
  347. assert deltas_x.shape[:3] == (bs, num_queries, W)
  348. assert deltas_y.shape[:3] == (bs, num_queries, H)
  349. B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
  350. 2
  351. ) # bs, num_queries, H, W, n_heads
  352. if not torch.compiler.is_dynamo_compiling():
  353. assert B.shape[:4] == (bs, num_queries, H, W)
  354. B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
  355. B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
  356. B = B.contiguous() # memeff attn likes ordered strides
  357. if not torch.compiler.is_dynamo_compiling():
  358. assert B.shape[2:] == (num_queries, H * W)
  359. return B
  360. def forward(
  361. self,
  362. tgt,
  363. memory,
  364. tgt_mask: Optional[Tensor] = None,
  365. memory_mask: Optional[Tensor] = None,
  366. tgt_key_padding_mask: Optional[Tensor] = None,
  367. memory_key_padding_mask: Optional[Tensor] = None,
  368. pos: Optional[Tensor] = None,
  369. reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4
  370. # for memory
  371. level_start_index: Optional[Tensor] = None, # num_levels
  372. spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
  373. valid_ratios: Optional[Tensor] = None,
  374. # for text
  375. memory_text: Optional[Tensor] = None,
  376. text_attention_mask: Optional[Tensor] = None,
  377. # if `apply_dac` is None, it will default to `self.dac`
  378. apply_dac: Optional[bool] = None,
  379. is_instance_prompt=False,
  380. decoder_extra_kwargs: Optional[Dict] = None,
  381. # ROI memory bank
  382. obj_roi_memory_feat=None,
  383. obj_roi_memory_mask=None,
  384. box_head_trk=None,
  385. ):
  386. """
  387. Input:
  388. - tgt: nq, bs, d_model
  389. - memory: \\sum{hw}, bs, d_model
  390. - pos: \\sum{hw}, bs, d_model
  391. - reference_boxes: nq, bs, 4 (after sigmoid)
  392. - valid_ratios/spatial_shapes: bs, nlevel, 2
  393. """
  394. if memory_mask is not None:
  395. assert self.boxRPB == "none", (
  396. "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
  397. )
  398. apply_dac = apply_dac if apply_dac is not None else self.dac
  399. if apply_dac:
  400. assert (tgt.shape[0] == self.num_queries) or (
  401. self.use_instance_query
  402. and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
  403. )
  404. tgt = tgt.repeat(2, 1, 1)
  405. # note that we don't tile tgt_mask, since DAC doesn't
  406. # use self-attention in o2m queries
  407. if reference_boxes is not None:
  408. assert (reference_boxes.shape[0] == self.num_queries) or (
  409. self.use_instance_query
  410. and (
  411. reference_boxes.shape[0]
  412. == self.instance_query_embed.num_embeddings
  413. )
  414. )
  415. reference_boxes = reference_boxes.repeat(2, 1, 1)
  416. bs = tgt.shape[1]
  417. intermediate = []
  418. intermediate_presence_logits = []
  419. presence_feats = None
  420. if self.box_refine:
  421. if reference_boxes is None:
  422. # In this case, we're in a one-stage model, so we generate the reference boxes
  423. reference_boxes = self.reference_points.weight.unsqueeze(1)
  424. reference_boxes = (
  425. reference_boxes.repeat(2, bs, 1)
  426. if apply_dac
  427. else reference_boxes.repeat(1, bs, 1)
  428. )
  429. reference_boxes = reference_boxes.sigmoid()
  430. intermediate_ref_boxes = [reference_boxes]
  431. else:
  432. reference_boxes = None
  433. intermediate_ref_boxes = None
  434. output = tgt
  435. presence_out = None
  436. if self.presence_token is not None and is_instance_prompt is False:
  437. # expand to batch dim
  438. presence_out = self.presence_token.weight[None].expand(1, bs, -1)
  439. box_head = self.bbox_embed
  440. if is_instance_prompt and self.instance_bbox_embed is not None:
  441. box_head = self.instance_bbox_embed
  442. out_norm = self.norm
  443. if is_instance_prompt and self.instance_norm is not None:
  444. out_norm = self.instance_norm
  445. for layer_idx, layer in enumerate(self.layers):
  446. reference_points_input = (
  447. reference_boxes[:, :, None]
  448. * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
  449. ) # nq, bs, nlevel, 4
  450. query_sine_embed = gen_sineembed_for_position(
  451. reference_points_input[:, :, 0, :], self.d_model
  452. ) # nq, bs, d_model*2
  453. # conditional query
  454. query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
  455. if self.boxRPB != "none" and reference_boxes is not None:
  456. assert spatial_shapes.shape[0] == 1, (
  457. "only single scale support implemented"
  458. )
  459. memory_mask = self._get_rpb_matrix(
  460. reference_boxes,
  461. (spatial_shapes[0, 0], spatial_shapes[0, 1]),
  462. )
  463. memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
  464. if self.training:
  465. assert self.use_act_checkpoint, (
  466. "Activation checkpointing not enabled in the decoder"
  467. )
  468. output, presence_out = activation_ckpt_wrapper(layer)(
  469. tgt=output,
  470. tgt_query_pos=query_pos,
  471. tgt_query_sine_embed=query_sine_embed,
  472. tgt_key_padding_mask=tgt_key_padding_mask,
  473. tgt_reference_points=reference_points_input,
  474. memory_text=memory_text,
  475. text_attention_mask=text_attention_mask,
  476. memory=memory,
  477. memory_key_padding_mask=memory_key_padding_mask,
  478. memory_level_start_index=level_start_index,
  479. memory_spatial_shapes=spatial_shapes,
  480. memory_pos=pos,
  481. self_attn_mask=tgt_mask,
  482. cross_attn_mask=memory_mask,
  483. dac=apply_dac,
  484. dac_use_selfatt_ln=self.dac_use_selfatt_ln,
  485. presence_token=presence_out,
  486. **(decoder_extra_kwargs or {}),
  487. act_ckpt_enable=self.training and self.use_act_checkpoint,
  488. # ROI memory bank
  489. obj_roi_memory_feat=obj_roi_memory_feat,
  490. obj_roi_memory_mask=obj_roi_memory_mask,
  491. )
  492. # iter update
  493. if self.box_refine:
  494. reference_before_sigmoid = inverse_sigmoid(reference_boxes)
  495. if box_head_trk is None:
  496. # delta_unsig = self.bbox_embed(output)
  497. if not self.use_normed_output_consistently:
  498. delta_unsig = box_head(output)
  499. else:
  500. delta_unsig = box_head(out_norm(output))
  501. else:
  502. # box_head_trk use a separate box head for tracking queries
  503. Q_det = decoder_extra_kwargs["Q_det"]
  504. assert output.size(0) >= Q_det
  505. delta_unsig_det = self.bbox_embed(output[:Q_det])
  506. delta_unsig_trk = box_head_trk(output[Q_det:])
  507. delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
  508. outputs_unsig = delta_unsig + reference_before_sigmoid
  509. new_reference_points = outputs_unsig.sigmoid()
  510. reference_boxes = new_reference_points.detach()
  511. if layer_idx != self.num_layers - 1:
  512. intermediate_ref_boxes.append(new_reference_points)
  513. else:
  514. raise NotImplementedError("not implemented yet")
  515. intermediate.append(out_norm(output))
  516. if self.presence_token is not None and is_instance_prompt is False:
  517. # norm, mlp head
  518. intermediate_layer_presence_logits = self.presence_token_head(
  519. self.presence_token_out_norm(presence_out)
  520. ).squeeze(-1)
  521. # clamp to mitigate numerical issues
  522. if self.clamp_presence_logits:
  523. intermediate_layer_presence_logits.clamp(
  524. min=-self.clamp_presence_logit_max_val,
  525. max=self.clamp_presence_logit_max_val,
  526. )
  527. intermediate_presence_logits.append(intermediate_layer_presence_logits)
  528. presence_feats = presence_out.clone()
  529. if not self.compiled and self.compile_mode is not None:
  530. self.forward = torch.compile(
  531. self.forward, mode=self.compile_mode, fullgraph=True
  532. )
  533. self.compiled = True
  534. return (
  535. torch.stack(intermediate),
  536. torch.stack(intermediate_ref_boxes),
  537. (
  538. torch.stack(intermediate_presence_logits)
  539. if self.presence_token is not None and is_instance_prompt is False
  540. else None
  541. ),
  542. presence_feats,
  543. )
  544. class TransformerEncoderCrossAttention(nn.Module):
  545. def __init__(
  546. self,
  547. d_model: int,
  548. frozen: bool,
  549. pos_enc_at_input: bool,
  550. layer,
  551. num_layers: int,
  552. use_act_checkpoint: bool = False,
  553. batch_first: bool = False, # Do layers expect batch first input?
  554. # which layers to exclude cross attention? default: None, means all
  555. # layers use cross attention
  556. remove_cross_attention_layers: Optional[list] = None,
  557. ):
  558. super().__init__()
  559. self.d_model = d_model
  560. self.layers = get_clones(layer, num_layers)
  561. self.num_layers = num_layers
  562. self.norm = nn.LayerNorm(d_model)
  563. self.pos_enc_at_input = pos_enc_at_input
  564. self.use_act_checkpoint = use_act_checkpoint
  565. if frozen:
  566. for p in self.parameters():
  567. p.requires_grad_(False)
  568. self.batch_first = batch_first
  569. # remove cross attention layers if specified
  570. self.remove_cross_attention_layers = [False] * self.num_layers
  571. if remove_cross_attention_layers is not None:
  572. for i in remove_cross_attention_layers:
  573. self.remove_cross_attention_layers[i] = True
  574. assert len(self.remove_cross_attention_layers) == len(self.layers)
  575. for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers):
  576. if remove_cross_attention:
  577. self.layers[i].cross_attn_image = None
  578. self.layers[i].norm2 = None
  579. self.layers[i].dropout2 = None
  580. def forward(
  581. self,
  582. src, # self-attention inputs
  583. prompt, # cross-attention inputs
  584. src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs
  585. prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs
  586. src_key_padding_mask: Optional[Tensor] = None,
  587. prompt_key_padding_mask: Optional[Tensor] = None,
  588. src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
  589. prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
  590. feat_sizes: Optional[list] = None,
  591. num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
  592. ):
  593. if isinstance(src, list):
  594. assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list)
  595. assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1
  596. src, src_key_padding_mask, src_pos = (
  597. src[0],
  598. src_key_padding_mask[0],
  599. src_pos[0],
  600. )
  601. assert src.shape[1] == prompt.shape[1], (
  602. "Batch size must be the same for src and prompt"
  603. )
  604. output = src
  605. if self.pos_enc_at_input and src_pos is not None:
  606. output = output + 0.1 * src_pos
  607. if self.batch_first:
  608. # Convert to batch first
  609. output = output.transpose(0, 1)
  610. src_pos = src_pos.transpose(0, 1)
  611. prompt = prompt.transpose(0, 1)
  612. prompt_pos = prompt_pos.transpose(0, 1)
  613. for layer in self.layers:
  614. kwds = {}
  615. if isinstance(layer.cross_attn_image, RoPEAttention):
  616. kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
  617. output = activation_ckpt_wrapper(layer)(
  618. tgt=output,
  619. memory=prompt,
  620. tgt_mask=src_mask,
  621. memory_mask=prompt_mask,
  622. tgt_key_padding_mask=src_key_padding_mask,
  623. memory_key_padding_mask=prompt_key_padding_mask,
  624. pos=prompt_pos,
  625. query_pos=src_pos,
  626. dac=False,
  627. attn_bias=None,
  628. act_ckpt_enable=self.training and self.use_act_checkpoint,
  629. **kwds,
  630. )
  631. normed_output = self.norm(output)
  632. if self.batch_first:
  633. # Convert back to seq first
  634. normed_output = normed_output.transpose(0, 1)
  635. src_pos = src_pos.transpose(0, 1)
  636. return {
  637. "memory": normed_output,
  638. "pos_embed": src_pos,
  639. "padding_mask": src_key_padding_mask,
  640. }
  641. class TransformerDecoderLayerv1(nn.Module):
  642. def __init__(
  643. self,
  644. activation: str,
  645. cross_attention: nn.Module,
  646. d_model: int,
  647. dim_feedforward: int,
  648. dropout: float,
  649. pos_enc_at_attn: bool,
  650. pos_enc_at_cross_attn_keys: bool,
  651. pos_enc_at_cross_attn_queries: bool,
  652. pre_norm: bool,
  653. self_attention: nn.Module,
  654. ):
  655. super().__init__()
  656. self.d_model = d_model
  657. self.dim_feedforward = dim_feedforward
  658. self.dropout_value = dropout
  659. self.self_attn = self_attention
  660. self.cross_attn_image = cross_attention
  661. # Implementation of Feedforward model
  662. self.linear1 = nn.Linear(d_model, dim_feedforward)
  663. self.dropout = nn.Dropout(dropout)
  664. self.linear2 = nn.Linear(dim_feedforward, d_model)
  665. self.norm1 = nn.LayerNorm(d_model)
  666. self.norm2 = nn.LayerNorm(d_model)
  667. self.norm3 = nn.LayerNorm(d_model)
  668. self.dropout1 = nn.Dropout(dropout)
  669. self.dropout2 = nn.Dropout(dropout)
  670. self.dropout3 = nn.Dropout(dropout)
  671. self.activation_str = activation
  672. self.activation = get_activation_fn(activation)
  673. self.pre_norm = pre_norm
  674. self.pos_enc_at_attn = pos_enc_at_attn
  675. self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
  676. self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
  677. def forward_post(
  678. self,
  679. tgt,
  680. memory,
  681. tgt_mask: Optional[Tensor] = None,
  682. memory_mask: Optional[Tensor] = None,
  683. tgt_key_padding_mask: Optional[Tensor] = None,
  684. memory_key_padding_mask: Optional[Tensor] = None,
  685. pos: Optional[Tensor] = None,
  686. query_pos: Optional[Tensor] = None,
  687. **kwargs,
  688. ):
  689. q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
  690. # Self attention
  691. tgt2 = self.self_attn(
  692. q,
  693. k,
  694. value=tgt,
  695. attn_mask=tgt_mask,
  696. key_padding_mask=tgt_key_padding_mask,
  697. )[0]
  698. tgt = tgt + self.dropout1(tgt2)
  699. tgt = self.norm1(tgt)
  700. # Cross attention to image
  701. tgt2 = self.cross_attn_image(
  702. query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
  703. key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  704. value=memory,
  705. attn_mask=memory_mask,
  706. key_padding_mask=memory_key_padding_mask,
  707. )[0]
  708. tgt = tgt + self.dropout2(tgt2)
  709. tgt = self.norm2(tgt)
  710. # FFN
  711. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  712. tgt = tgt + self.dropout3(tgt2)
  713. tgt = self.norm3(tgt)
  714. return tgt
  715. def forward_pre(
  716. self,
  717. tgt,
  718. memory,
  719. dac: bool = False,
  720. tgt_mask: Optional[Tensor] = None,
  721. memory_mask: Optional[Tensor] = None,
  722. tgt_key_padding_mask: Optional[Tensor] = None,
  723. memory_key_padding_mask: Optional[Tensor] = None,
  724. pos: Optional[Tensor] = None,
  725. query_pos: Optional[Tensor] = None,
  726. attn_bias: Optional[Tensor] = None,
  727. **kwargs,
  728. ):
  729. if dac:
  730. # we only apply self attention to the first half of the queries
  731. assert tgt.shape[0] % 2 == 0
  732. other_tgt = tgt[tgt.shape[0] // 2 :]
  733. tgt = tgt[: tgt.shape[0] // 2]
  734. tgt2 = self.norm1(tgt)
  735. q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
  736. tgt2 = self.self_attn(
  737. q,
  738. k,
  739. value=tgt2,
  740. attn_mask=tgt_mask,
  741. key_padding_mask=tgt_key_padding_mask,
  742. )[0]
  743. tgt = tgt + self.dropout1(tgt2)
  744. if dac:
  745. # Recombine
  746. tgt = torch.cat((tgt, other_tgt), dim=0)
  747. tgt2 = self.norm2(tgt)
  748. tgt2 = self.cross_attn_image(
  749. query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
  750. key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  751. value=memory,
  752. attn_mask=memory_mask,
  753. key_padding_mask=memory_key_padding_mask,
  754. attn_bias=attn_bias,
  755. )[0]
  756. tgt = tgt + self.dropout2(tgt2)
  757. tgt2 = self.norm3(tgt)
  758. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  759. tgt = tgt + self.dropout3(tgt2)
  760. return tgt
  761. def forward(
  762. self,
  763. tgt,
  764. memory,
  765. dac: bool = False,
  766. tgt_mask: Optional[Tensor] = None,
  767. memory_mask: Optional[Tensor] = None,
  768. tgt_key_padding_mask: Optional[Tensor] = None,
  769. memory_key_padding_mask: Optional[Tensor] = None,
  770. pos: Optional[Tensor] = None,
  771. query_pos: Optional[Tensor] = None,
  772. attn_bias: Optional[Tensor] = None,
  773. **kwds: Any,
  774. ) -> torch.Tensor:
  775. fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
  776. return fwd_fn(
  777. tgt,
  778. memory,
  779. dac=dac,
  780. tgt_mask=tgt_mask,
  781. memory_mask=memory_mask,
  782. tgt_key_padding_mask=tgt_key_padding_mask,
  783. memory_key_padding_mask=memory_key_padding_mask,
  784. pos=pos,
  785. query_pos=query_pos,
  786. attn_bias=attn_bias,
  787. **kwds,
  788. )
  789. class TransformerDecoderLayerv2(TransformerDecoderLayerv1):
  790. def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any):
  791. super().__init__(*args, **kwds)
  792. self.cross_attention_first = cross_attention_first
  793. def _forward_sa(self, tgt, query_pos):
  794. # Self-Attention
  795. tgt2 = self.norm1(tgt)
  796. q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
  797. tgt2 = self.self_attn(q, k, v=tgt2)
  798. tgt = tgt + self.dropout1(tgt2)
  799. return tgt
  800. def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
  801. if self.cross_attn_image is None:
  802. return tgt
  803. kwds = {}
  804. if num_k_exclude_rope > 0:
  805. assert isinstance(self.cross_attn_image, RoPEAttention)
  806. kwds = {"num_k_exclude_rope": num_k_exclude_rope}
  807. # Cross-Attention
  808. tgt2 = self.norm2(tgt)
  809. tgt2 = self.cross_attn_image(
  810. q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
  811. k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  812. v=memory,
  813. **kwds,
  814. )
  815. tgt = tgt + self.dropout2(tgt2)
  816. return tgt
  817. def forward_pre(
  818. self,
  819. tgt,
  820. memory,
  821. dac: bool,
  822. tgt_mask: Optional[Tensor] = None,
  823. memory_mask: Optional[Tensor] = None,
  824. tgt_key_padding_mask: Optional[Tensor] = None,
  825. memory_key_padding_mask: Optional[Tensor] = None,
  826. pos: Optional[Tensor] = None,
  827. query_pos: Optional[Tensor] = None,
  828. attn_bias: Optional[Tensor] = None,
  829. num_k_exclude_rope: int = 0,
  830. ):
  831. assert dac is False
  832. assert tgt_mask is None
  833. assert memory_mask is None
  834. assert tgt_key_padding_mask is None
  835. assert memory_key_padding_mask is None
  836. assert attn_bias is None
  837. if self.cross_attention_first:
  838. tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
  839. tgt = self._forward_sa(tgt, query_pos)
  840. else:
  841. tgt = self._forward_sa(tgt, query_pos)
  842. tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
  843. # MLP
  844. tgt2 = self.norm3(tgt)
  845. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  846. tgt = tgt + self.dropout3(tgt2)
  847. return tgt
  848. def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
  849. if self.pre_norm:
  850. return self.forward_pre(*args, **kwds)
  851. raise NotImplementedError