geometry_encoders.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Tuple
  4. import torch
  5. import torch.nn as nn
  6. import torchvision
  7. from typing_extensions import override
  8. from .act_ckpt_utils import activation_ckpt_wrapper
  9. from .box_ops import box_cxcywh_to_xyxy
  10. from .model_misc import get_clones
  11. def is_right_padded(mask):
  12. """Given a padding mask (following pytorch convention, 1s for padded values),
  13. returns whether the padding is on the right or not."""
  14. return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
  15. def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
  16. """
  17. Concatenates two right-padded sequences, such that the resulting sequence
  18. is contiguous and also right-padded.
  19. Following pytorch's convention, tensors are sequence first, and the mask are
  20. batch first, with 1s for padded values.
  21. :param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
  22. :param mask1: A tensor of shape (batch_size, seq1_length).
  23. :param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
  24. :param mask2: A tensor of shape (batch_size, seq2_length).
  25. :param return_index: If True, also returns the index of the ids of the element of seq2
  26. in the concatenated sequence. This can be used to retrieve the elements of seq2
  27. :return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
  28. otherwise (concatenated_sequence, concatenated_mask, index).
  29. """
  30. seq1_length, batch_size, hidden_size = seq1.shape
  31. seq2_length, batch_size, hidden_size = seq2.shape
  32. assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
  33. assert hidden_size == seq1.size(2) == seq2.size(2)
  34. assert seq1_length == mask1.size(1)
  35. assert seq2_length == mask2.size(1)
  36. torch._assert_async(is_right_padded(mask1))
  37. torch._assert_async(is_right_padded(mask2))
  38. actual_seq1_lengths = (~mask1).sum(dim=-1)
  39. actual_seq2_lengths = (~mask2).sum(dim=-1)
  40. final_lengths = actual_seq1_lengths + actual_seq2_lengths
  41. max_length = seq1_length + seq2_length
  42. concatenated_mask = (
  43. torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1)
  44. >= final_lengths[:, None]
  45. )
  46. # (max_len, batch_size, hidden_size)
  47. concatenated_sequence = torch.zeros(
  48. (max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype
  49. )
  50. concatenated_sequence[:seq1_length, :, :] = seq1
  51. # At this point, the element of seq1 are in the right place
  52. # We just need to shift the elements of seq2
  53. index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
  54. index = index + actual_seq1_lengths[None]
  55. concatenated_sequence = concatenated_sequence.scatter(
  56. 0, index[:, :, None].expand(-1, -1, hidden_size), seq2
  57. )
  58. if return_index:
  59. return concatenated_sequence, concatenated_mask, index
  60. return concatenated_sequence, concatenated_mask
  61. class Prompt:
  62. """Utility class to manipulate geometric prompts.
  63. We expect the sequences in pytorch convention, that is sequence first, batch second
  64. The dimensions are expected as follows:
  65. box_embeddings shape: N_boxes x B x C_box
  66. box_mask shape: B x N_boxes. Can be None if nothing is masked out
  67. point_embeddings shape: N_points x B x C_point
  68. point_mask shape: B x N_points. Can be None if nothing is masked out
  69. mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask
  70. mask_mask shape: B x N_masks. Can be None if nothing is masked out
  71. We also store positive/negative labels. These tensors are also stored batch-first
  72. If they are None, we'll assume positive labels everywhere
  73. box_labels: long tensor of shape N_boxes x B
  74. point_labels: long tensor of shape N_points x B
  75. mask_labels: long tensor of shape N_masks x B
  76. """
  77. def __init__(
  78. self,
  79. box_embeddings=None,
  80. box_mask=None,
  81. point_embeddings=None,
  82. point_mask=None,
  83. box_labels=None,
  84. point_labels=None,
  85. mask_embeddings=None,
  86. mask_mask=None, # Attention mask for mask prompt
  87. mask_labels=None,
  88. ):
  89. # Check for null prompt
  90. if (
  91. box_embeddings is None
  92. and point_embeddings is None
  93. and mask_embeddings is None
  94. ):
  95. self.box_embeddings = None
  96. self.box_labels = None
  97. self.box_mask = None
  98. self.point_embeddings = None
  99. self.point_labels = None
  100. self.point_mask = None
  101. self.mask_embeddings = None
  102. self.mask_mask = None
  103. # Masks are assumed positive only for now.
  104. self.mask_labels = None
  105. return
  106. # Get sequence lengths and device
  107. box_seq_len, point_seq_len, mask_seq_len, bs, device = (
  108. self._init_seq_len_and_device(
  109. box_embeddings, point_embeddings, mask_embeddings
  110. )
  111. )
  112. # Initialize embeds, labels, attention masks.
  113. box_embeddings, box_labels, box_mask = self._init_box(
  114. box_embeddings, box_labels, box_mask, box_seq_len, bs, device
  115. )
  116. point_embeddings, point_labels, point_mask = self._init_point(
  117. point_embeddings, point_labels, point_mask, point_seq_len, bs, device
  118. )
  119. mask_embeddings, mask_labels, mask_mask = self._init_mask(
  120. mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
  121. )
  122. # Dimension checks
  123. assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
  124. box_seq_len,
  125. bs,
  126. ], (
  127. f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
  128. )
  129. assert box_mask is not None and list(box_mask.shape) == [
  130. bs,
  131. box_seq_len,
  132. ], (
  133. f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
  134. )
  135. assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
  136. point_seq_len,
  137. bs,
  138. ], (
  139. f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
  140. )
  141. assert point_mask is not None and list(point_mask.shape) == [
  142. bs,
  143. point_seq_len,
  144. ], (
  145. f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
  146. )
  147. assert box_labels is not None and list(box_labels.shape) == [
  148. box_seq_len,
  149. bs,
  150. ], (
  151. f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
  152. )
  153. assert point_labels is not None and list(point_labels.shape) == [
  154. point_seq_len,
  155. bs,
  156. ], (
  157. f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
  158. )
  159. assert (
  160. # Allowed to be None, we leave it to the encoder to check for validity before encoding.
  161. mask_embeddings is None
  162. or list(mask_embeddings.shape[:2])
  163. == [
  164. mask_seq_len,
  165. bs,
  166. ]
  167. ), (
  168. f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
  169. )
  170. assert mask_mask is None or list(mask_mask.shape) == [
  171. bs,
  172. mask_seq_len,
  173. ], (
  174. f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
  175. )
  176. # Device checks
  177. assert box_embeddings is not None and box_embeddings.device == device, (
  178. f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
  179. )
  180. assert box_mask is not None and box_mask.device == device, (
  181. f"Expected box mask to be on device {device}, got {box_mask.device}"
  182. )
  183. assert box_labels is not None and box_labels.device == device, (
  184. f"Expected box labels to be on device {device}, got {box_labels.device}"
  185. )
  186. assert point_embeddings is not None and point_embeddings.device == device, (
  187. f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
  188. )
  189. assert point_mask is not None and point_mask.device == device, (
  190. f"Expected point mask to be on device {device}, got {point_mask.device}"
  191. )
  192. assert point_labels is not None and point_labels.device == device, (
  193. f"Expected point labels to be on device {device}, got {point_labels.device}"
  194. )
  195. assert mask_embeddings is None or mask_embeddings.device == device, (
  196. f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
  197. )
  198. assert mask_mask is None or mask_mask.device == device, (
  199. f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
  200. )
  201. self.box_embeddings = box_embeddings
  202. self.point_embeddings = point_embeddings
  203. self.box_mask = box_mask
  204. self.point_mask = point_mask
  205. self.box_labels = box_labels
  206. self.point_labels = point_labels
  207. self.mask_embeddings = mask_embeddings
  208. self.mask_labels = mask_labels
  209. self.mask_mask = mask_mask
  210. def _init_seq_len_and_device(
  211. self, box_embeddings, point_embeddings, mask_embeddings
  212. ):
  213. box_seq_len = point_seq_len = mask_seq_len = 0
  214. bs = None
  215. device = None
  216. if box_embeddings is not None:
  217. bs = box_embeddings.shape[1]
  218. box_seq_len = box_embeddings.shape[0]
  219. device = box_embeddings.device
  220. if point_embeddings is not None:
  221. point_seq_len = point_embeddings.shape[0]
  222. if bs is not None:
  223. assert bs == point_embeddings.shape[1], (
  224. f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
  225. )
  226. else:
  227. bs = point_embeddings.shape[1]
  228. if device is not None:
  229. assert device == point_embeddings.device, (
  230. "Device mismatch between box and point embeddings"
  231. )
  232. else:
  233. device = point_embeddings.device
  234. if mask_embeddings is not None:
  235. mask_seq_len = mask_embeddings.shape[0]
  236. if bs is not None:
  237. assert bs == mask_embeddings.shape[1], (
  238. f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
  239. )
  240. else:
  241. bs = mask_embeddings.shape[1]
  242. if device is not None:
  243. assert device == mask_embeddings.device, (
  244. "Device mismatch between box/point and mask embeddings."
  245. )
  246. else:
  247. device = mask_embeddings.device
  248. return box_seq_len, point_seq_len, mask_seq_len, bs, device
  249. def _init_box(self, box_embeddings, box_labels, box_mask, box_seq_len, bs, device):
  250. if box_embeddings is None:
  251. box_embeddings = torch.zeros(box_seq_len, bs, 4, device=device)
  252. if box_labels is None:
  253. box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
  254. if box_mask is None:
  255. box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
  256. return box_embeddings, box_labels, box_mask
  257. def _init_point(
  258. self, point_embeddings, point_labels, point_mask, point_seq_len, bs, device
  259. ):
  260. """
  261. Identical to _init_box. Except that C=2 for points (vs. 4 for boxes).
  262. """
  263. if point_embeddings is None:
  264. point_embeddings = torch.zeros(point_seq_len, bs, 2, device=device)
  265. if point_labels is None:
  266. point_labels = torch.ones(
  267. point_seq_len, bs, device=device, dtype=torch.long
  268. )
  269. if point_mask is None:
  270. point_mask = torch.zeros(bs, point_seq_len, device=device, dtype=torch.bool)
  271. return point_embeddings, point_labels, point_mask
  272. def _init_mask(
  273. self, mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
  274. ):
  275. # NOTE: Mask embeddings can be of arbitrary resolution, so we don't initialize it here.
  276. # In case we append new mask, we check that its resolution matches exisiting ones (if any).
  277. # In case mask_embeddings is None, we should never encode it.
  278. if mask_labels is None:
  279. mask_labels = torch.ones(mask_seq_len, bs, device=device, dtype=torch.long)
  280. if mask_mask is None:
  281. mask_mask = torch.zeros(bs, mask_seq_len, device=device, dtype=torch.bool)
  282. return mask_embeddings, mask_labels, mask_mask
  283. def append_boxes(self, boxes, labels, mask=None):
  284. if self.box_embeddings is None:
  285. self.box_embeddings = boxes
  286. self.box_labels = labels
  287. self.box_mask = mask
  288. return
  289. bs = self.box_embeddings.shape[1]
  290. assert boxes.shape[1] == labels.shape[1] == bs
  291. assert list(boxes.shape[:2]) == list(labels.shape[:2])
  292. if mask is None:
  293. mask = torch.zeros(
  294. bs, boxes.shape[0], dtype=torch.bool, device=boxes.device
  295. )
  296. self.box_labels, _ = concat_padded_sequences(
  297. self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
  298. )
  299. self.box_labels = self.box_labels.squeeze(-1)
  300. self.box_embeddings, self.box_mask = concat_padded_sequences(
  301. self.box_embeddings, self.box_mask, boxes, mask
  302. )
  303. def append_points(self, points, labels, mask=None):
  304. if self.point_embeddings is None:
  305. self.point_embeddings = points
  306. self.point_labels = labels
  307. self.point_mask = mask
  308. return
  309. bs = self.point_embeddings.shape[1]
  310. assert points.shape[1] == labels.shape[1] == bs
  311. assert list(points.shape[:2]) == list(labels.shape[:2])
  312. if mask is None:
  313. mask = torch.zeros(
  314. bs, points.shape[0], dtype=torch.bool, device=points.device
  315. )
  316. self.point_labels, _ = concat_padded_sequences(
  317. self.point_labels.unsqueeze(-1), self.point_mask, labels.unsqueeze(-1), mask
  318. )
  319. self.point_labels = self.point_labels.squeeze(-1)
  320. self.point_embeddings, self.point_mask = concat_padded_sequences(
  321. self.point_embeddings, self.point_mask, points, mask
  322. )
  323. def append_masks(self, masks, labels=None, attn_mask=None):
  324. if labels is not None:
  325. assert list(masks.shape[:2]) == list(labels.shape[:2])
  326. if self.mask_embeddings is None:
  327. self.mask_embeddings = masks
  328. mask_seq_len, bs = masks.shape[:2]
  329. if labels is None:
  330. self.mask_labels = torch.ones(
  331. mask_seq_len, bs, device=masks.device, dtype=torch.long
  332. )
  333. else:
  334. self.mask_labels = labels
  335. if attn_mask is None:
  336. self.mask_mask = torch.zeros(
  337. bs, mask_seq_len, device=masks.device, dtype=torch.bool
  338. )
  339. else:
  340. self.mask_mask = attn_mask
  341. else:
  342. raise NotImplementedError("Only one mask per prompt is supported.")
  343. def clone(self):
  344. return Prompt(
  345. box_embeddings=(
  346. None if self.box_embeddings is None else self.box_embeddings.clone()
  347. ),
  348. box_mask=None if self.box_mask is None else self.box_mask.clone(),
  349. point_embeddings=(
  350. None if self.point_embeddings is None else self.point_embeddings.clone()
  351. ),
  352. point_mask=None if self.point_mask is None else self.point_mask.clone(),
  353. box_labels=None if self.box_labels is None else self.box_labels.clone(),
  354. point_labels=(
  355. None if self.point_labels is None else self.point_labels.clone()
  356. ),
  357. )
  358. class MaskEncoder(nn.Module):
  359. """
  360. Base class for mask encoders.
  361. """
  362. def __init__(
  363. self,
  364. mask_downsampler: nn.Module,
  365. position_encoding: nn.Module,
  366. ):
  367. super().__init__()
  368. self.mask_downsampler = mask_downsampler
  369. self.position_encoding = position_encoding
  370. def forward(self, masks, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
  371. masks = self.mask_downsampler(masks)
  372. masks_pos = self.position_encoding(masks).to(masks.dtype)
  373. return masks, masks_pos
  374. class FusedMaskEncoder(MaskEncoder):
  375. """
  376. Identical to memory.SimpleMaskEncoder but follows the interface of geometry_encoders.MaskEncoder.
  377. We also remove the `skip_mask_sigmoid` option (to be handled outside the MaskEncoder).
  378. Fuses backbone image features with mask features.
  379. """
  380. def __init__(
  381. self,
  382. mask_downsampler: nn.Module,
  383. position_encoding: nn.Module,
  384. fuser: nn.Module,
  385. in_dim: int = 256,
  386. out_dim: int = 256,
  387. ):
  388. super().__init__(mask_downsampler, position_encoding)
  389. self.fuser = fuser
  390. self.out_proj = nn.Identity()
  391. if out_dim != in_dim:
  392. self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
  393. self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
  394. @override
  395. def forward(
  396. self,
  397. masks: torch.Tensor,
  398. pix_feat: torch.Tensor,
  399. **kwargs,
  400. ) -> Tuple[torch.Tensor, torch.Tensor]:
  401. masks = self.mask_downsampler(masks)
  402. ## Fuse pix_feats and downsampled masks
  403. # in case the visual features are on CPU, cast them to CUDA
  404. pix_feat = pix_feat.to(masks.device)
  405. x = self.pix_feat_proj(pix_feat)
  406. x = x + masks
  407. x = self.fuser(x)
  408. x = self.out_proj(x)
  409. pos = self.position_encoding(x).to(x.dtype)
  410. return x, pos
  411. class SequenceGeometryEncoder(nn.Module):
  412. """
  413. This a fully fledged encoder for geometric prompts.
  414. It assumes boxes are passed in the "normalized CxCyWH" format, and points in normalized xy
  415. This allows flexibility in how to encode the features (eg do pooling)
  416. Points and boxes can be encoded with any of the three possibilities:
  417. - direct projection: we just compute a linear from coordinate space to d_model
  418. - pooling: pool features from the backbone in the requested location.
  419. For boxes, it's a roi align
  420. For points it's a grid sample
  421. - pos encoder: Take the position encoding of the point or box center
  422. These three options are mutually compatible. If several are selected, we'll take a simple addition
  423. As an alternative, we offer the possibility to encode points only.
  424. In that case, the boxes are converted to two points for the top left and bottom right corners (with appropriate labels)
  425. On top of these encodings, we offer the possibility to further encode the prompt sequence with a transformer.
  426. """
  427. def __init__(
  428. self,
  429. encode_boxes_as_points: bool,
  430. points_direct_project: bool,
  431. points_pool: bool,
  432. points_pos_enc: bool,
  433. boxes_direct_project: bool,
  434. boxes_pool: bool,
  435. boxes_pos_enc: bool,
  436. d_model: int,
  437. pos_enc,
  438. num_layers: int,
  439. layer: nn.Module,
  440. roi_size: int = 7, # for boxes pool
  441. add_cls: bool = True,
  442. add_post_encode_proj: bool = True,
  443. mask_encoder: MaskEncoder = None,
  444. add_mask_label: bool = False,
  445. use_act_ckpt: bool = False,
  446. ):
  447. super().__init__()
  448. self.d_model = d_model
  449. self.pos_enc = pos_enc
  450. self.encode_boxes_as_points = encode_boxes_as_points
  451. self.roi_size = roi_size
  452. # There usually are two labels: positive and negatives.
  453. # If we encode boxes as points, we have 3 types of points: regular, top left, bottom right
  454. # These 3 types can be positives or negatives, hence 2*3 = 6 labels
  455. num_labels = 6 if self.encode_boxes_as_points else 2
  456. self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
  457. # This is a cls token, can be used for pooling if need be.
  458. # It also ensures that the encoded sequences are always non-empty
  459. self.cls_embed = None
  460. if add_cls:
  461. self.cls_embed = torch.nn.Embedding(1, self.d_model)
  462. assert points_direct_project or points_pos_enc or points_pool, (
  463. "Error: need at least one way to encode points"
  464. )
  465. assert (
  466. encode_boxes_as_points
  467. or boxes_direct_project
  468. or boxes_pos_enc
  469. or boxes_pool
  470. ), "Error: need at least one way to encode boxes"
  471. self.points_direct_project = None
  472. if points_direct_project:
  473. self.points_direct_project = nn.Linear(2, self.d_model)
  474. self.points_pool_project = None
  475. if points_pool:
  476. self.points_pool_project = nn.Linear(self.d_model, self.d_model)
  477. self.points_pos_enc_project = None
  478. if points_pos_enc:
  479. self.points_pos_enc_project = nn.Linear(self.d_model, self.d_model)
  480. self.boxes_direct_project = None
  481. self.boxes_pool_project = None
  482. self.boxes_pos_enc_project = None
  483. if not encode_boxes_as_points:
  484. if boxes_direct_project:
  485. self.boxes_direct_project = nn.Linear(4, self.d_model)
  486. if boxes_pool:
  487. self.boxes_pool_project = nn.Conv2d(
  488. self.d_model, self.d_model, self.roi_size
  489. )
  490. if boxes_pos_enc:
  491. self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
  492. self.final_proj = None
  493. if add_post_encode_proj:
  494. self.final_proj = nn.Linear(self.d_model, self.d_model)
  495. self.norm = nn.LayerNorm(self.d_model)
  496. self.img_pre_norm = nn.Identity()
  497. if self.points_pool_project is not None or self.boxes_pool_project is not None:
  498. self.img_pre_norm = nn.LayerNorm(self.d_model)
  499. self.encode = None
  500. if num_layers > 0:
  501. assert add_cls, (
  502. "It's currently highly recommended to add a CLS when using a transformer"
  503. )
  504. self.encode = get_clones(layer, num_layers)
  505. self.encode_norm = nn.LayerNorm(self.d_model)
  506. if mask_encoder is not None:
  507. assert isinstance(mask_encoder, MaskEncoder), (
  508. f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
  509. )
  510. if add_mask_label:
  511. self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
  512. self.add_mask_label = add_mask_label
  513. self.mask_encoder = mask_encoder
  514. self.use_act_ckpt = use_act_ckpt
  515. def _encode_points(self, points, points_mask, points_labels, img_feats):
  516. points_embed = None
  517. n_points, bs = points.shape[:2]
  518. if self.points_direct_project is not None:
  519. proj = self.points_direct_project(points)
  520. assert points_embed is None
  521. points_embed = proj
  522. if self.points_pool_project is not None:
  523. # points are [Num_points, bs, 2], normalized in [0, 1]
  524. # the grid needs to be [Bs, H_out, W_out, 2] normalized in [-1,1]
  525. # Will take H_out = num_points, w_out = 1
  526. grid = points.transpose(0, 1).unsqueeze(2)
  527. # re normalize to [-1, 1]
  528. grid = (grid * 2) - 1
  529. sampled = torch.nn.functional.grid_sample(
  530. img_feats, grid, align_corners=False
  531. )
  532. assert list(sampled.shape) == [bs, self.d_model, n_points, 1]
  533. sampled = sampled.squeeze(-1).permute(2, 0, 1)
  534. proj = self.points_pool_project(sampled)
  535. if points_embed is None:
  536. points_embed = proj
  537. else:
  538. points_embed = points_embed + proj
  539. if self.points_pos_enc_project is not None:
  540. x, y = points.unbind(-1)
  541. enc_x, enc_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
  542. enc_x = enc_x.view(n_points, bs, enc_x.shape[-1])
  543. enc_y = enc_y.view(n_points, bs, enc_y.shape[-1])
  544. enc = torch.cat([enc_x, enc_y], -1)
  545. proj = self.points_pos_enc_project(enc)
  546. if points_embed is None:
  547. points_embed = proj
  548. else:
  549. points_embed = points_embed + proj
  550. type_embed = self.label_embed(points_labels.long())
  551. return type_embed + points_embed, points_mask
  552. def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
  553. boxes_embed = None
  554. n_boxes, bs = boxes.shape[:2]
  555. if self.boxes_direct_project is not None:
  556. proj = self.boxes_direct_project(boxes)
  557. assert boxes_embed is None
  558. boxes_embed = proj
  559. if self.boxes_pool_project is not None:
  560. H, W = img_feats.shape[-2:]
  561. # boxes are [Num_boxes, bs, 4], normalized in [0, 1]
  562. # We need to denormalize, and convert to [x, y, x, y]
  563. boxes_xyxy = box_cxcywh_to_xyxy(boxes)
  564. scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
  565. scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
  566. scale = scale.view(1, 1, 4)
  567. boxes_xyxy = boxes_xyxy * scale
  568. sampled = torchvision.ops.roi_align(
  569. img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
  570. )
  571. assert list(sampled.shape) == [
  572. bs * n_boxes,
  573. self.d_model,
  574. self.roi_size,
  575. self.roi_size,
  576. ]
  577. proj = self.boxes_pool_project(sampled)
  578. proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
  579. if boxes_embed is None:
  580. boxes_embed = proj
  581. else:
  582. boxes_embed = boxes_embed + proj
  583. if self.boxes_pos_enc_project is not None:
  584. cx, cy, w, h = boxes.unbind(-1)
  585. enc = self.pos_enc.encode_boxes(
  586. cx.flatten(), cy.flatten(), w.flatten(), h.flatten()
  587. )
  588. enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
  589. proj = self.boxes_pos_enc_project(enc)
  590. if boxes_embed is None:
  591. boxes_embed = proj
  592. else:
  593. boxes_embed = boxes_embed + proj
  594. type_embed = self.label_embed(boxes_labels.long())
  595. return type_embed + boxes_embed, boxes_mask
  596. def _encode_masks(
  597. self,
  598. masks: torch.Tensor,
  599. attn_mask: torch.Tensor,
  600. mask_labels: torch.Tensor,
  601. img_feats: torch.Tensor = None,
  602. ):
  603. n_masks, bs = masks.shape[:2]
  604. assert n_masks == 1, (
  605. "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
  606. )
  607. assert list(attn_mask.shape) == [
  608. bs,
  609. n_masks,
  610. ], (
  611. f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
  612. )
  613. masks, pos = self.mask_encoder(
  614. masks=masks.flatten(0, 1).float(),
  615. pix_feat=img_feats,
  616. )
  617. H, W = masks.shape[-2:]
  618. n_tokens_per_mask = H * W
  619. # NOTE: We directly add pos enc here as we usually don't keep track of pos encoding for the concatenated prompt (text, other geometric prompts). Might need to do some refactoring for more flexibility.
  620. masks = masks + pos
  621. masks = masks.view(n_masks, bs, *masks.shape[1:]).flatten(
  622. -2
  623. ) # n_masks x bs x C x H*W
  624. masks = masks.permute(0, 3, 1, 2).flatten(0, 1) # n_masks * H*W x bs x C
  625. attn_mask = attn_mask.repeat_interleave(n_tokens_per_mask, dim=1)
  626. if self.add_mask_label:
  627. masks = masks + self.mask_label_embed(mask_labels.long())
  628. return masks, attn_mask
  629. def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
  630. points = geo_prompt.point_embeddings
  631. points_mask = geo_prompt.point_mask
  632. points_labels = geo_prompt.point_labels
  633. boxes = geo_prompt.box_embeddings
  634. boxes_mask = geo_prompt.box_mask
  635. boxes_labels = geo_prompt.box_labels
  636. masks = geo_prompt.mask_embeddings
  637. masks_mask = geo_prompt.mask_mask
  638. masks_labels = geo_prompt.mask_labels
  639. seq_first_img_feats = img_feats[-1] # [H*W, B, C]
  640. seq_first_img_pos_embeds = (
  641. img_pos_embeds[-1]
  642. if img_pos_embeds is not None
  643. else torch.zeros_like(seq_first_img_feats)
  644. )
  645. if self.points_pool_project or self.boxes_pool_project:
  646. assert len(img_feats) == len(img_sizes)
  647. cur_img_feat = img_feats[-1]
  648. cur_img_feat = self.img_pre_norm(cur_img_feat)
  649. H, W = img_sizes[-1]
  650. assert cur_img_feat.shape[0] == H * W
  651. N, C = cur_img_feat.shape[-2:]
  652. # Put back in NxCxHxW
  653. cur_img_feat = cur_img_feat.permute(1, 2, 0)
  654. cur_img_feat = cur_img_feat.view(N, C, H, W)
  655. img_feats = cur_img_feat
  656. if self.encode_boxes_as_points:
  657. assert boxes is not None
  658. assert geo_prompt.box_mask is not None
  659. assert geo_prompt.box_labels is not None
  660. assert boxes.shape[-1] == 4
  661. boxes_xyxy = box_cxcywh_to_xyxy(boxes)
  662. top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
  663. labels_tl = geo_prompt.box_labels + 2
  664. labels_br = geo_prompt.box_labels + 4
  665. # Append to the existing points
  666. points, _ = concat_padded_sequences(
  667. points, points_mask, top_left, boxes_mask
  668. )
  669. points_labels, points_mask = concat_padded_sequences(
  670. points_labels.unsqueeze(-1),
  671. points_mask,
  672. labels_tl.unsqueeze(-1),
  673. boxes_mask,
  674. )
  675. points_labels = points_labels.squeeze(-1)
  676. points, _ = concat_padded_sequences(
  677. points, points_mask, bottom_right, boxes_mask
  678. )
  679. points_labels, points_mask = concat_padded_sequences(
  680. points_labels.unsqueeze(-1),
  681. points_mask,
  682. labels_br.unsqueeze(-1),
  683. boxes_mask,
  684. )
  685. points_labels = points_labels.squeeze(-1)
  686. final_embeds, final_mask = self._encode_points(
  687. points=points,
  688. points_mask=points_mask,
  689. points_labels=points_labels,
  690. img_feats=img_feats,
  691. )
  692. if not self.encode_boxes_as_points:
  693. boxes_embeds, boxes_mask = self._encode_boxes(
  694. boxes=boxes,
  695. boxes_mask=boxes_mask,
  696. boxes_labels=boxes_labels,
  697. img_feats=img_feats,
  698. )
  699. final_embeds, final_mask = concat_padded_sequences(
  700. final_embeds, final_mask, boxes_embeds, boxes_mask
  701. )
  702. if masks is not None and self.mask_encoder is not None:
  703. masks_embed, masks_mask = self._encode_masks(
  704. masks=masks,
  705. attn_mask=masks_mask,
  706. mask_labels=masks_labels,
  707. img_feats=img_feats,
  708. )
  709. if points.size(0) == boxes.size(0) == 0:
  710. return masks_embed, masks_mask
  711. bs = final_embeds.shape[1]
  712. assert final_mask.shape[0] == bs
  713. if self.cls_embed is not None:
  714. cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
  715. cls_mask = torch.zeros(
  716. bs, 1, dtype=final_mask.dtype, device=final_mask.device
  717. )
  718. final_embeds, final_mask = concat_padded_sequences(
  719. final_embeds, final_mask, cls, cls_mask
  720. )
  721. if self.final_proj is not None:
  722. final_embeds = self.norm(self.final_proj(final_embeds))
  723. if self.encode is not None:
  724. for lay in self.encode:
  725. final_embeds = activation_ckpt_wrapper(lay)(
  726. tgt=final_embeds,
  727. memory=seq_first_img_feats,
  728. tgt_key_padding_mask=final_mask,
  729. pos=seq_first_img_pos_embeds,
  730. act_ckpt_enable=self.training and self.use_act_ckpt,
  731. )
  732. final_embeds = self.encode_norm(final_embeds)
  733. # Finally, concat mask embeddings if any
  734. if masks is not None and self.mask_encoder is not None:
  735. final_embeds, final_mask = concat_padded_sequences(
  736. final_embeds, final_mask, masks_embed, masks_mask
  737. )
  738. return final_embeds, final_mask