sam3_video_base.py 82 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import datetime
  4. import logging
  5. import math
  6. import os
  7. from collections import defaultdict
  8. from copy import deepcopy
  9. from enum import Enum
  10. from typing import Any, Dict, List, Set
  11. import numpy as np
  12. import numpy.typing as npt
  13. import torch
  14. import torch.distributed as dist
  15. import torch.nn.functional as F
  16. from sam3 import perflib
  17. from sam3.logger import get_logger
  18. from sam3.model.box_ops import fast_diag_box_iou
  19. from sam3.model.data_misc import BatchedDatapoint
  20. from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box
  21. from sam3.perflib.masks_ops import mask_iou
  22. from sam3.train.masks_ops import rle_encode
  23. from torch import nn, Tensor
  24. logger = get_logger(__name__)
  25. class MaskletConfirmationStatus(Enum):
  26. UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet
  27. CONFIRMED = 2 # confirmed by at least one detection
  28. class Sam3VideoBase(nn.Module):
  29. def __init__(
  30. self,
  31. detector: nn.Module,
  32. tracker: nn.Module,
  33. # prob threshold for detection outputs -- only keep detections above this threshold
  34. # enters NMS and det-to-track matching
  35. score_threshold_detection=0.5,
  36. # IoU threshold for detection NMS
  37. det_nms_thresh=0.0,
  38. # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it
  39. # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1
  40. assoc_iou_thresh=0.5,
  41. # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched"
  42. # by any detections -- it is often a stricter threshold like 0.5
  43. trk_assoc_iou_thresh=0.5,
  44. # prob threshold for a detection to be added as a new object
  45. new_det_thresh=0.0,
  46. # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and
  47. # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh`
  48. # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh`
  49. hotstart_delay=0,
  50. hotstart_unmatch_thresh=3,
  51. hotstart_dup_thresh=3,
  52. # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period.
  53. suppress_unmatched_only_within_hotstart=True,
  54. init_trk_keep_alive=0,
  55. max_trk_keep_alive=8,
  56. min_trk_keep_alive=-4,
  57. # Threshold for suppressing overlapping objects based on recent occlusion
  58. suppress_overlapping_based_on_recent_occlusion_threshold=0.0,
  59. decrease_trk_keep_alive_for_empty_masklets=False,
  60. o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets
  61. suppress_det_close_to_boundary=False,
  62. fill_hole_area=16,
  63. # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1)
  64. max_num_objects=-1,
  65. recondition_every_nth_frame=-1,
  66. # masket confirmation status (to suppress unconfirmed masklets)
  67. masklet_confirmation_enable=False,
  68. # a masklet is confirmed after being consecutively detected and matched for
  69. # `masklet_confirmation_consecutive_det_thresh`
  70. masklet_confirmation_consecutive_det_thresh=3,
  71. # bbox heuristic parameters
  72. reconstruction_bbox_iou_thresh=0.0,
  73. reconstruction_bbox_det_score=0.0,
  74. ):
  75. super().__init__()
  76. self.detector = detector
  77. self.tracker = tracker
  78. self.score_threshold_detection = score_threshold_detection
  79. self.det_nms_thresh = det_nms_thresh
  80. self.assoc_iou_thresh = assoc_iou_thresh
  81. self.trk_assoc_iou_thresh = trk_assoc_iou_thresh
  82. self.new_det_thresh = new_det_thresh
  83. # hotstart parameters
  84. if hotstart_delay > 0:
  85. assert hotstart_unmatch_thresh <= hotstart_delay
  86. assert hotstart_dup_thresh <= hotstart_delay
  87. self.hotstart_delay = hotstart_delay
  88. self.hotstart_unmatch_thresh = hotstart_unmatch_thresh
  89. self.hotstart_dup_thresh = hotstart_dup_thresh
  90. self.suppress_unmatched_only_within_hotstart = (
  91. suppress_unmatched_only_within_hotstart
  92. )
  93. self.init_trk_keep_alive = init_trk_keep_alive
  94. self.max_trk_keep_alive = max_trk_keep_alive
  95. self.min_trk_keep_alive = min_trk_keep_alive
  96. self.suppress_overlapping_based_on_recent_occlusion_threshold = (
  97. suppress_overlapping_based_on_recent_occlusion_threshold
  98. )
  99. self.suppress_det_close_to_boundary = suppress_det_close_to_boundary
  100. self.decrease_trk_keep_alive_for_empty_masklets = (
  101. decrease_trk_keep_alive_for_empty_masklets
  102. )
  103. self.o2o_matching_masklets_enable = o2o_matching_masklets_enable
  104. self.fill_hole_area = fill_hole_area
  105. self.eval()
  106. self.rank = int(os.getenv("RANK", "0"))
  107. self.world_size = int(os.getenv("WORLD_SIZE", "1"))
  108. self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use)
  109. # the maximum object number
  110. if max_num_objects > 0:
  111. num_obj_for_compile = math.ceil(max_num_objects / self.world_size)
  112. else:
  113. max_num_objects = 10000 # no limit
  114. num_obj_for_compile = 16
  115. logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}")
  116. self.max_num_objects = max_num_objects
  117. self.num_obj_for_compile = num_obj_for_compile
  118. self.recondition_every_nth_frame = recondition_every_nth_frame
  119. self.masklet_confirmation_enable = masklet_confirmation_enable
  120. self.masklet_confirmation_consecutive_det_thresh = (
  121. masklet_confirmation_consecutive_det_thresh
  122. )
  123. self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh
  124. self.reconstruction_bbox_det_score = reconstruction_bbox_det_score
  125. @property
  126. def device(self):
  127. self._device = getattr(self, "_device", None) or next(self.parameters()).device
  128. return self._device
  129. def _init_dist_pg_cpu(self):
  130. # a short 3-min timeout to quickly detect any synchronization failures
  131. timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180"))
  132. timeout = datetime.timedelta(seconds=timeout_sec)
  133. self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout)
  134. def broadcast_python_obj_cpu(self, python_obj_list, src):
  135. if self._dist_pg_cpu is None:
  136. self._init_dist_pg_cpu()
  137. dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu)
  138. def _det_track_one_frame(
  139. self,
  140. frame_idx: int,
  141. num_frames: int,
  142. reverse: bool,
  143. input_batch: BatchedDatapoint,
  144. geometric_prompt: Any,
  145. tracker_states_local: List[Any],
  146. tracker_metadata_prev: Dict[str, Any],
  147. feature_cache: Dict,
  148. orig_vid_height: int,
  149. orig_vid_width: int,
  150. is_image_only: bool = False,
  151. allow_new_detections: bool = True,
  152. ):
  153. """
  154. This function handles one-step inference for the DenseTracking model in an SPMD manner.
  155. At a high-level, all GPUs execute the same function calls as if it's done on a single GPU,
  156. while under the hood, some function calls involve distributed computation based on sharded
  157. SAM2 states.
  158. - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs
  159. - `tracker_states_local` holds the local masklet information in this GPU shard
  160. - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs
  161. it contains both global and local masklet information
  162. """
  163. # Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU,
  164. # a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner.
  165. # It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx`
  166. # into `feature_cache`. Despite its distributed inference under the hood, the results would be
  167. # the same as if it is running backbone and detector for every frame on a single GPU.
  168. det_out = self.run_backbone_and_detection(
  169. frame_idx=frame_idx,
  170. num_frames=num_frames,
  171. reverse=reverse,
  172. input_batch=input_batch,
  173. geometric_prompt=geometric_prompt,
  174. feature_cache=feature_cache,
  175. allow_new_detections=allow_new_detections,
  176. )
  177. # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks.
  178. # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions
  179. # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only
  180. # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks;
  181. # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics.
  182. if tracker_metadata_prev == {}:
  183. # initialize masklet metadata if it's uninitialized (empty dict)
  184. tracker_metadata_prev.update(self._initialize_metadata())
  185. tracker_low_res_masks_global, tracker_obj_scores_global = (
  186. self.run_tracker_propagation(
  187. frame_idx=frame_idx,
  188. num_frames=num_frames,
  189. reverse=reverse,
  190. tracker_states_local=tracker_states_local,
  191. tracker_metadata_prev=tracker_metadata_prev,
  192. )
  193. )
  194. # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans
  195. # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc).
  196. # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints.
  197. # **This step should involve all the heuristics needed for any updates.** Most of the update
  198. # planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is
  199. # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the
  200. # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`).
  201. tracker_update_plan, tracker_metadata_new = (
  202. self.run_tracker_update_planning_phase(
  203. frame_idx=frame_idx,
  204. num_frames=num_frames,
  205. reverse=reverse,
  206. det_out=det_out,
  207. tracker_low_res_masks_global=tracker_low_res_masks_global,
  208. tracker_obj_scores_global=tracker_obj_scores_global,
  209. tracker_metadata_prev=tracker_metadata_prev,
  210. tracker_states_local=tracker_states_local,
  211. is_image_only=is_image_only,
  212. )
  213. )
  214. # Get reconditioning info from the update plan
  215. reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set())
  216. det_to_matched_trk_obj_ids = tracker_update_plan.get(
  217. "det_to_matched_trk_obj_ids", {}
  218. )
  219. # Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states
  220. tracker_states_local_new = self.run_tracker_update_execution_phase(
  221. frame_idx=frame_idx,
  222. num_frames=num_frames,
  223. reverse=reverse,
  224. det_out=det_out,
  225. tracker_states_local=tracker_states_local,
  226. tracker_update_plan=tracker_update_plan,
  227. orig_vid_height=orig_vid_height,
  228. orig_vid_width=orig_vid_width,
  229. feature_cache=feature_cache,
  230. )
  231. # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since
  232. # only GPU 0 will send outputs to the server).
  233. if self.rank == 0:
  234. obj_id_to_mask = self.build_outputs(
  235. frame_idx=frame_idx,
  236. num_frames=num_frames,
  237. reverse=reverse,
  238. det_out=det_out,
  239. tracker_low_res_masks_global=tracker_low_res_masks_global,
  240. tracker_obj_scores_global=tracker_obj_scores_global,
  241. tracker_metadata_prev=tracker_metadata_prev,
  242. tracker_update_plan=tracker_update_plan,
  243. orig_vid_height=orig_vid_height,
  244. orig_vid_width=orig_vid_width,
  245. reconditioned_obj_ids=reconditioned_obj_ids,
  246. det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
  247. )
  248. obj_id_to_score = tracker_metadata_new["obj_id_to_score"]
  249. else:
  250. obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs
  251. # a few statistics for the current frame as a part of the output
  252. frame_stats = {
  253. "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]),
  254. "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"],
  255. }
  256. # add tracker scores to metadata, it should be fired for frames except the first frame
  257. if tracker_obj_scores_global.shape[0] > 0:
  258. # Convert tracker_obj_scores_global to sigmoid scores before updating
  259. tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist()
  260. tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
  261. tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
  262. frame_idx
  263. ].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global)))
  264. return (
  265. obj_id_to_mask, # a dict: obj_id --> output mask
  266. obj_id_to_score, # a dict: obj_id --> output score (prob)
  267. tracker_states_local_new,
  268. tracker_metadata_new,
  269. frame_stats,
  270. tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores
  271. )
  272. def _suppress_detections_close_to_boundary(self, boxes, margin=0.025):
  273. """
  274. Suppress detections too close to image edges (for normalized boxes).
  275. boxes: (N, 4) in xyxy format, normalized [0,1]
  276. margin: fraction of image
  277. """
  278. x_min, y_min, x_max, y_max = boxes.unbind(-1)
  279. x_c = (x_min + x_max) / 2
  280. y_c = (y_min + y_max) / 2
  281. keep = (
  282. (x_c > margin)
  283. & (x_c < 1.0 - margin)
  284. & (y_c > margin)
  285. & (y_c < 1.0 - margin)
  286. )
  287. return keep
  288. def run_backbone_and_detection(
  289. self,
  290. frame_idx: int,
  291. num_frames: int,
  292. input_batch: BatchedDatapoint,
  293. geometric_prompt: Any,
  294. feature_cache: Dict,
  295. reverse: bool,
  296. allow_new_detections: bool,
  297. ):
  298. # Step 1: if text feature is not cached in `feature_cache`, compute and cache it
  299. text_batch_key = tuple(input_batch.find_text_batch)
  300. if "text" not in feature_cache or text_batch_key not in feature_cache["text"]:
  301. text_outputs = self.detector.backbone.forward_text(
  302. input_batch.find_text_batch, device=self.device
  303. )
  304. # note: we only cache the text feature of the most recent prompt
  305. feature_cache["text"] = {text_batch_key: text_outputs}
  306. else:
  307. text_outputs = feature_cache["text"][text_batch_key]
  308. # Step 2: run backbone, detector, and post-processing with NMS
  309. if "multigpu_buffer" not in feature_cache:
  310. # "multigpu_buffer" is a buffer cache used by `self.detector` and it needs
  311. # to be passed to `forward_video_grounding_multigpu` for every call
  312. feature_cache["multigpu_buffer"] = {}
  313. # Extract max_frame_num_to_track from feature_cache if available
  314. tracking_bounds = feature_cache.get("tracking_bounds", {})
  315. max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track")
  316. start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx")
  317. sam3_image_out, _ = self.detector.forward_video_grounding_multigpu(
  318. backbone_out={
  319. "img_batch_all_stages": input_batch.img_batch,
  320. **text_outputs,
  321. },
  322. find_inputs=input_batch.find_inputs,
  323. geometric_prompt=geometric_prompt,
  324. frame_idx=frame_idx,
  325. num_frames=num_frames,
  326. multigpu_buffer=feature_cache["multigpu_buffer"],
  327. track_in_reverse=reverse,
  328. # also get the SAM2 backbone features
  329. return_tracker_backbone_feats=True,
  330. # run NMS as a part of distributed computation
  331. run_nms=self.det_nms_thresh > 0.0,
  332. nms_prob_thresh=self.score_threshold_detection,
  333. nms_iou_thresh=self.det_nms_thresh,
  334. # pass max_frame_num_to_track to respect tracking limits
  335. max_frame_num_to_track=max_frame_num_to_track,
  336. propagate_in_video_start_frame_idx=start_frame_idx,
  337. )
  338. # note: detections in `sam3_image_out` has already gone through NMS
  339. pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid()
  340. if not allow_new_detections:
  341. pred_probs = pred_probs - 1e8 # make sure no detections are kept
  342. pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"]
  343. pred_masks = sam3_image_out["pred_masks"]
  344. # get the positive detection outputs above threshold
  345. pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection)
  346. det_out = {
  347. "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]],
  348. "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]],
  349. "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]],
  350. }
  351. # Step 3: build SAM2 backbone features and store them in `feature_cache`
  352. backbone_cache = {}
  353. sam_mask_decoder = self.tracker.sam_mask_decoder
  354. tracker_backbone_fpn = [
  355. sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]),
  356. sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]),
  357. sam3_image_out["tracker_backbone_fpn_2"], # fpn_2 doesn't need conv
  358. ]
  359. tracker_backbone_out = {
  360. "vision_features": tracker_backbone_fpn[-1], # top-level feature
  361. "vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"],
  362. "backbone_fpn": tracker_backbone_fpn,
  363. }
  364. backbone_cache["tracker_backbone_out"] = tracker_backbone_out
  365. feature_cache[frame_idx] = (
  366. input_batch.img_batch[frame_idx],
  367. backbone_cache,
  368. )
  369. # remove from `feature_cache` old features to save GPU memory
  370. feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None)
  371. return det_out
  372. def run_tracker_propagation(
  373. self,
  374. frame_idx: int,
  375. num_frames: int,
  376. reverse: bool,
  377. tracker_states_local: List[Any],
  378. tracker_metadata_prev: Dict[str, npt.NDArray],
  379. ):
  380. # Step 1: propagate the local SAM2 states to get the current frame's prediction
  381. # `low_res_masks_local` of the existing masklets on this GPU
  382. # - obj_ids_local: List[int] -- list of object IDs
  383. # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask)
  384. obj_ids_local, low_res_masks_local, obj_scores_local = (
  385. self._propogate_tracker_one_frame_local_gpu(
  386. tracker_states_local, frame_idx=frame_idx, reverse=reverse
  387. )
  388. )
  389. assert np.all(
  390. obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
  391. ), "{} != {}".format(
  392. obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
  393. )
  394. # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global`
  395. # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
  396. _, H_mask, W_mask = low_res_masks_local.shape
  397. if self.world_size > 1:
  398. # `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32
  399. # (they could be non-contiguous due to slicing and/or bfloat16 due to autocast)
  400. low_res_masks_local = low_res_masks_local.float().contiguous()
  401. obj_scores_local = obj_scores_local.float().contiguous()
  402. num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank]
  403. assert low_res_masks_local.size(0) == num_obj_this_gpu
  404. assert obj_scores_local.size(0) == num_obj_this_gpu
  405. low_res_masks_peers = [
  406. low_res_masks_local.new_empty(num_obj, H_mask, W_mask)
  407. for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
  408. ]
  409. obj_scores_peers = [
  410. obj_scores_local.new_empty(num_obj)
  411. for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
  412. ]
  413. dist.all_gather(low_res_masks_peers, low_res_masks_local)
  414. dist.all_gather(obj_scores_peers, obj_scores_local)
  415. low_res_masks_global = torch.cat(low_res_masks_peers, dim=0)
  416. obj_scores_global = torch.cat(obj_scores_peers, dim=0)
  417. else:
  418. low_res_masks_global = low_res_masks_local
  419. obj_scores_global = obj_scores_local
  420. return low_res_masks_global, obj_scores_global
  421. def _recondition_masklets(
  422. self,
  423. frame_idx,
  424. det_out: Dict[str, Tensor],
  425. trk_id_to_max_iou_high_conf_det: List[int],
  426. tracker_states_local: List[Any],
  427. tracker_metadata: Dict[str, npt.NDArray],
  428. tracker_obj_scores_global: Tensor,
  429. ):
  430. # Recondition the masklets based on the new detections
  431. for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
  432. new_mask = det_out["mask"][det_idx : det_idx + 1]
  433. input_mask_res = self.tracker.input_mask_size
  434. new_mask_binary = (
  435. F.interpolate(
  436. new_mask.unsqueeze(1),
  437. size=(input_mask_res, input_mask_res),
  438. mode="bilinear",
  439. align_corners=False,
  440. ).squeeze(1)[0]
  441. > 0
  442. )
  443. HIGH_CONF_THRESH = 0.8
  444. reconditioned_states_idx = set()
  445. obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[
  446. 0
  447. ].item()
  448. obj_score = tracker_obj_scores_global[obj_idx]
  449. for state_idx, inference_state in enumerate(tracker_states_local):
  450. if (
  451. trk_obj_id in inference_state["obj_ids"]
  452. # NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy.
  453. # Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics.
  454. and obj_score > HIGH_CONF_THRESH
  455. ):
  456. logger.debug(
  457. f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned."
  458. )
  459. self.tracker.add_new_mask(
  460. inference_state=inference_state,
  461. frame_idx=frame_idx,
  462. obj_id=trk_obj_id,
  463. mask=new_mask_binary,
  464. )
  465. reconditioned_states_idx.add(state_idx)
  466. for idx in reconditioned_states_idx:
  467. self.tracker.propagate_in_video_preflight(
  468. tracker_states_local[idx], run_mem_encoder=True
  469. )
  470. return tracker_states_local
  471. def run_tracker_update_planning_phase(
  472. self,
  473. frame_idx: int,
  474. num_frames: int,
  475. reverse: bool,
  476. det_out: Dict[str, Tensor],
  477. tracker_low_res_masks_global: Tensor,
  478. tracker_obj_scores_global: Tensor,
  479. tracker_metadata_prev: Dict[str, npt.NDArray],
  480. tracker_states_local: List[Any],
  481. is_image_only: bool = False,
  482. ):
  483. # initialize new metadata from previous metadata (its values will be updated later)
  484. tracker_metadata_new = {
  485. "obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]),
  486. "obj_ids_all_gpu": None, # will be filled later
  487. "num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]),
  488. "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]),
  489. "obj_id_to_tracker_score_frame_wise": deepcopy(
  490. tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"]
  491. ),
  492. "obj_id_to_last_occluded": {}, # will be filled later
  493. "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]),
  494. }
  495. # Initialize reconditioned_obj_ids early to avoid UnboundLocalError
  496. reconditioned_obj_ids = set()
  497. # Step 1: make the update plan and resolve heuristics on GPU 0
  498. det_mask_preds: Tensor = det_out["mask"] # low-res mask logits
  499. det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy()
  500. det_bbox_xyxy: Tensor = det_out["bbox"]
  501. if self.rank == 0:
  502. # a) match detector and tracker masks and find new objects
  503. (
  504. new_det_fa_inds,
  505. unmatched_trk_obj_ids,
  506. det_to_matched_trk_obj_ids,
  507. trk_id_to_max_iou_high_conf_det,
  508. empty_trk_obj_ids,
  509. ) = self._associate_det_trk(
  510. det_masks=det_mask_preds,
  511. det_scores_np=det_scores_np,
  512. trk_masks=tracker_low_res_masks_global,
  513. trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"],
  514. )
  515. if self.suppress_det_close_to_boundary:
  516. keep = self._suppress_detections_close_to_boundary(
  517. det_bbox_xyxy[new_det_fa_inds]
  518. )
  519. new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()]
  520. # check whether we've hit the maximum number of objects we can track (and if so, drop some detections)
  521. prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"])
  522. new_det_num = len(new_det_fa_inds)
  523. num_obj_dropped_due_to_limit = 0
  524. if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects:
  525. logger.warning(
  526. f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}"
  527. )
  528. new_det_num_to_keep = self.max_num_objects - prev_obj_num
  529. num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep
  530. new_det_fa_inds = self._drop_new_det_with_obj_limit(
  531. new_det_fa_inds, det_scores_np, new_det_num_to_keep
  532. )
  533. assert len(new_det_fa_inds) == new_det_num_to_keep
  534. new_det_num = len(new_det_fa_inds)
  535. # assign object IDs to new detections and decide which GPU to place them
  536. new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1
  537. new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num)
  538. prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"]
  539. new_det_gpu_ids = self._assign_new_det_to_gpus(
  540. new_det_num=new_det_num,
  541. prev_workload_per_gpu=prev_workload_per_gpu,
  542. )
  543. # b) handle hotstart heuristics to remove objects
  544. # here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0;
  545. # we avoid broadcasting them to other GPUs to save communication cost, assuming
  546. # that `rank0_metadata` is not needed by other GPUs
  547. rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"])
  548. if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
  549. obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart(
  550. frame_idx=frame_idx,
  551. num_frames=num_frames,
  552. reverse=reverse,
  553. det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
  554. new_det_obj_ids=new_det_obj_ids,
  555. empty_trk_obj_ids=empty_trk_obj_ids,
  556. unmatched_trk_obj_ids=unmatched_trk_obj_ids,
  557. rank0_metadata=rank0_metadata_new,
  558. tracker_metadata=tracker_metadata_prev,
  559. )
  560. else:
  561. # if warm-up is not complete, we don't remove any objects
  562. obj_ids_newly_removed = set()
  563. tracker_metadata_new["rank0_metadata"] = rank0_metadata_new
  564. # Step 2: broadcast the update plan to other GPUs
  565. NUM_BROADCAST_ITEMS = 9
  566. if self.rank == 0 and self.world_size > 1:
  567. # `num_obj_per_gpu_on_rank0` is used for metadata consistency check on other GPUs
  568. # (it's a small array with length==self.world_size, so broadcasting it is cheap)
  569. num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"]
  570. update_plan = [
  571. new_det_fa_inds,
  572. new_det_obj_ids,
  573. new_det_gpu_ids,
  574. num_obj_per_gpu_on_rank0,
  575. unmatched_trk_obj_ids,
  576. det_to_matched_trk_obj_ids,
  577. obj_ids_newly_removed,
  578. num_obj_dropped_due_to_limit,
  579. trk_id_to_max_iou_high_conf_det,
  580. ]
  581. assert len(update_plan) == NUM_BROADCAST_ITEMS, (
  582. f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
  583. )
  584. self.broadcast_python_obj_cpu(update_plan, src=0)
  585. elif self.rank > 0 and self.world_size > 1:
  586. update_plan = [
  587. None
  588. ] * NUM_BROADCAST_ITEMS # other ranks receive the plan from rank 0
  589. self.broadcast_python_obj_cpu(update_plan, src=0)
  590. (
  591. new_det_fa_inds,
  592. new_det_obj_ids,
  593. new_det_gpu_ids,
  594. num_obj_per_gpu_on_rank0,
  595. unmatched_trk_obj_ids,
  596. det_to_matched_trk_obj_ids,
  597. obj_ids_newly_removed,
  598. num_obj_dropped_due_to_limit,
  599. trk_id_to_max_iou_high_conf_det,
  600. ) = update_plan
  601. # metadata consistency check: verify that the received `num_obj_per_gpu_on_rank0` is consistent with the local metadata
  602. # it's critical that all GPUs agree on the previous number of objects (otherwise the inference might hang or fail silently)
  603. if not np.all(
  604. num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"]
  605. ):
  606. raise RuntimeError(
  607. f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record "
  608. f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution."
  609. )
  610. # `tracker_update_plan` should be identical on all GPUs after broadcasting
  611. tracker_update_plan = {
  612. "new_det_fa_inds": new_det_fa_inds, # npt.NDArray
  613. "new_det_obj_ids": new_det_obj_ids, # npt.NDArray
  614. "new_det_gpu_ids": new_det_gpu_ids, # npt.NDArray
  615. "unmatched_trk_obj_ids": unmatched_trk_obj_ids, # npt.NDArray
  616. "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict
  617. "obj_ids_newly_removed": obj_ids_newly_removed, # set
  618. "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int
  619. "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict
  620. "reconditioned_obj_ids": reconditioned_obj_ids, # set
  621. }
  622. # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding
  623. # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results
  624. should_recondition_iou = False
  625. # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections
  626. if (
  627. self.reconstruction_bbox_iou_thresh > 0
  628. and len(trk_id_to_max_iou_high_conf_det) > 0
  629. ):
  630. for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
  631. det_box = det_out["bbox"][det_idx]
  632. det_score = det_out["scores"][det_idx]
  633. try:
  634. trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index(
  635. trk_obj_id
  636. )
  637. except ValueError:
  638. continue # Skip if tracklet not found
  639. tracker_mask = tracker_low_res_masks_global[trk_idx]
  640. mask_binary = tracker_mask > 0
  641. mask_area = mask_binary.sum().item()
  642. if mask_area == 0:
  643. continue # Skip tracklets with zero mask area
  644. # Get bounding box from SAM2 mask and convert to normalized coordinates
  645. tracker_box_pixels = (
  646. mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0))
  647. .squeeze(0)
  648. .squeeze(0)
  649. )
  650. mask_height, mask_width = tracker_mask.shape[-2:]
  651. tracker_box_normalized = torch.tensor(
  652. [
  653. tracker_box_pixels[0] / mask_width,
  654. tracker_box_pixels[1] / mask_height,
  655. tracker_box_pixels[2] / mask_width,
  656. tracker_box_pixels[3] / mask_height,
  657. ],
  658. device=tracker_box_pixels.device,
  659. )
  660. # Compute IoU between detection and SAM2 tracklet bounding boxes
  661. det_box_batch = det_box.unsqueeze(0)
  662. tracker_box_batch = tracker_box_normalized.unsqueeze(0)
  663. iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0]
  664. if (
  665. iou < self.reconstruction_bbox_iou_thresh
  666. and det_score >= self.reconstruction_bbox_det_score
  667. ):
  668. should_recondition_iou = True
  669. reconditioned_obj_ids.add(trk_obj_id)
  670. should_recondition_periodic = (
  671. self.recondition_every_nth_frame > 0
  672. and frame_idx % self.recondition_every_nth_frame == 0
  673. and len(trk_id_to_max_iou_high_conf_det) > 0
  674. )
  675. # Recondition if periodic or IoU condition met
  676. if should_recondition_periodic or should_recondition_iou:
  677. self._recondition_masklets(
  678. frame_idx,
  679. det_out,
  680. trk_id_to_max_iou_high_conf_det,
  681. tracker_states_local,
  682. tracker_metadata_prev,
  683. tracker_obj_scores_global,
  684. )
  685. # Step 4: Run SAM2 memory encoder on the current frame's prediction masks
  686. # This is done on all GPUs
  687. batch_size = tracker_low_res_masks_global.size(0)
  688. if batch_size > 0:
  689. if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
  690. if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0:
  691. # NOTE: tracker_low_res_masks_global is updated in-place then returned
  692. tracker_low_res_masks_global = (
  693. self._suppress_overlapping_based_on_recent_occlusion(
  694. frame_idx,
  695. tracker_low_res_masks_global,
  696. tracker_metadata_prev,
  697. tracker_metadata_new,
  698. obj_ids_newly_removed,
  699. reverse,
  700. )
  701. )
  702. self._tracker_update_memories(
  703. tracker_states_local,
  704. frame_idx,
  705. tracker_metadata=tracker_metadata_prev,
  706. low_res_masks=tracker_low_res_masks_global,
  707. )
  708. # Step 4: update the SAM2 metadata based on the update plan
  709. # note: except for "rank0_metadata" (that is only available on GPU 0),
  710. # the updated `tracker_metadata_new` should be identical on all GPUs
  711. for rank in range(self.world_size):
  712. new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank]
  713. updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank]
  714. if len(new_det_obj_ids_this_gpu) > 0:
  715. updated_obj_ids_this_gpu = np.concatenate(
  716. [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu]
  717. )
  718. if len(obj_ids_newly_removed) > 0:
  719. is_removed = np.isin(
  720. updated_obj_ids_this_gpu, list(obj_ids_newly_removed)
  721. )
  722. updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed]
  723. tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu
  724. tracker_metadata_new["num_obj_per_gpu"][rank] = len(
  725. updated_obj_ids_this_gpu
  726. )
  727. tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate(
  728. tracker_metadata_new["obj_ids_per_gpu"]
  729. )
  730. # update object scores and the maximum object ID assigned so far
  731. if len(new_det_obj_ids) > 0:
  732. tracker_metadata_new["obj_id_to_score"].update(
  733. zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])
  734. )
  735. # tracker scores are not available for new objects, use det score instead.
  736. tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
  737. frame_idx
  738. ].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]))
  739. tracker_metadata_new["max_obj_id"] = max(
  740. tracker_metadata_new["max_obj_id"],
  741. np.max(new_det_obj_ids),
  742. )
  743. # for removed objects, we set their scores to a very low value (-1e4) but still
  744. # keep them in "obj_id_to_score" (it's easier to handle outputs this way)
  745. for obj_id in obj_ids_newly_removed:
  746. tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4
  747. tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][
  748. obj_id
  749. ] = -1e4
  750. tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None)
  751. # check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0
  752. assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0)
  753. if self.rank == 0 and self.masklet_confirmation_enable:
  754. rank0_metadata = self.update_masklet_confirmation_status(
  755. rank0_metadata=tracker_metadata_new["rank0_metadata"],
  756. obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"],
  757. obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"],
  758. det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
  759. new_det_obj_ids=new_det_obj_ids,
  760. )
  761. tracker_metadata_new["rank0_metadata"] = rank0_metadata
  762. return tracker_update_plan, tracker_metadata_new
  763. def _suppress_overlapping_based_on_recent_occlusion(
  764. self,
  765. frame_idx: int,
  766. tracker_low_res_masks_global: Tensor,
  767. tracker_metadata_prev: Dict[str, Any],
  768. tracker_metadata_new: Dict[str, Any],
  769. obj_ids_newly_removed: Set[int],
  770. reverse: bool = False,
  771. ):
  772. """
  773. Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object.
  774. Args:
  775. frame_idx (int): The current frame index.
  776. tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame.
  777. tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame.
  778. tracker_metadata_new (Dict[str, Any]): The metadata for the current frame.
  779. obj_ids_newly_removed (Set[int]): The object IDs that have been removed.
  780. Return:
  781. Tensor: The updated low-resolution masks with some objects suppressed.
  782. """
  783. obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"]
  784. binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
  785. batch_size = tracker_low_res_masks_global.size(0)
  786. if batch_size > 0:
  787. assert len(obj_ids_global) == batch_size, (
  788. f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
  789. )
  790. NEVER_OCCLUDED = -1
  791. ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
  792. last_occluded_prev = torch.cat(
  793. [
  794. tracker_metadata_prev["obj_id_to_last_occluded"].get(
  795. obj_id,
  796. torch.full(
  797. (1,),
  798. fill_value=(
  799. NEVER_OCCLUDED
  800. if obj_id not in obj_ids_newly_removed
  801. else ALWAYS_OCCLUDED
  802. ),
  803. device=binary_tracker_low_res_masks_global.device,
  804. dtype=torch.long,
  805. ),
  806. )
  807. for obj_id in obj_ids_global
  808. ],
  809. dim=0,
  810. )
  811. to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded(
  812. binary_tracker_low_res_masks_global,
  813. last_occluded_prev,
  814. obj_ids_global,
  815. frame_idx,
  816. reverse,
  817. )
  818. # Update metadata with occlusion information
  819. is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2)))
  820. is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress
  821. last_occluded_new = last_occluded_prev.clone()
  822. last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx
  823. # Slice out the last occluded frame for each object
  824. tracker_metadata_new["obj_id_to_last_occluded"] = {
  825. obj_id: last_occluded_new[obj_idx : obj_idx + 1]
  826. for obj_idx, obj_id in enumerate(obj_ids_global)
  827. }
  828. # Zero out suppressed masks before memory encoding
  829. NO_OBJ_LOGIT = -10
  830. tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT
  831. return tracker_low_res_masks_global
  832. def run_tracker_update_execution_phase(
  833. self,
  834. frame_idx: int,
  835. num_frames: int,
  836. reverse: bool,
  837. det_out: Dict[str, Tensor],
  838. tracker_states_local: List[Any],
  839. tracker_update_plan: Dict[str, npt.NDArray],
  840. orig_vid_height: int,
  841. orig_vid_width: int,
  842. feature_cache: Dict,
  843. ):
  844. # initialize tracking scores with detection scores
  845. new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
  846. new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
  847. new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"]
  848. is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank
  849. new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu]
  850. new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu]
  851. obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"]
  852. # Step 1: add new objects from the detector to SAM2 inference states
  853. if len(new_det_fa_inds_local) > 0:
  854. new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local)
  855. new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t]
  856. # initialize SAM2 with new object masks
  857. tracker_states_local = self._tracker_add_new_objects(
  858. frame_idx=frame_idx,
  859. num_frames=num_frames,
  860. new_obj_ids=new_det_obj_ids_local,
  861. new_obj_masks=new_det_masks,
  862. tracker_states_local=tracker_states_local,
  863. orig_vid_height=orig_vid_height,
  864. orig_vid_width=orig_vid_width,
  865. feature_cache=feature_cache,
  866. )
  867. # Step 2: remove from SAM2 inference states those objects removed by heuristics
  868. if len(obj_ids_newly_removed) > 0:
  869. self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed)
  870. return tracker_states_local
  871. def build_outputs(
  872. self,
  873. frame_idx: int,
  874. num_frames: int,
  875. reverse: bool,
  876. det_out: Dict[str, Tensor],
  877. tracker_low_res_masks_global: Tensor,
  878. tracker_obj_scores_global: Tensor,
  879. tracker_metadata_prev: Dict[str, npt.NDArray],
  880. tracker_update_plan: Dict[str, npt.NDArray],
  881. orig_vid_height: int,
  882. orig_vid_width: int,
  883. reconditioned_obj_ids: set = None,
  884. det_to_matched_trk_obj_ids: dict = None,
  885. ):
  886. new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
  887. new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
  888. obj_id_to_mask = {} # obj_id --> output mask tensor
  889. # Part 1: masks from previous SAM2 propagation
  890. existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
  891. existing_masklet_video_res_masks = F.interpolate(
  892. tracker_low_res_masks_global.unsqueeze(1),
  893. size=(orig_vid_height, orig_vid_width),
  894. mode="bilinear",
  895. align_corners=False,
  896. ) # (num_obj, 1, H_video, W_video)
  897. existing_masklet_binary = existing_masklet_video_res_masks > 0
  898. assert len(existing_masklet_obj_ids) == len(existing_masklet_binary)
  899. for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary):
  900. obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
  901. # Part 2: masks from new detections
  902. new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds)
  903. new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1)
  904. new_det_low_res_masks = fill_holes_in_mask_scores(
  905. new_det_low_res_masks,
  906. max_area=self.fill_hole_area,
  907. fill_holes=True,
  908. remove_sprinkles=True,
  909. )
  910. new_masklet_video_res_masks = F.interpolate(
  911. new_det_low_res_masks,
  912. size=(orig_vid_height, orig_vid_width),
  913. mode="bilinear",
  914. align_corners=False,
  915. ) # (num_obj, 1, H_video, W_video)
  916. new_masklet_binary = new_masklet_video_res_masks > 0
  917. assert len(new_det_obj_ids) == len(new_masklet_video_res_masks)
  918. for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary):
  919. obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
  920. # Part 3: Override masks for reconditioned objects using detection masks
  921. if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0:
  922. trk_id_to_max_iou_high_conf_det = tracker_update_plan.get(
  923. "trk_id_to_max_iou_high_conf_det", {}
  924. )
  925. for obj_id in reconditioned_obj_ids:
  926. det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id)
  927. if det_idx is not None:
  928. det_mask = det_out["mask"][det_idx]
  929. det_mask = det_mask.unsqueeze(0).unsqueeze(0)
  930. det_mask_resized = (
  931. F.interpolate(
  932. det_mask.float(),
  933. size=(orig_vid_height, orig_vid_width),
  934. mode="bilinear",
  935. align_corners=False,
  936. )
  937. > 0
  938. )
  939. det_mask_final = det_mask_resized.squeeze(0)
  940. obj_id_to_mask[obj_id] = det_mask_final
  941. return obj_id_to_mask
  942. def _get_objects_to_suppress_based_on_most_recently_occluded(
  943. self,
  944. binary_low_res_masks: Tensor,
  945. last_occluded: List[int],
  946. obj_ids: List[int],
  947. frame_idx: int = None,
  948. reverse: bool = False,
  949. ):
  950. # Suppress overlapping masks for objects that were most recently occluded
  951. assert binary_low_res_masks.dtype == torch.bool, (
  952. f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
  953. )
  954. to_suppress = torch.zeros(
  955. binary_low_res_masks.size(0),
  956. device=binary_low_res_masks.device,
  957. dtype=torch.bool,
  958. )
  959. if len(obj_ids) <= 1:
  960. return to_suppress
  961. iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N]
  962. # Create masks for upper triangular matrix (i < j) and IoU threshold
  963. mask_iou_thresh = (
  964. iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold
  965. )
  966. overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N]
  967. last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1)
  968. last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N)
  969. # Suppress most recently occluded
  970. cmp_op = torch.gt if not reverse else torch.lt
  971. suppress_i_mask = (
  972. overlapping_pairs
  973. & cmp_op(
  974. last_occ_expanded_i, last_occ_expanded_j
  975. ) # (last_occ_expanded_i > last_occ_expanded_j)
  976. & (
  977. last_occ_expanded_j > -1
  978. ) # j can suppress i only if i was previously occluded
  979. )
  980. suppress_j_mask = (
  981. overlapping_pairs
  982. & cmp_op(last_occ_expanded_j, last_occ_expanded_i)
  983. & (
  984. last_occ_expanded_i > -1
  985. ) # i can suppress j only if j was previously occluded
  986. )
  987. # Apply suppression
  988. to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0)
  989. # Log for debugging
  990. if (
  991. self.rank == 0
  992. and logger.isEnabledFor(logging.DEBUG)
  993. and frame_idx is not None
  994. ):
  995. suppress_i_mask = suppress_i_mask.cpu().numpy()
  996. suppress_j_mask = suppress_j_mask.cpu().numpy()
  997. last_occluded = last_occluded.cpu().numpy()
  998. # Find all suppression pairs without using torch.where
  999. batch_size = suppress_i_mask.shape[0]
  1000. # Log i-suppression cases (where i gets suppressed in favor of j)
  1001. for i in range(batch_size):
  1002. for j in range(batch_size):
  1003. if suppress_i_mask[i, j]:
  1004. logger.debug(
  1005. f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}"
  1006. )
  1007. # Log j-suppression cases (where j gets suppressed in favor of i)
  1008. for i in range(batch_size):
  1009. for j in range(batch_size):
  1010. if suppress_j_mask[i, j]:
  1011. logger.debug(
  1012. f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}"
  1013. )
  1014. return to_suppress
  1015. def _propogate_tracker_one_frame_local_gpu(
  1016. self,
  1017. inference_states: List[Any],
  1018. frame_idx: int,
  1019. reverse: bool,
  1020. # by default, we disable memory encoding until we gather all outputs
  1021. run_mem_encoder: bool = False,
  1022. ):
  1023. """
  1024. inference_states: List of inference states, each state corresponds to a different set of objects.
  1025. """
  1026. obj_ids_local = []
  1027. low_res_masks_list = []
  1028. obj_scores_list = []
  1029. for inference_state in inference_states:
  1030. if len(inference_state["obj_ids"]) == 0:
  1031. continue # skip propagation on empty inference states
  1032. # propagate one frame
  1033. num_frames_propagated = 0
  1034. for out in self.tracker.propagate_in_video(
  1035. inference_state,
  1036. start_frame_idx=frame_idx,
  1037. # end_frame_idx = start_frame_idx + max_frame_num_to_track
  1038. # (i.e. propagating 1 frame since end_frame_idx is inclusive)
  1039. max_frame_num_to_track=0,
  1040. reverse=reverse,
  1041. tqdm_disable=True,
  1042. run_mem_encoder=run_mem_encoder,
  1043. ):
  1044. out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out
  1045. num_frames_propagated += 1
  1046. # only 1 frames should be propagated
  1047. assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
  1048. f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
  1049. )
  1050. assert isinstance(out_obj_ids, list)
  1051. obj_ids_local.extend(out_obj_ids)
  1052. low_res_masks_list.append(out_low_res_masks.squeeze(1))
  1053. obj_scores_list.append(out_obj_scores.squeeze(1))
  1054. # concatenate the output masklets from all local inference states
  1055. H_mask = W_mask = self.tracker.low_res_mask_size
  1056. if len(low_res_masks_list) > 0:
  1057. low_res_masks_local = torch.cat(low_res_masks_list, dim=0)
  1058. obj_scores_local = torch.cat(obj_scores_list, dim=0)
  1059. assert low_res_masks_local.shape[1:] == (H_mask, W_mask)
  1060. # Apply hole filling to the masks
  1061. low_res_masks_local = fill_holes_in_mask_scores(
  1062. low_res_masks_local.unsqueeze(1),
  1063. max_area=self.fill_hole_area,
  1064. fill_holes=True,
  1065. remove_sprinkles=True,
  1066. )
  1067. low_res_masks_local = low_res_masks_local.squeeze(1)
  1068. else:
  1069. low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device)
  1070. obj_scores_local = torch.zeros(0, device=self.device)
  1071. return obj_ids_local, low_res_masks_local, obj_scores_local
  1072. def _associate_det_trk(
  1073. self,
  1074. det_masks: Tensor,
  1075. det_scores_np: npt.NDArray,
  1076. trk_masks: Tensor,
  1077. trk_obj_ids: npt.NDArray,
  1078. ):
  1079. """
  1080. Match detections on the current frame with the existing masklets.
  1081. Args:
  1082. - det_masks: (N, H, W) tensor of predicted masks
  1083. - det_scores_np: (N,) array of detection scores
  1084. - trk_masks: (M, H, W) tensor of track masks
  1085. - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks
  1086. Returns:
  1087. - new_det_fa_inds: array of new object indices.
  1088. - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched
  1089. to any detections on this frame (for unmatched, we only count masklets with >0 area)
  1090. - det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices
  1091. to the list of matched tracklet object IDs
  1092. - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction
  1093. """
  1094. iou_threshold = self.assoc_iou_thresh
  1095. iou_threshold_trk = self.trk_assoc_iou_thresh
  1096. new_det_thresh = self.new_det_thresh
  1097. assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
  1098. assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
  1099. assert trk_masks.size(0) == len(trk_obj_ids), (
  1100. f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
  1101. )
  1102. if trk_masks.size(0) == 0:
  1103. # all detections are new
  1104. new_det_fa_inds = np.arange(det_masks.size(0))
  1105. unmatched_trk_obj_ids = np.array([], np.int64)
  1106. empty_trk_obj_ids = np.array([], np.int64)
  1107. det_to_matched_trk_obj_ids = {}
  1108. trk_id_to_max_iou_high_conf_det = {}
  1109. return (
  1110. new_det_fa_inds,
  1111. unmatched_trk_obj_ids,
  1112. det_to_matched_trk_obj_ids,
  1113. trk_id_to_max_iou_high_conf_det,
  1114. empty_trk_obj_ids,
  1115. )
  1116. elif det_masks.size(0) == 0:
  1117. # all previous tracklets are unmatched if they have a non-zero area
  1118. new_det_fa_inds = np.array([], np.int64)
  1119. trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy()
  1120. unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty]
  1121. empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
  1122. det_to_matched_trk_obj_ids = {}
  1123. trk_id_to_max_iou_high_conf_det = {}
  1124. return (
  1125. new_det_fa_inds,
  1126. unmatched_trk_obj_ids,
  1127. det_to_matched_trk_obj_ids,
  1128. trk_id_to_max_iou_high_conf_det,
  1129. empty_trk_obj_ids,
  1130. )
  1131. if det_masks.shape[-2:] != trk_masks.shape[-2:]:
  1132. # resize to the smaller size to save GPU memory
  1133. if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]):
  1134. trk_masks = F.interpolate(
  1135. trk_masks.unsqueeze(1),
  1136. size=det_masks.shape[-2:],
  1137. mode="bilinear",
  1138. align_corners=False,
  1139. ).squeeze(1)
  1140. else:
  1141. # resize detections to track size
  1142. det_masks = F.interpolate(
  1143. det_masks.unsqueeze(1),
  1144. size=trk_masks.shape[-2:],
  1145. mode="bilinear",
  1146. align_corners=False,
  1147. ).squeeze(1)
  1148. det_masks_binary = det_masks > 0
  1149. trk_masks_binary = trk_masks > 0
  1150. ious = mask_iou(det_masks_binary, trk_masks_binary) # (N, M)
  1151. ious_np = ious.cpu().numpy()
  1152. if self.o2o_matching_masklets_enable:
  1153. from scipy.optimize import linear_sum_assignment
  1154. # Hungarian matching for tracks (one-to-one: each track matches at most one detection)
  1155. cost_matrix = 1 - ious_np # Hungarian solves for minimum cost
  1156. row_ind, col_ind = linear_sum_assignment(cost_matrix)
  1157. trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool)
  1158. for d, t in zip(row_ind, col_ind):
  1159. if ious_np[d, t] >= iou_threshold_trk:
  1160. trk_is_matched[t] = True
  1161. else:
  1162. trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0)
  1163. # Non-empty tracks not matched by Hungarian assignment above threshold are unmatched
  1164. trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy()
  1165. trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched)
  1166. unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched]
  1167. # also record masklets that have zero area in SAM 2 prediction
  1168. empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
  1169. # For detections: allow many tracks to match to the same detection (many-to-one)
  1170. # So, a detection is 'new' if it does not match any track above threshold
  1171. is_new_det = np.logical_and(
  1172. det_scores_np >= new_det_thresh,
  1173. np.logical_not(np.any(ious_np >= iou_threshold, axis=1)),
  1174. )
  1175. new_det_fa_inds = np.nonzero(is_new_det)[0]
  1176. # for each detection, which tracks it matched to (above threshold)
  1177. det_to_matched_trk_obj_ids = {}
  1178. trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx
  1179. HIGH_CONF_THRESH = 0.8
  1180. HIGH_IOU_THRESH = 0.8
  1181. det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1)
  1182. det_is_high_conf = (det_scores_np >= HIGH_CONF_THRESH) & ~is_new_det
  1183. det_is_high_iou = np.max(ious_np, axis=1) >= HIGH_IOU_THRESH
  1184. det_is_high_conf_and_iou = set(
  1185. np.nonzero(det_is_high_conf & det_is_high_iou)[0]
  1186. )
  1187. for d in range(det_masks.size(0)):
  1188. det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold]
  1189. if d in det_is_high_conf_and_iou:
  1190. trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item()
  1191. trk_id_to_max_iou_high_conf_det[trk_obj_id] = d
  1192. return (
  1193. new_det_fa_inds,
  1194. unmatched_trk_obj_ids,
  1195. det_to_matched_trk_obj_ids,
  1196. trk_id_to_max_iou_high_conf_det,
  1197. empty_trk_obj_ids,
  1198. )
  1199. def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu):
  1200. """Distribute the new objects to the GPUs with the least workload."""
  1201. workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy()
  1202. new_det_gpu_ids = np.zeros(new_det_num, np.int64)
  1203. # assign the objects one by one
  1204. for i in range(len(new_det_gpu_ids)):
  1205. # find the GPU with the least workload
  1206. min_gpu = np.argmin(workload_per_gpu)
  1207. new_det_gpu_ids[i] = min_gpu
  1208. workload_per_gpu[min_gpu] += 1
  1209. return new_det_gpu_ids
  1210. def _process_hotstart(
  1211. self,
  1212. frame_idx: int,
  1213. num_frames: int,
  1214. reverse: bool,
  1215. det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
  1216. new_det_obj_ids: npt.NDArray,
  1217. empty_trk_obj_ids: npt.NDArray,
  1218. unmatched_trk_obj_ids: npt.NDArray,
  1219. rank0_metadata: Dict[str, Any],
  1220. tracker_metadata: Dict[str, Any],
  1221. ):
  1222. """Handle hotstart heuristics to remove unmatched or duplicated objects."""
  1223. # obj_id --> first frame index where the object was detected
  1224. obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"]
  1225. # obj_id --> [mismatched frame indices]
  1226. unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"]
  1227. trk_keep_alive = rank0_metadata["trk_keep_alive"]
  1228. # (first_appear_obj_id, obj_id) --> [overlap frame indices]
  1229. overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"]
  1230. # removed_obj_ids: object IDs that are suppressed via hot-start
  1231. removed_obj_ids = rank0_metadata["removed_obj_ids"]
  1232. suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx]
  1233. obj_ids_newly_removed = set() # object IDs to be newly removed on this frame
  1234. hotstart_diff = (
  1235. frame_idx - self.hotstart_delay
  1236. if not reverse
  1237. else frame_idx + self.hotstart_delay
  1238. )
  1239. # Step 1: log the frame index where each object ID first appears
  1240. for obj_id in new_det_obj_ids:
  1241. if obj_id not in obj_first_frame_idx:
  1242. obj_first_frame_idx[obj_id] = frame_idx
  1243. assert obj_id not in trk_keep_alive
  1244. trk_keep_alive[obj_id] = self.init_trk_keep_alive
  1245. matched_trks = set()
  1246. # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded
  1247. for matched_trks_per_det in det_to_matched_trk_obj_ids.values():
  1248. matched_trks.update(matched_trks_per_det)
  1249. for obj_id in matched_trks:
  1250. # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive
  1251. trk_keep_alive[obj_id] = min(
  1252. self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1
  1253. )
  1254. for obj_id in unmatched_trk_obj_ids:
  1255. unmatched_frame_inds[obj_id].append(frame_idx)
  1256. # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
  1257. # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough.
  1258. trk_keep_alive[obj_id] = max(
  1259. self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
  1260. )
  1261. if self.decrease_trk_keep_alive_for_empty_masklets:
  1262. for obj_id in empty_trk_obj_ids:
  1263. # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
  1264. trk_keep_alive[obj_id] = max(
  1265. self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
  1266. )
  1267. # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period
  1268. # a) add unmatched frame indices for each existing object ID
  1269. # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask
  1270. # doesn't match any detection; it excludes those frames where SAM2 gives an empty mask
  1271. # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more
  1272. # than `self.hotstart_unmatch_thresh` frames
  1273. for obj_id, frame_indices in unmatched_frame_inds.items():
  1274. if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
  1275. continue # skip if the object is already removed
  1276. if len(frame_indices) >= self.hotstart_unmatch_thresh:
  1277. is_within_hotstart = (
  1278. obj_first_frame_idx[obj_id] > hotstart_diff and not reverse
  1279. ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse)
  1280. if is_within_hotstart:
  1281. obj_ids_newly_removed.add(obj_id)
  1282. logger.debug(
  1283. f"Removing object {obj_id} at frame {frame_idx} "
  1284. f"since it is unmatched for frames: {frame_indices}"
  1285. )
  1286. if (
  1287. trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long
  1288. and not self.suppress_unmatched_only_within_hotstart
  1289. and obj_id not in removed_obj_ids
  1290. and obj_id not in obj_ids_newly_removed
  1291. ):
  1292. logger.debug(
  1293. f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched"
  1294. )
  1295. suppressed_obj_ids.add(obj_id)
  1296. # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames
  1297. # a) find overlaps tracks -- we consider overlap if they match to the same detection
  1298. for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items():
  1299. if len(matched_trk_obj_ids) < 2:
  1300. continue # only count detections that are matched to multiple (>=2) masklets
  1301. # if there are multiple matched track ids, we need to find the one that appeared first;
  1302. # these later appearing ids may be removed since they may be considered as duplicates
  1303. first_appear_obj_id = (
  1304. min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
  1305. if not reverse
  1306. else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
  1307. )
  1308. for obj_id in matched_trk_obj_ids:
  1309. if obj_id != first_appear_obj_id:
  1310. key = (first_appear_obj_id, obj_id)
  1311. overlap_pair_to_frame_inds[key].append(frame_idx)
  1312. # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another
  1313. # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames
  1314. for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items():
  1315. if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
  1316. continue # skip if the object is already removed
  1317. if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or (
  1318. obj_first_frame_idx[obj_id] < hotstart_diff and reverse
  1319. ):
  1320. if len(frame_indices) >= self.hotstart_dup_thresh:
  1321. obj_ids_newly_removed.add(obj_id)
  1322. logger.debug(
  1323. f"Removing object {obj_id} at frame {frame_idx} "
  1324. f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}"
  1325. )
  1326. removed_obj_ids.update(obj_ids_newly_removed)
  1327. return obj_ids_newly_removed, rank0_metadata
  1328. def _tracker_update_memories(
  1329. self,
  1330. tracker_inference_states: List[Any],
  1331. frame_idx: int,
  1332. tracker_metadata: Dict[str, Any],
  1333. low_res_masks: Tensor,
  1334. ):
  1335. """
  1336. Run Sam2 memory encoder, enforcing non-overlapping constraints globally.
  1337. """
  1338. if len(tracker_inference_states) == 0:
  1339. return
  1340. # Avoid an extra interpolation step by directly interpolating to `interpol_size`
  1341. high_res_H, high_res_W = (
  1342. self.tracker.maskmem_backbone.mask_downsampler.interpol_size
  1343. )
  1344. # NOTE: inspect this part if we observe OOMs in the demo
  1345. high_res_masks = F.interpolate(
  1346. low_res_masks.unsqueeze(1),
  1347. size=(high_res_H, high_res_W),
  1348. mode="bilinear",
  1349. align_corners=False,
  1350. )
  1351. # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics.
  1352. if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
  1353. high_res_masks = self.tracker._suppress_object_pw_area_shrinkage(
  1354. high_res_masks
  1355. )
  1356. # Instead of gathering the predicted object scores, we use mask areas as a proxy.
  1357. object_score_logits = torch.where(
  1358. (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0
  1359. )
  1360. # Run the memory encoder on local slices for each GPU
  1361. start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank])
  1362. start_idx_state = start_idx_gpu
  1363. for tracker_state in tracker_inference_states:
  1364. num_obj_per_state = len(tracker_state["obj_ids"])
  1365. if num_obj_per_state == 0:
  1366. continue
  1367. # Get the local high-res masks and object score logits for this inference state
  1368. end_idx_state = start_idx_state + num_obj_per_state
  1369. local_high_res_masks = high_res_masks[start_idx_state:end_idx_state]
  1370. local_object_score_logits = object_score_logits[
  1371. start_idx_state:end_idx_state
  1372. ]
  1373. local_batch_size = local_high_res_masks.size(0)
  1374. # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default
  1375. encoded_mem = self.tracker._run_memory_encoder(
  1376. tracker_state,
  1377. frame_idx,
  1378. local_batch_size,
  1379. local_high_res_masks,
  1380. local_object_score_logits,
  1381. is_mask_from_pts=False,
  1382. )
  1383. local_maskmem_features, local_maskmem_pos_enc = encoded_mem
  1384. # Store encoded memories in the local inference state
  1385. output_dict = tracker_state["output_dict"]
  1386. for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]:
  1387. if frame_idx not in output_dict[storage_key]:
  1388. continue
  1389. output_dict[storage_key][frame_idx]["maskmem_features"] = (
  1390. local_maskmem_features
  1391. )
  1392. output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [
  1393. pos for pos in local_maskmem_pos_enc
  1394. ]
  1395. # for batched inference state, we also need to add per-object
  1396. # memory slides to support instance interactivity
  1397. self.tracker._add_output_per_object(
  1398. inference_state=tracker_state,
  1399. frame_idx=frame_idx,
  1400. current_out=output_dict[storage_key][frame_idx],
  1401. storage_key=storage_key,
  1402. )
  1403. start_idx_state += num_obj_per_state
  1404. def _tracker_add_new_objects(
  1405. self,
  1406. frame_idx: int,
  1407. num_frames: int,
  1408. new_obj_ids: List[int],
  1409. new_obj_masks: Tensor,
  1410. tracker_states_local: List[Any],
  1411. orig_vid_height: int,
  1412. orig_vid_width: int,
  1413. feature_cache: Dict,
  1414. ):
  1415. """Add a new object to SAM2 inference states."""
  1416. prev_tracker_state = (
  1417. tracker_states_local[0] if len(tracker_states_local) > 0 else None
  1418. )
  1419. # prepare inference_state
  1420. # batch objects that first appear on the same frame together
  1421. # Clear inference state. Keep the cached image features if available.
  1422. new_tracker_state = self.tracker.init_state(
  1423. cached_features=feature_cache,
  1424. video_height=orig_vid_height,
  1425. video_width=orig_vid_width,
  1426. num_frames=num_frames,
  1427. )
  1428. new_tracker_state["backbone_out"] = (
  1429. prev_tracker_state.get("backbone_out", None)
  1430. if prev_tracker_state is not None
  1431. else None
  1432. )
  1433. assert len(new_obj_ids) == new_obj_masks.size(0)
  1434. assert new_obj_masks.is_floating_point()
  1435. input_mask_res = self.tracker.input_mask_size
  1436. new_obj_masks = F.interpolate(
  1437. new_obj_masks.unsqueeze(1),
  1438. size=(input_mask_res, input_mask_res),
  1439. mode="bilinear",
  1440. align_corners=False,
  1441. ).squeeze(1)
  1442. new_obj_masks = new_obj_masks > 0
  1443. # add object one by one
  1444. for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks):
  1445. self.tracker.add_new_mask(
  1446. inference_state=new_tracker_state,
  1447. frame_idx=frame_idx,
  1448. obj_id=new_obj_id,
  1449. mask=new_mask,
  1450. add_mask_to_memory=True,
  1451. )
  1452. # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects.
  1453. self.tracker.propagate_in_video_preflight(
  1454. new_tracker_state, run_mem_encoder=True
  1455. )
  1456. tracker_states_local.append(new_tracker_state)
  1457. return tracker_states_local
  1458. def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int):
  1459. """
  1460. Remove an object from SAM2 inference states. This would remove the object from
  1461. all frames in the video.
  1462. """
  1463. tracker_states_local_before_removal = tracker_states_local.copy()
  1464. tracker_states_local.clear()
  1465. for tracker_inference_state in tracker_states_local_before_removal:
  1466. # we try to remove `obj_id` on every inference state with `strict=False`
  1467. # it will not do anything if an inference state doesn't contain `obj_id`
  1468. new_obj_ids, _ = self.tracker.remove_object(
  1469. tracker_inference_state, obj_id, strict=False, need_output=False
  1470. )
  1471. # only keep an inference state if it's non-empty after object removal
  1472. if len(new_obj_ids) > 0:
  1473. tracker_states_local.append(tracker_inference_state)
  1474. def _tracker_remove_objects(
  1475. self, tracker_states_local: List[Any], obj_ids: list[int]
  1476. ):
  1477. """
  1478. Remove an object from SAM2 inference states. This would remove the object from
  1479. all frames in the video.
  1480. """
  1481. for obj_id in obj_ids:
  1482. self._tracker_remove_object(tracker_states_local, obj_id)
  1483. def _initialize_metadata(self):
  1484. """Initialize metadata for the masklets."""
  1485. tracker_metadata = {
  1486. "obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)],
  1487. "obj_ids_all_gpu": np.array([], np.int64),
  1488. "num_obj_per_gpu": np.zeros(self.world_size, np.int64),
  1489. "max_obj_id": -1,
  1490. "obj_id_to_score": {},
  1491. "obj_id_to_tracker_score_frame_wise": defaultdict(dict),
  1492. "obj_id_to_last_occluded": {},
  1493. }
  1494. if self.rank == 0:
  1495. # "rank0_metadata" contains metadata that is only stored on (and accessible to) GPU 0
  1496. # - obj_first_frame_idx: obj_id --> first frame index where the object was detected
  1497. # - unmatched_frame_inds: obj_id --> [mismatched frame indices]
  1498. # - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices]
  1499. # - removed_obj_ids: object IDs that are suppressed via hot-start
  1500. rank0_metadata = {
  1501. "obj_first_frame_idx": {},
  1502. "unmatched_frame_inds": defaultdict(list),
  1503. "trk_keep_alive": defaultdict(
  1504. int
  1505. ), # This is used only for object suppression not for removal
  1506. "overlap_pair_to_frame_inds": defaultdict(list),
  1507. "removed_obj_ids": set(),
  1508. "suppressed_obj_ids": defaultdict(
  1509. set
  1510. ), # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked
  1511. }
  1512. if self.masklet_confirmation_enable:
  1513. # all the following are npt.NDArray with the same shape as `obj_ids_all_gpu`
  1514. rank0_metadata["masklet_confirmation"] = {
  1515. # "status" is the confirmation status of each masklet (in `MaskletConfirmationStatus`)
  1516. "status": np.array([], np.int64),
  1517. # "consecutive_det_num" is the number of consecutive frames where the masklet is
  1518. # detected by the detector (with a matched detection)
  1519. "consecutive_det_num": np.array([], np.int64),
  1520. }
  1521. tracker_metadata["rank0_metadata"] = rank0_metadata
  1522. return tracker_metadata
  1523. def update_masklet_confirmation_status(
  1524. self,
  1525. rank0_metadata: Dict[str, Any],
  1526. obj_ids_all_gpu_prev: npt.NDArray,
  1527. obj_ids_all_gpu_updated: npt.NDArray,
  1528. det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
  1529. new_det_obj_ids: npt.NDArray,
  1530. ):
  1531. confirmation_data = rank0_metadata["masklet_confirmation"]
  1532. # a) first, expand "confirmation_data" to include new masklets added in this frame
  1533. status_prev = confirmation_data["status"]
  1534. consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
  1535. assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
  1536. f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
  1537. )
  1538. obj_id_to_updated_idx = {
  1539. obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
  1540. }
  1541. prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated)
  1542. prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated]
  1543. prev_elem_inds_in_updated = np.array(
  1544. [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated],
  1545. dtype=np.int64,
  1546. )
  1547. # newly added masklets are initialized to "UNCONFIRMED" status
  1548. unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value
  1549. status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val)
  1550. status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated]
  1551. consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated)
  1552. consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[
  1553. prev_elem_is_in_updated
  1554. ]
  1555. # b) update the confirmation status of all masklets based on the current frame
  1556. # b.1) update "consecutive_det_num"
  1557. # "is_matched": whether a masklet is matched to a detection on this frame
  1558. is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids)
  1559. for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values():
  1560. is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids)
  1561. consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0)
  1562. # b.2) update "status"
  1563. change_to_confirmed = (
  1564. consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh
  1565. )
  1566. status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value
  1567. confirmation_data["status"] = status
  1568. confirmation_data["consecutive_det_num"] = consecutive_det_num
  1569. return rank0_metadata
  1570. def forward(self, input: BatchedDatapoint, is_inference: bool = False):
  1571. raise NotImplementedError("Evaluation outside demo is not implemented yet")
  1572. def _load_checkpoint(self, ckpt_path: str, strict: bool = True):
  1573. sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
  1574. missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict)
  1575. if len(missing_keys) > 0 or len(unexpected_keys) > 0:
  1576. logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}")
  1577. else:
  1578. logger.info("Loaded ckpt successfully without missing or unexpected keys")
  1579. def prep_for_evaluator(self, video_frames, tracking_res, scores_labels):
  1580. """This method is only used for benchmark eval (not used in the demo)."""
  1581. num_frames = len(video_frames)
  1582. w, h = video_frames[0].size
  1583. zero_mask = torch.zeros((1, h, w), dtype=torch.bool)
  1584. object_ids = list(scores_labels.keys())
  1585. preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []}
  1586. for oid in object_ids:
  1587. o_masks = []
  1588. o_score = scores_labels[oid][0].item()
  1589. o_label = scores_labels[oid][1]
  1590. for frame_idx in range(num_frames):
  1591. if frame_idx not in tracking_res:
  1592. o_masks.append(zero_mask)
  1593. else:
  1594. o_masks.append(tracking_res[frame_idx].get(oid, zero_mask))
  1595. o_masks = torch.cat(o_masks, dim=0) # (n_frames, H, W)
  1596. preds["scores"].append(o_score)
  1597. preds["labels"].append(o_label)
  1598. preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze())
  1599. preds["masks_rle"].append(rle_encode(o_masks, return_areas=True))
  1600. preds["boxes"] = (
  1601. torch.stack(preds["boxes"], dim=0)
  1602. if len(preds["boxes"]) > 0
  1603. else torch.empty(
  1604. (0, num_frames, 4), dtype=torch.float32, device=self.device
  1605. )
  1606. )
  1607. preds["scores"] = (
  1608. torch.tensor(preds["scores"], device=self.device)
  1609. if len(preds["scores"]) > 0
  1610. else torch.empty((0,), device=self.device)
  1611. )
  1612. preds["per_frame_scores"] = preds["scores"]
  1613. preds["labels"] = (
  1614. torch.tensor(preds["labels"], device=self.device)
  1615. if len(preds["labels"]) > 0
  1616. else torch.empty((0,), device=self.device)
  1617. )
  1618. return preds
  1619. def _encode_prompt(self, **kwargs):
  1620. return self.detector._encode_prompt(**kwargs)
  1621. def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep):
  1622. """
  1623. Drop a few new detections based on the maximum number of objects. We drop new objects based
  1624. on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones.
  1625. """
  1626. assert 0 <= num_to_keep <= len(new_det_fa_inds)
  1627. if num_to_keep == 0:
  1628. return np.array([], np.int64) # keep none
  1629. if num_to_keep == len(new_det_fa_inds):
  1630. return new_det_fa_inds # keep all
  1631. # keep the top-scoring detections
  1632. score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1]
  1633. new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]]
  1634. return new_det_fa_inds