sam3_image.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import os
  4. from copy import deepcopy
  5. from typing import Dict, List, Optional, Tuple
  6. import numpy as np
  7. import torch
  8. from sam3.model.model_misc import SAM3Output
  9. from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
  10. from sam3.model.vl_combiner import SAM3VLBackbone
  11. from sam3.perflib.nms import nms_masks
  12. from sam3.train.data.collator import BatchedDatapoint
  13. from .act_ckpt_utils import activation_ckpt_wrapper
  14. from .box_ops import box_cxcywh_to_xyxy
  15. from .geometry_encoders import Prompt
  16. from .model_misc import inverse_sigmoid
  17. def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
  18. out[out_name] = out_value[-1] if auxiliary else out_value
  19. if auxiliary and update_aux:
  20. if "aux_outputs" not in out:
  21. out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
  22. assert len(out["aux_outputs"]) == len(out_value) - 1
  23. for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
  24. aux_output[out_name] = aux_value
  25. class Sam3Image(torch.nn.Module):
  26. TEXT_ID_FOR_TEXT = 0
  27. TEXT_ID_FOR_VISUAL = 1
  28. TEXT_ID_FOR_GEOMETRIC = 2
  29. def __init__(
  30. self,
  31. backbone: SAM3VLBackbone,
  32. transformer,
  33. input_geometry_encoder,
  34. segmentation_head=None,
  35. num_feature_levels=1,
  36. o2m_mask_predict=True,
  37. dot_prod_scoring=None,
  38. use_instance_query: bool = True,
  39. multimask_output: bool = True,
  40. use_act_checkpoint_seg_head: bool = True,
  41. interactivity_in_encoder: bool = True,
  42. matcher=None,
  43. use_dot_prod_scoring=True,
  44. supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
  45. detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
  46. separate_scorer_for_instance: bool = False,
  47. num_interactive_steps_val: int = 0,
  48. inst_interactive_predictor: SAM3InteractiveImagePredictor = None,
  49. **kwargs,
  50. ):
  51. super().__init__()
  52. self.backbone = backbone
  53. self.geometry_encoder = input_geometry_encoder
  54. self.transformer = transformer
  55. self.hidden_dim = transformer.d_model
  56. self.num_feature_levels = num_feature_levels
  57. self.segmentation_head = segmentation_head
  58. self.o2m_mask_predict = o2m_mask_predict
  59. self.dot_prod_scoring = dot_prod_scoring
  60. self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
  61. self.interactivity_in_encoder = interactivity_in_encoder
  62. self.matcher = matcher
  63. self.num_interactive_steps_val = num_interactive_steps_val
  64. self.use_dot_prod_scoring = use_dot_prod_scoring
  65. if self.use_dot_prod_scoring:
  66. assert dot_prod_scoring is not None
  67. self.dot_prod_scoring = dot_prod_scoring
  68. self.instance_dot_prod_scoring = None
  69. if separate_scorer_for_instance:
  70. self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
  71. else:
  72. self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
  73. self.instance_class_embed = None
  74. if separate_scorer_for_instance:
  75. self.instance_class_embed = deepcopy(self.class_embed)
  76. self.supervise_joint_box_scores = supervise_joint_box_scores
  77. self.detach_presence_in_joint_score = detach_presence_in_joint_score
  78. # verify the number of queries for O2O and O2M
  79. num_o2o_static = self.transformer.decoder.num_queries
  80. num_o2m_static = self.transformer.decoder.num_o2m_queries
  81. assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
  82. self.dac = self.transformer.decoder.dac
  83. self.use_instance_query = use_instance_query
  84. self.multimask_output = multimask_output
  85. self.inst_interactive_predictor = inst_interactive_predictor
  86. @property
  87. def device(self):
  88. self._device = getattr(self, "_device", None) or next(self.parameters()).device
  89. return self._device
  90. def to(self, *args, **kwargs):
  91. # clear cached _device in case the model is moved to a different device
  92. self._device = None
  93. return super().to(*args, **kwargs)
  94. def _get_img_feats(self, backbone_out, img_ids):
  95. """Retrieve correct image features from backbone output."""
  96. if "backbone_fpn" in backbone_out:
  97. if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None:
  98. img_ids = backbone_out["id_mapping"][img_ids]
  99. # If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?)
  100. # We currently don't expect this to happen. We could technically trigger a recompute here,
  101. # but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf
  102. torch._assert_async((img_ids >= 0).all())
  103. vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :]
  104. vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
  105. vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] # (H, W) shapes
  106. # index and flatten visual features NxCxHxW => HWxNxC (batch-first => seq-first)
  107. img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats]
  108. img_pos_embeds = [
  109. x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc
  110. ]
  111. return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes
  112. # Image features not available in backbone output, so we compute them on the fly
  113. # This case likely occurs for video. In that case, we want to forward only the current frame
  114. img_batch = backbone_out["img_batch_all_stages"]
  115. if img_ids.numel() > 1:
  116. # Only forward backbone on unique image ids to avoid repetitive computation
  117. unique_ids, _ = torch.unique(img_ids, return_inverse=True)
  118. else:
  119. unique_ids, _ = img_ids, slice(None)
  120. # Compute the image features on those unique image ids
  121. # note: we allow using a list (or other indexable types) of tensors as img_batch
  122. # (e.g. for async frame loading in demo). In this case we index img_batch.tensors directly
  123. if isinstance(img_batch, torch.Tensor):
  124. image = img_batch[unique_ids]
  125. elif unique_ids.numel() == 1:
  126. image = img_batch[unique_ids.item()].unsqueeze(0)
  127. else:
  128. image = torch.stack([img_batch[i] for i in unique_ids.tolist()])
  129. # `img_batch` might be fp16 and offloaded to CPU
  130. image = image.to(dtype=torch.float32, device=self.device)
  131. # Next time we call this function, we want to remember which indices we computed
  132. id_mapping = torch.full(
  133. (len(img_batch),), -1, dtype=torch.long, device=self.device
  134. )
  135. id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device)
  136. backbone_out = {
  137. **backbone_out,
  138. **self.backbone.forward_image(image),
  139. "id_mapping": id_mapping,
  140. }
  141. assert "backbone_fpn" in backbone_out
  142. return self._get_img_feats(backbone_out, img_ids=img_ids)
  143. def _encode_prompt(
  144. self,
  145. backbone_out,
  146. find_input,
  147. geometric_prompt,
  148. visual_prompt_embed=None,
  149. visual_prompt_mask=None,
  150. encode_text=True,
  151. prev_mask_pred=None,
  152. ):
  153. # index text features (note that regardless of early or late fusion, the batch size of
  154. # `txt_feats` is always the number of *prompts* in the encoder)
  155. txt_ids = find_input.text_ids
  156. txt_feats = backbone_out["language_features"][:, txt_ids]
  157. txt_masks = backbone_out["language_mask"][txt_ids]
  158. feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
  159. backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
  160. if prev_mask_pred is not None:
  161. img_feats = [img_feats[-1] + prev_mask_pred]
  162. # Encode geometry
  163. geo_feats, geo_masks = self.geometry_encoder(
  164. geo_prompt=geometric_prompt,
  165. img_feats=img_feats,
  166. img_sizes=vis_feat_sizes,
  167. img_pos_embeds=img_pos_embeds,
  168. )
  169. if visual_prompt_embed is None:
  170. visual_prompt_embed = torch.zeros(
  171. (0, *geo_feats.shape[1:]), device=geo_feats.device
  172. )
  173. visual_prompt_mask = torch.zeros(
  174. (*geo_masks.shape[:-1], 0),
  175. device=geo_masks.device,
  176. dtype=geo_masks.dtype,
  177. )
  178. if encode_text:
  179. prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0)
  180. prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1)
  181. else:
  182. prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
  183. prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
  184. return prompt, prompt_mask, backbone_out
  185. def _run_encoder(
  186. self,
  187. backbone_out,
  188. find_input,
  189. prompt,
  190. prompt_mask,
  191. encoder_extra_kwargs: Optional[Dict] = None,
  192. ):
  193. feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
  194. backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
  195. # Run the encoder
  196. prompt_pos_embed = torch.zeros_like(prompt)
  197. # make a copy of the image feature lists since the encoder may modify these lists in-place
  198. memory = self.transformer.encoder(
  199. src=img_feats.copy(),
  200. src_key_padding_mask=None,
  201. src_pos=img_pos_embeds.copy(),
  202. prompt=prompt,
  203. prompt_pos=prompt_pos_embed,
  204. prompt_key_padding_mask=prompt_mask,
  205. feat_sizes=vis_feat_sizes,
  206. encoder_extra_kwargs=encoder_extra_kwargs,
  207. )
  208. encoder_out = {
  209. # encoded image features
  210. "encoder_hidden_states": memory["memory"],
  211. "pos_embed": memory["pos_embed"],
  212. "padding_mask": memory["padding_mask"],
  213. "level_start_index": memory["level_start_index"],
  214. "spatial_shapes": memory["spatial_shapes"],
  215. "valid_ratios": memory["valid_ratios"],
  216. "vis_feat_sizes": vis_feat_sizes,
  217. # encoded text features (or other prompts)
  218. "prompt_before_enc": prompt,
  219. "prompt_after_enc": memory.get("memory_text", prompt),
  220. "prompt_mask": prompt_mask,
  221. }
  222. return backbone_out, encoder_out, feat_tuple
  223. def _run_decoder(
  224. self,
  225. pos_embed,
  226. memory,
  227. src_mask,
  228. out,
  229. prompt,
  230. prompt_mask,
  231. encoder_out,
  232. ):
  233. bs = memory.shape[1]
  234. query_embed = self.transformer.decoder.query_embed.weight
  235. tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
  236. apply_dac = self.transformer.decoder.dac and self.training
  237. hs, reference_boxes, dec_presence_out, dec_presence_feats = (
  238. self.transformer.decoder(
  239. tgt=tgt,
  240. memory=memory,
  241. memory_key_padding_mask=src_mask,
  242. pos=pos_embed,
  243. reference_boxes=None,
  244. level_start_index=encoder_out["level_start_index"],
  245. spatial_shapes=encoder_out["spatial_shapes"],
  246. valid_ratios=encoder_out["valid_ratios"],
  247. tgt_mask=None,
  248. memory_text=prompt,
  249. text_attention_mask=prompt_mask,
  250. apply_dac=apply_dac,
  251. )
  252. )
  253. hs = hs.transpose(1, 2) # seq-first to batch-first
  254. reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
  255. if dec_presence_out is not None:
  256. # seq-first to batch-first
  257. dec_presence_out = dec_presence_out.transpose(1, 2)
  258. out["presence_feats"] = dec_presence_feats
  259. self._update_scores_and_boxes(
  260. out,
  261. hs,
  262. reference_boxes,
  263. prompt,
  264. prompt_mask,
  265. dec_presence_out=dec_presence_out,
  266. )
  267. return out, hs
  268. def _update_scores_and_boxes(
  269. self,
  270. out,
  271. hs,
  272. reference_boxes,
  273. prompt,
  274. prompt_mask,
  275. dec_presence_out=None,
  276. is_instance_prompt=False,
  277. ):
  278. apply_dac = self.transformer.decoder.dac and self.training
  279. num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
  280. num_o2m = hs.size(2) - num_o2o
  281. assert num_o2m == (num_o2o if apply_dac else 0)
  282. out["queries"] = hs[-1][:, :num_o2o] # remove o2m queries if there are any
  283. # score prediction
  284. if self.use_dot_prod_scoring:
  285. dot_prod_scoring_head = self.dot_prod_scoring
  286. if is_instance_prompt and self.instance_dot_prod_scoring is not None:
  287. dot_prod_scoring_head = self.instance_dot_prod_scoring
  288. outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
  289. else:
  290. class_embed_head = self.class_embed
  291. if is_instance_prompt and self.instance_class_embed is not None:
  292. class_embed_head = self.instance_class_embed
  293. outputs_class = class_embed_head(hs)
  294. # box prediction
  295. box_head = self.transformer.decoder.bbox_embed
  296. if (
  297. is_instance_prompt
  298. and self.transformer.decoder.instance_bbox_embed is not None
  299. ):
  300. box_head = self.transformer.decoder.instance_bbox_embed
  301. anchor_box_offsets = box_head(hs)
  302. reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
  303. outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
  304. outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord)
  305. if dec_presence_out is not None:
  306. _update_out(
  307. out, "presence_logit_dec", dec_presence_out, update_aux=self.training
  308. )
  309. if self.supervise_joint_box_scores:
  310. assert dec_presence_out is not None
  311. prob_dec_presence_out = dec_presence_out.clone().sigmoid()
  312. if self.detach_presence_in_joint_score:
  313. prob_dec_presence_out = prob_dec_presence_out.detach()
  314. outputs_class = inverse_sigmoid(
  315. outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)
  316. ).clamp(min=-10.0, max=10.0)
  317. _update_out(
  318. out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training
  319. )
  320. _update_out(
  321. out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training
  322. )
  323. _update_out(
  324. out,
  325. "pred_boxes_xyxy",
  326. outputs_boxes_xyxy[:, :, :num_o2o],
  327. update_aux=self.training,
  328. )
  329. if num_o2m > 0 and self.training:
  330. _update_out(
  331. out,
  332. "pred_logits_o2m",
  333. outputs_class[:, :, num_o2o:],
  334. update_aux=self.training,
  335. )
  336. _update_out(
  337. out,
  338. "pred_boxes_o2m",
  339. outputs_coord[:, :, num_o2o:],
  340. update_aux=self.training,
  341. )
  342. _update_out(
  343. out,
  344. "pred_boxes_xyxy_o2m",
  345. outputs_boxes_xyxy[:, :, num_o2o:],
  346. update_aux=self.training,
  347. )
  348. def _run_segmentation_heads(
  349. self,
  350. out,
  351. backbone_out,
  352. img_ids,
  353. vis_feat_sizes,
  354. encoder_hidden_states,
  355. prompt,
  356. prompt_mask,
  357. hs,
  358. ):
  359. apply_dac = self.transformer.decoder.dac and self.training
  360. if self.segmentation_head is not None:
  361. num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
  362. num_o2m = hs.size(2) - num_o2o
  363. obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
  364. seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)(
  365. backbone_feats=backbone_out["backbone_fpn"],
  366. obj_queries=obj_queries,
  367. image_ids=img_ids,
  368. encoder_hidden_states=encoder_hidden_states,
  369. act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head,
  370. prompt=prompt,
  371. prompt_mask=prompt_mask,
  372. )
  373. aux_masks = False # self.aux_loss and self.segmentation_head.aux_masks
  374. for k, v in seg_head_outputs.items():
  375. if k in self.segmentation_head.instance_keys:
  376. _update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks)
  377. if (
  378. self.o2m_mask_predict and num_o2m > 0
  379. ): # handle o2m mask prediction
  380. _update_out(
  381. out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks
  382. )
  383. else:
  384. out[k] = v
  385. else:
  386. backbone_out.pop("backbone_fpn", None)
  387. def _get_best_mask(self, out):
  388. prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1)
  389. batch_idx = torch.arange(
  390. out["pred_logits"].shape[0], device=prev_mask_idx.device
  391. )
  392. prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None]
  393. # Downsample mask to match image resolution.
  394. prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler(
  395. prev_mask_pred
  396. )
  397. prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1)
  398. return prev_mask_pred
  399. def forward_grounding(
  400. self,
  401. backbone_out,
  402. find_input,
  403. find_target,
  404. geometric_prompt: Prompt,
  405. ):
  406. with torch.profiler.record_function("SAM3Image._encode_prompt"):
  407. prompt, prompt_mask, backbone_out = self._encode_prompt(
  408. backbone_out, find_input, geometric_prompt
  409. )
  410. # Run the encoder
  411. with torch.profiler.record_function("SAM3Image._run_encoder"):
  412. backbone_out, encoder_out, _ = self._run_encoder(
  413. backbone_out, find_input, prompt, prompt_mask
  414. )
  415. out = {
  416. "encoder_hidden_states": encoder_out["encoder_hidden_states"],
  417. "prev_encoder_out": {
  418. "encoder_out": encoder_out,
  419. "backbone_out": backbone_out,
  420. },
  421. }
  422. # Run the decoder
  423. with torch.profiler.record_function("SAM3Image._run_decoder"):
  424. out, hs = self._run_decoder(
  425. memory=out["encoder_hidden_states"],
  426. pos_embed=encoder_out["pos_embed"],
  427. src_mask=encoder_out["padding_mask"],
  428. out=out,
  429. prompt=prompt,
  430. prompt_mask=prompt_mask,
  431. encoder_out=encoder_out,
  432. )
  433. # Run segmentation heads
  434. with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
  435. self._run_segmentation_heads(
  436. out=out,
  437. backbone_out=backbone_out,
  438. img_ids=find_input.img_ids,
  439. vis_feat_sizes=encoder_out["vis_feat_sizes"],
  440. encoder_hidden_states=out["encoder_hidden_states"],
  441. prompt=prompt,
  442. prompt_mask=prompt_mask,
  443. hs=hs,
  444. )
  445. if self.training or self.num_interactive_steps_val > 0:
  446. self._compute_matching(out, self.back_convert(find_target))
  447. return out
  448. def _postprocess_out(self, out: Dict, multimask_output: bool = False):
  449. # For multimask output, during eval we return the single best mask with the dict keys expected by the evaluators, but also return the multimasks output with new keys.
  450. num_mask_boxes = out["pred_boxes"].size(1)
  451. if not self.training and multimask_output and num_mask_boxes > 1:
  452. out["multi_pred_logits"] = out["pred_logits"]
  453. if "pred_masks" in out:
  454. out["multi_pred_masks"] = out["pred_masks"]
  455. out["multi_pred_boxes"] = out["pred_boxes"]
  456. out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"]
  457. best_mask_idx = out["pred_logits"].argmax(1).squeeze(1)
  458. batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device)
  459. out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze(
  460. 1
  461. )
  462. if "pred_masks" in out:
  463. out["pred_masks"] = out["pred_masks"][
  464. batch_idx, best_mask_idx
  465. ].unsqueeze(1)
  466. out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1)
  467. out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][
  468. batch_idx, best_mask_idx
  469. ].unsqueeze(1)
  470. return out
  471. def _get_dummy_prompt(self, num_prompts=1):
  472. device = self.device
  473. geometric_prompt = Prompt(
  474. box_embeddings=torch.zeros(0, num_prompts, 4, device=device),
  475. box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool),
  476. )
  477. return geometric_prompt
  478. def forward(self, input: BatchedDatapoint):
  479. device = self.device
  480. backbone_out = {"img_batch_all_stages": input.img_batch}
  481. backbone_out.update(self.backbone.forward_image(input.img_batch))
  482. num_frames = len(input.find_inputs)
  483. assert num_frames == 1
  484. text_outputs = self.backbone.forward_text(input.find_text_batch, device=device)
  485. backbone_out.update(text_outputs)
  486. previous_stages_out = SAM3Output(
  487. iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE
  488. )
  489. find_input = input.find_inputs[0]
  490. find_target = input.find_targets[0]
  491. if find_input.input_points is not None and find_input.input_points.numel() > 0:
  492. print("Warning: Point prompts are ignored in PCS.")
  493. num_interactive_steps = 0 if self.training else self.num_interactive_steps_val
  494. geometric_prompt = Prompt(
  495. box_embeddings=find_input.input_boxes,
  496. box_mask=find_input.input_boxes_mask,
  497. box_labels=find_input.input_boxes_label,
  498. )
  499. # Init vars that are shared across the loop.
  500. stage_outs = []
  501. for cur_step in range(num_interactive_steps + 1):
  502. if cur_step > 0:
  503. # We sample interactive geometric prompts (boxes, points)
  504. geometric_prompt, _ = self.interactive_prompt_sampler.sample(
  505. geo_prompt=geometric_prompt,
  506. find_target=find_target,
  507. previous_out=stage_outs[-1],
  508. )
  509. out = self.forward_grounding(
  510. backbone_out=backbone_out,
  511. find_input=find_input,
  512. find_target=find_target,
  513. geometric_prompt=geometric_prompt.clone(),
  514. )
  515. stage_outs.append(out)
  516. previous_stages_out.append(stage_outs)
  517. return previous_stages_out
  518. def _compute_matching(self, out, targets):
  519. out["indices"] = self.matcher(out, targets)
  520. for aux_out in out.get("aux_outputs", []):
  521. aux_out["indices"] = self.matcher(aux_out, targets)
  522. def back_convert(self, targets):
  523. batched_targets = {
  524. "boxes": targets.boxes.view(-1, 4),
  525. "boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)),
  526. "boxes_padded": targets.boxes_padded,
  527. "positive_map": targets.boxes.new_ones(len(targets.boxes), 1),
  528. "num_boxes": targets.num_boxes,
  529. "masks": targets.segments,
  530. "semantic_masks": targets.semantic_segments,
  531. "is_valid_mask": targets.is_valid_segment,
  532. "is_exhaustive": targets.is_exhaustive,
  533. "object_ids_packed": targets.object_ids,
  534. "object_ids_padded": targets.object_ids_padded,
  535. }
  536. return batched_targets
  537. def predict_inst(
  538. self,
  539. inference_state,
  540. **kwargs,
  541. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  542. orig_h, orig_w = (
  543. inference_state["original_height"],
  544. inference_state["original_width"],
  545. )
  546. backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
  547. (
  548. _,
  549. vision_feats,
  550. _,
  551. _,
  552. ) = self.inst_interactive_predictor.model._prepare_backbone_features(
  553. backbone_out
  554. )
  555. # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
  556. vision_feats[-1] = (
  557. vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
  558. )
  559. feats = [
  560. feat.permute(1, 2, 0).view(1, -1, *feat_size)
  561. for feat, feat_size in zip(
  562. vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
  563. )
  564. ][::-1]
  565. self.inst_interactive_predictor._features = {
  566. "image_embed": feats[-1],
  567. "high_res_feats": feats[:-1],
  568. }
  569. self.inst_interactive_predictor._is_image_set = True
  570. self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)]
  571. res = self.inst_interactive_predictor.predict(**kwargs)
  572. self.inst_interactive_predictor._features = None
  573. self.inst_interactive_predictor._is_image_set = False
  574. return res
  575. def predict_inst_batch(
  576. self,
  577. inference_state,
  578. *args,
  579. **kwargs,
  580. ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
  581. backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
  582. (
  583. _,
  584. vision_feats,
  585. _,
  586. _,
  587. ) = self.inst_interactive_predictor.model._prepare_backbone_features(
  588. backbone_out
  589. )
  590. # Add no_mem_embed, which is added to the lowest res feat. map during training on videos
  591. vision_feats[-1] = (
  592. vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
  593. )
  594. batch_size = vision_feats[-1].shape[1]
  595. orig_heights, orig_widths = (
  596. inference_state["original_heights"],
  597. inference_state["original_widths"],
  598. )
  599. assert batch_size == len(orig_heights) == len(orig_widths), (
  600. f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
  601. )
  602. feats = [
  603. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  604. for feat, feat_size in zip(
  605. vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
  606. )
  607. ][::-1]
  608. self.inst_interactive_predictor._features = {
  609. "image_embed": feats[-1],
  610. "high_res_feats": feats[:-1],
  611. }
  612. self.inst_interactive_predictor._is_image_set = True
  613. self.inst_interactive_predictor._is_batch = True
  614. self.inst_interactive_predictor._orig_hw = [
  615. (orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths)
  616. ]
  617. res = self.inst_interactive_predictor.predict_batch(*args, **kwargs)
  618. self.inst_interactive_predictor._features = None
  619. self.inst_interactive_predictor._is_image_set = False
  620. self.inst_interactive_predictor._is_batch = False
  621. return res
  622. class Sam3ImageOnVideoMultiGPU(Sam3Image):
  623. def __init__(
  624. self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs
  625. ):
  626. super().__init__(*args, **kwargs)
  627. self.rank = int(os.getenv("RANK", "0"))
  628. self.world_size = int(os.getenv("WORLD_SIZE", "1"))
  629. self.async_all_gather = async_all_gather
  630. # if gather_backbone is not set, default to gathering only for `SAM3VLBackbone`
  631. if gather_backbone_out is None:
  632. gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone)
  633. self.gather_backbone_out = gather_backbone_out
  634. def forward_video_grounding_multigpu(
  635. self,
  636. backbone_out,
  637. find_inputs,
  638. geometric_prompt: Prompt,
  639. frame_idx,
  640. num_frames,
  641. # `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls
  642. multigpu_buffer,
  643. track_in_reverse=False,
  644. # whether to also return the SAM2 backbone features
  645. return_sam2_backbone_feats=False,
  646. # whether to perform NMS and suppress the scores of those detections removed by NMS
  647. run_nms=False,
  648. nms_prob_thresh=None,
  649. nms_iou_thresh=None,
  650. **kwargs,
  651. ):
  652. """
  653. Compute the detector's detection outputs in a distributed manner, where all GPUs process
  654. a chunk of frames (equal to the number of GPUs) at once and store them in cache.
  655. """
  656. # Step 1: fetch the detector outputs in the current chunk from buffer
  657. frame_idx_curr_b = frame_idx - frame_idx % self.world_size
  658. frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames)
  659. # in case the current frame's detection results are not in the buffer yet, build the current chunk
  660. # (this should only happen on the first chunk, since we are also building the next chunk below)
  661. if frame_idx not in multigpu_buffer:
  662. with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"):
  663. self._build_multigpu_buffer_next_chunk(
  664. backbone_out=backbone_out,
  665. find_inputs=find_inputs,
  666. geometric_prompt=geometric_prompt,
  667. frame_idx_begin=frame_idx_curr_b,
  668. frame_idx_end=frame_idx_curr_e,
  669. num_frames=num_frames,
  670. multigpu_buffer=multigpu_buffer,
  671. run_nms=run_nms,
  672. nms_prob_thresh=nms_prob_thresh,
  673. nms_iou_thresh=nms_iou_thresh,
  674. )
  675. # read out the current frame's results from `multigpu_buffer`
  676. out = {}
  677. for k, (v, handle) in multigpu_buffer[frame_idx].items():
  678. if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats:
  679. continue
  680. if handle is not None:
  681. handle.wait() # wait for async all-gather to finish
  682. out[k] = v
  683. # Step 2: remove detection outputs of the previous chunk from cache to save GPU memory
  684. if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
  685. frame_idx_prev_e = frame_idx_curr_b
  686. frame_idx_prev_b = frame_idx_curr_b - self.world_size
  687. elif track_in_reverse and frame_idx_curr_e < num_frames:
  688. frame_idx_prev_b = frame_idx_curr_e
  689. frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames)
  690. else:
  691. frame_idx_prev_b = frame_idx_prev_e = None
  692. if frame_idx_prev_b is not None:
  693. for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e):
  694. multigpu_buffer.pop(frame_idx_rm, None)
  695. # Step 3: compute and cache detection outputs of the next chunk ahead of time
  696. # (so that we can overlap computation with all-gather transfer)
  697. if not track_in_reverse and frame_idx_curr_e < num_frames:
  698. frame_idx_next_b = frame_idx_curr_e
  699. frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames)
  700. elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
  701. frame_idx_next_e = frame_idx_curr_b
  702. frame_idx_next_b = frame_idx_curr_b - self.world_size
  703. else:
  704. frame_idx_next_b = frame_idx_next_e = None
  705. if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer:
  706. with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"):
  707. self._build_multigpu_buffer_next_chunk(
  708. backbone_out=backbone_out,
  709. find_inputs=find_inputs,
  710. geometric_prompt=geometric_prompt,
  711. frame_idx_begin=frame_idx_next_b,
  712. frame_idx_end=frame_idx_next_e,
  713. num_frames=num_frames,
  714. multigpu_buffer=multigpu_buffer,
  715. run_nms=run_nms,
  716. nms_prob_thresh=nms_prob_thresh,
  717. nms_iou_thresh=nms_iou_thresh,
  718. )
  719. return out, backbone_out
  720. def _build_multigpu_buffer_next_chunk(
  721. self,
  722. backbone_out,
  723. find_inputs,
  724. geometric_prompt: Prompt,
  725. frame_idx_begin,
  726. frame_idx_end,
  727. num_frames,
  728. multigpu_buffer,
  729. run_nms=False,
  730. nms_prob_thresh=None,
  731. nms_iou_thresh=None,
  732. ):
  733. """Compute detection outputs on a chunk of frames and store their results in multigpu_buffer."""
  734. # each GPU computes detections on one frame in the chunk (in a round-robin manner)
  735. frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1)
  736. # `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame
  737. with torch.profiler.record_function("forward_grounding"):
  738. out_local = self.forward_grounding(
  739. backbone_out=backbone_out,
  740. find_input=find_inputs[frame_idx_local_gpu],
  741. find_target=None,
  742. geometric_prompt=geometric_prompt,
  743. )
  744. if run_nms:
  745. with torch.profiler.record_function("nms_masks"):
  746. # run NMS as a post-processing step on top of the detection outputs
  747. assert nms_prob_thresh is not None and nms_iou_thresh is not None
  748. pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid()
  749. pred_masks = out_local["pred_masks"]
  750. # loop over text prompts (not an overhead for demo where there's only 1 prompt)
  751. for prompt_idx in range(pred_probs.size(0)):
  752. keep = nms_masks(
  753. pred_probs=pred_probs[prompt_idx],
  754. pred_masks=pred_masks[prompt_idx],
  755. prob_threshold=nms_prob_thresh,
  756. iou_threshold=nms_iou_thresh,
  757. )
  758. # set a very low threshold for those detections removed by NMS
  759. out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float()
  760. if self.gather_backbone_out:
  761. # gather the SAM 2 backbone features across GPUs
  762. feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"]
  763. assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels
  764. # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
  765. # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
  766. backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]]
  767. fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0])
  768. fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1])
  769. fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2])
  770. # vision_pos_enc is the same on all frames, so no need to all-gather them
  771. vision_pos_enc = feats["vision_pos_enc"]
  772. # trim the detector output to only include the necessary keys
  773. out_local = {
  774. "pred_logits": out_local["pred_logits"],
  775. "pred_boxes": out_local["pred_boxes"],
  776. "pred_boxes_xyxy": out_local["pred_boxes_xyxy"],
  777. "pred_masks": out_local["pred_masks"],
  778. }
  779. # gather the results: after this step, each GPU will receive detector outputs on
  780. # all frames in the chunk and store them in `multigpu_buffer`
  781. out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
  782. for rank in range(self.world_size):
  783. frame_idx_to_save = frame_idx_begin + rank
  784. if frame_idx_to_save >= num_frames:
  785. continue
  786. frame_buffer = {
  787. k: (v[rank], handle) for k, (v, handle) in out_gathered.items()
  788. }
  789. if self.gather_backbone_out:
  790. # also add gathered SAM 2 backbone features to frame_buffer
  791. frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0)
  792. frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1)
  793. frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2)
  794. frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None)
  795. multigpu_buffer[frame_idx_to_save] = frame_buffer
  796. def _gather_tensor(self, x):
  797. if self.world_size == 1:
  798. return [x], None
  799. async_op = self.async_all_gather
  800. # here `.contiguous()` is required -- otherwise NCCL all_gather
  801. # sometimes gives wrong results
  802. x = x.contiguous() # ensure contiguous memory for NCCL
  803. output_list = [torch.empty_like(x) for _ in range(self.world_size)]
  804. handle = torch.distributed.all_gather(output_list, x, async_op=async_op)
  805. return output_list, handle