# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe import datetime import logging import math import os from collections import defaultdict from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Set import numpy as np import numpy.typing as npt import torch import torch.distributed as dist import torch.nn.functional as F from sam3 import perflib from sam3.logger import get_logger from sam3.model.box_ops import fast_diag_box_iou from sam3.model.data_misc import BatchedDatapoint from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box from sam3.perflib.masks_ops import mask_iou from sam3.train.masks_ops import rle_encode from torch import nn, Tensor logger = get_logger(__name__) class MaskletConfirmationStatus(Enum): UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet CONFIRMED = 2 # confirmed by at least one detection class Sam3VideoBase(nn.Module): def __init__( self, detector: nn.Module, tracker: nn.Module, # prob threshold for detection outputs -- only keep detections above this threshold # enters NMS and det-to-track matching score_threshold_detection=0.5, # IoU threshold for detection NMS det_nms_thresh=0.0, # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1 assoc_iou_thresh=0.5, # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched" # by any detections -- it is often a stricter threshold like 0.5 trk_assoc_iou_thresh=0.5, # prob threshold for a detection to be added as a new object new_det_thresh=0.0, # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh` # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh` hotstart_delay=0, hotstart_unmatch_thresh=3, hotstart_dup_thresh=3, # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period. suppress_unmatched_only_within_hotstart=True, init_trk_keep_alive=0, max_trk_keep_alive=8, min_trk_keep_alive=-4, # Threshold for suppressing overlapping objects based on recent occlusion suppress_overlapping_based_on_recent_occlusion_threshold=0.0, decrease_trk_keep_alive_for_empty_masklets=False, o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets suppress_det_close_to_boundary=False, fill_hole_area=16, # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1) max_num_objects=-1, recondition_every_nth_frame=-1, # masket confirmation status (to suppress unconfirmed masklets) masklet_confirmation_enable=False, # a masklet is confirmed after being consecutively detected and matched for # `masklet_confirmation_consecutive_det_thresh` masklet_confirmation_consecutive_det_thresh=3, # bbox heuristic parameters reconstruction_bbox_iou_thresh=0.0, reconstruction_bbox_det_score=0.0, ): super().__init__() self.detector = detector self.tracker = tracker self.score_threshold_detection = score_threshold_detection self.det_nms_thresh = det_nms_thresh self.assoc_iou_thresh = assoc_iou_thresh self.trk_assoc_iou_thresh = trk_assoc_iou_thresh self.new_det_thresh = new_det_thresh # hotstart parameters if hotstart_delay > 0: assert hotstart_unmatch_thresh <= hotstart_delay assert hotstart_dup_thresh <= hotstart_delay self.hotstart_delay = hotstart_delay self.hotstart_unmatch_thresh = hotstart_unmatch_thresh self.hotstart_dup_thresh = hotstart_dup_thresh self.suppress_unmatched_only_within_hotstart = ( suppress_unmatched_only_within_hotstart ) self.init_trk_keep_alive = init_trk_keep_alive self.max_trk_keep_alive = max_trk_keep_alive self.min_trk_keep_alive = min_trk_keep_alive self.suppress_overlapping_based_on_recent_occlusion_threshold = ( suppress_overlapping_based_on_recent_occlusion_threshold ) self.suppress_det_close_to_boundary = suppress_det_close_to_boundary self.decrease_trk_keep_alive_for_empty_masklets = ( decrease_trk_keep_alive_for_empty_masklets ) self.o2o_matching_masklets_enable = o2o_matching_masklets_enable self.fill_hole_area = fill_hole_area self.eval() self.rank = int(os.getenv("RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use) # the maximum object number if max_num_objects > 0: num_obj_for_compile = math.ceil(max_num_objects / self.world_size) else: max_num_objects = 10000 # no limit num_obj_for_compile = 16 logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}") self.max_num_objects = max_num_objects self.num_obj_for_compile = num_obj_for_compile self.recondition_every_nth_frame = recondition_every_nth_frame self.masklet_confirmation_enable = masklet_confirmation_enable self.masklet_confirmation_consecutive_det_thresh = ( masklet_confirmation_consecutive_det_thresh ) self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh self.reconstruction_bbox_det_score = reconstruction_bbox_det_score @property def device(self): self._device = getattr(self, "_device", None) or next(self.parameters()).device return self._device def _init_dist_pg_cpu(self): # a short 3-min timeout to quickly detect any synchronization failures timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180")) timeout = datetime.timedelta(seconds=timeout_sec) self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout) def broadcast_python_obj_cpu(self, python_obj_list, src): if self._dist_pg_cpu is None: self._init_dist_pg_cpu() dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu) def _det_track_one_frame( self, frame_idx: int, num_frames: int, reverse: bool, input_batch: BatchedDatapoint, geometric_prompt: Any, tracker_states_local: List[Any], tracker_metadata_prev: Dict[str, Any], feature_cache: Dict, orig_vid_height: int, orig_vid_width: int, is_image_only: bool = False, allow_new_detections: bool = True, ): """ This function handles one-step inference for the DenseTracking model in an SPMD manner. At a high-level, all GPUs execute the same function calls as if it's done on a single GPU, while under the hood, some function calls involve distributed computation based on sharded SAM2 states. - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs - `tracker_states_local` holds the local masklet information in this GPU shard - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs it contains both global and local masklet information """ # Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU, # a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner. # It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx` # into `feature_cache`. Despite its distributed inference under the hood, the results would be # the same as if it is running backbone and detector for every frame on a single GPU. det_out = self.run_backbone_and_detection( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, input_batch=input_batch, geometric_prompt=geometric_prompt, feature_cache=feature_cache, allow_new_detections=allow_new_detections, ) # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks. # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks; # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics. if tracker_metadata_prev == {}: # initialize masklet metadata if it's uninitialized (empty dict) tracker_metadata_prev.update(self._initialize_metadata()) tracker_low_res_masks_global, tracker_obj_scores_global = ( self.run_tracker_propagation( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, tracker_states_local=tracker_states_local, tracker_metadata_prev=tracker_metadata_prev, ) ) # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc). # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints. # **This step should involve all the heuristics needed for any updates.** Most of the update # planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`). tracker_update_plan, tracker_metadata_new = ( self.run_tracker_update_planning_phase( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_out=det_out, tracker_low_res_masks_global=tracker_low_res_masks_global, tracker_obj_scores_global=tracker_obj_scores_global, tracker_metadata_prev=tracker_metadata_prev, tracker_states_local=tracker_states_local, is_image_only=is_image_only, ) ) # Get reconditioning info from the update plan reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set()) det_to_matched_trk_obj_ids = tracker_update_plan.get( "det_to_matched_trk_obj_ids", {} ) # Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states tracker_states_local_new = self.run_tracker_update_execution_phase( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_out=det_out, tracker_states_local=tracker_states_local, tracker_update_plan=tracker_update_plan, orig_vid_height=orig_vid_height, orig_vid_width=orig_vid_width, feature_cache=feature_cache, ) # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since # only GPU 0 will send outputs to the server). if self.rank == 0: obj_id_to_mask = self.build_outputs( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_out=det_out, tracker_low_res_masks_global=tracker_low_res_masks_global, tracker_obj_scores_global=tracker_obj_scores_global, tracker_metadata_prev=tracker_metadata_prev, tracker_update_plan=tracker_update_plan, orig_vid_height=orig_vid_height, orig_vid_width=orig_vid_width, reconditioned_obj_ids=reconditioned_obj_ids, det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, ) obj_id_to_score = tracker_metadata_new["obj_id_to_score"] else: obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs # a few statistics for the current frame as a part of the output frame_stats = { "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]), "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"], } # add tracker scores to metadata, it should be fired for frames except the first frame if tracker_obj_scores_global.shape[0] > 0: # Convert tracker_obj_scores_global to sigmoid scores before updating tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist() tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][ frame_idx ].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global))) return ( obj_id_to_mask, # a dict: obj_id --> output mask obj_id_to_score, # a dict: obj_id --> output score (prob) tracker_states_local_new, tracker_metadata_new, frame_stats, tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores ) def _suppress_detections_close_to_boundary(self, boxes, margin=0.025): """ Suppress detections too close to image edges (for normalized boxes). boxes: (N, 4) in xyxy format, normalized [0,1] margin: fraction of image """ x_min, y_min, x_max, y_max = boxes.unbind(-1) x_c = (x_min + x_max) / 2 y_c = (y_min + y_max) / 2 keep = ( (x_c > margin) & (x_c < 1.0 - margin) & (y_c > margin) & (y_c < 1.0 - margin) ) return keep def run_backbone_and_detection( self, frame_idx: int, num_frames: int, input_batch: BatchedDatapoint, geometric_prompt: Any, feature_cache: Dict, reverse: bool, allow_new_detections: bool, ): # Step 1: if text feature is not cached in `feature_cache`, compute and cache it text_batch_key = tuple(input_batch.find_text_batch) if "text" not in feature_cache or text_batch_key not in feature_cache["text"]: text_outputs = self.detector.backbone.forward_text( input_batch.find_text_batch, device=self.device ) # note: we only cache the text feature of the most recent prompt feature_cache["text"] = {text_batch_key: text_outputs} else: text_outputs = feature_cache["text"][text_batch_key] # Step 2: run backbone, detector, and post-processing with NMS if "multigpu_buffer" not in feature_cache: # "multigpu_buffer" is a buffer cache used by `self.detector` and it needs # to be passed to `forward_video_grounding_multigpu` for every call feature_cache["multigpu_buffer"] = {} # Extract max_frame_num_to_track from feature_cache if available tracking_bounds = feature_cache.get("tracking_bounds", {}) max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track") start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx") sam3_image_out, _ = self.detector.forward_video_grounding_multigpu( backbone_out={ "img_batch_all_stages": input_batch.img_batch, **text_outputs, }, find_inputs=input_batch.find_inputs, geometric_prompt=geometric_prompt, frame_idx=frame_idx, num_frames=num_frames, multigpu_buffer=feature_cache["multigpu_buffer"], track_in_reverse=reverse, # also get the SAM2 backbone features return_tracker_backbone_feats=True, # run NMS as a part of distributed computation run_nms=self.det_nms_thresh > 0.0, nms_prob_thresh=self.score_threshold_detection, nms_iou_thresh=self.det_nms_thresh, # pass max_frame_num_to_track to respect tracking limits max_frame_num_to_track=max_frame_num_to_track, propagate_in_video_start_frame_idx=start_frame_idx, ) # note: detections in `sam3_image_out` has already gone through NMS pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid() if not allow_new_detections: pred_probs = pred_probs - 1e8 # make sure no detections are kept pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"] pred_masks = sam3_image_out["pred_masks"] # get the positive detection outputs above threshold pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection) det_out = { "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]], "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]], "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]], } # Step 3: build SAM2 backbone features and store them in `feature_cache` backbone_cache = {} sam_mask_decoder = self.tracker.sam_mask_decoder tracker_backbone_fpn = [ sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]), sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]), sam3_image_out["tracker_backbone_fpn_2"], # fpn_2 doesn't need conv ] tracker_backbone_out = { "vision_features": tracker_backbone_fpn[-1], # top-level feature "vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"], "backbone_fpn": tracker_backbone_fpn, } backbone_cache["tracker_backbone_out"] = tracker_backbone_out feature_cache[frame_idx] = ( input_batch.img_batch[frame_idx], backbone_cache, ) # remove from `feature_cache` old features to save GPU memory feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None) return det_out def run_tracker_propagation( self, frame_idx: int, num_frames: int, reverse: bool, tracker_states_local: List[Any], tracker_metadata_prev: Dict[str, npt.NDArray], ): # Step 1: propagate the local SAM2 states to get the current frame's prediction # `low_res_masks_local` of the existing masklets on this GPU # - obj_ids_local: List[int] -- list of object IDs # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask) obj_ids_local, low_res_masks_local, obj_scores_local = ( self._propogate_tracker_one_frame_local_gpu( tracker_states_local, frame_idx=frame_idx, reverse=reverse ) ) assert np.all( obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank] ), "{} != {}".format( obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank] ) # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global` # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) _, H_mask, W_mask = low_res_masks_local.shape if self.world_size > 1: # `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32 # (they could be non-contiguous due to slicing and/or bfloat16 due to autocast) low_res_masks_local = low_res_masks_local.float().contiguous() obj_scores_local = obj_scores_local.float().contiguous() num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank] assert low_res_masks_local.size(0) == num_obj_this_gpu assert obj_scores_local.size(0) == num_obj_this_gpu low_res_masks_peers = [ low_res_masks_local.new_empty(num_obj, H_mask, W_mask) for num_obj in tracker_metadata_prev["num_obj_per_gpu"] ] obj_scores_peers = [ obj_scores_local.new_empty(num_obj) for num_obj in tracker_metadata_prev["num_obj_per_gpu"] ] dist.all_gather(low_res_masks_peers, low_res_masks_local) dist.all_gather(obj_scores_peers, obj_scores_local) low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) obj_scores_global = torch.cat(obj_scores_peers, dim=0) else: low_res_masks_global = low_res_masks_local obj_scores_global = obj_scores_local return low_res_masks_global, obj_scores_global def _recondition_masklets( self, frame_idx, det_out: Dict[str, Tensor], trk_id_to_max_iou_high_conf_det: List[int], tracker_states_local: List[Any], tracker_metadata: Dict[str, npt.NDArray], tracker_obj_scores_global: Tensor, ): # Recondition the masklets based on the new detections for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): new_mask = det_out["mask"][det_idx : det_idx + 1] input_mask_res = self.tracker.input_mask_size new_mask_binary = ( F.interpolate( new_mask.unsqueeze(1), size=(input_mask_res, input_mask_res), mode="bilinear", align_corners=False, ).squeeze(1)[0] > 0 ) HIGH_CONF_THRESH = 0.8 reconditioned_states_idx = set() obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[ 0 ].item() obj_score = tracker_obj_scores_global[obj_idx] for state_idx, inference_state in enumerate(tracker_states_local): if ( trk_obj_id in inference_state["obj_ids"] # NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy. # Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics. and obj_score > HIGH_CONF_THRESH ): logger.debug( f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned." ) self.tracker.add_new_mask( inference_state=inference_state, frame_idx=frame_idx, obj_id=trk_obj_id, mask=new_mask_binary, ) reconditioned_states_idx.add(state_idx) for idx in reconditioned_states_idx: self.tracker.propagate_in_video_preflight( tracker_states_local[idx], run_mem_encoder=True ) return tracker_states_local def run_tracker_update_planning_phase( self, frame_idx: int, num_frames: int, reverse: bool, det_out: Dict[str, Tensor], tracker_low_res_masks_global: Tensor, tracker_obj_scores_global: Tensor, tracker_metadata_prev: Dict[str, npt.NDArray], tracker_states_local: List[Any], is_image_only: bool = False, ): # initialize new metadata from previous metadata (its values will be updated later) tracker_metadata_new = { "obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]), "obj_ids_all_gpu": None, # will be filled later "num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]), "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]), "obj_id_to_tracker_score_frame_wise": deepcopy( tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"] ), "obj_id_to_last_occluded": {}, # will be filled later "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]), } # Initialize reconditioned_obj_ids early to avoid UnboundLocalError reconditioned_obj_ids = set() # Step 1: make the update plan and resolve heuristics on GPU 0 det_mask_preds: Tensor = det_out["mask"] # low-res mask logits det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy() det_bbox_xyxy: Tensor = det_out["bbox"] if self.rank == 0: # a) match detector and tracker masks and find new objects ( new_det_fa_inds, unmatched_trk_obj_ids, det_to_matched_trk_obj_ids, trk_id_to_max_iou_high_conf_det, empty_trk_obj_ids, ) = self._associate_det_trk( det_masks=det_mask_preds, det_scores_np=det_scores_np, trk_masks=tracker_low_res_masks_global, trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"], ) if self.suppress_det_close_to_boundary: keep = self._suppress_detections_close_to_boundary( det_bbox_xyxy[new_det_fa_inds] ) new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()] # check whether we've hit the maximum number of objects we can track (and if so, drop some detections) prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"]) new_det_num = len(new_det_fa_inds) num_obj_dropped_due_to_limit = 0 if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects: logger.warning( f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}" ) new_det_num_to_keep = self.max_num_objects - prev_obj_num num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep new_det_fa_inds = self._drop_new_det_with_obj_limit( new_det_fa_inds, det_scores_np, new_det_num_to_keep ) assert len(new_det_fa_inds) == new_det_num_to_keep new_det_num = len(new_det_fa_inds) # assign object IDs to new detections and decide which GPU to place them new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1 new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num) prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"] new_det_gpu_ids = self._assign_new_det_to_gpus( new_det_num=new_det_num, prev_workload_per_gpu=prev_workload_per_gpu, ) # b) handle hotstart heuristics to remove objects # here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0; # we avoid broadcasting them to other GPUs to save communication cost, assuming # that `rank0_metadata` is not needed by other GPUs rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"]) if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, new_det_obj_ids=new_det_obj_ids, empty_trk_obj_ids=empty_trk_obj_ids, unmatched_trk_obj_ids=unmatched_trk_obj_ids, rank0_metadata=rank0_metadata_new, tracker_metadata=tracker_metadata_prev, ) else: # if warm-up is not complete, we don't remove any objects obj_ids_newly_removed = set() tracker_metadata_new["rank0_metadata"] = rank0_metadata_new # Step 2: broadcast the update plan to other GPUs NUM_BROADCAST_ITEMS = 9 if self.rank == 0 and self.world_size > 1: # `num_obj_per_gpu_on_rank0` is used for metadata consistency check on other GPUs # (it's a small array with length==self.world_size, so broadcasting it is cheap) num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"] update_plan = [ new_det_fa_inds, new_det_obj_ids, new_det_gpu_ids, num_obj_per_gpu_on_rank0, unmatched_trk_obj_ids, det_to_matched_trk_obj_ids, obj_ids_newly_removed, num_obj_dropped_due_to_limit, trk_id_to_max_iou_high_conf_det, ] assert len(update_plan) == NUM_BROADCAST_ITEMS, ( f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}" ) self.broadcast_python_obj_cpu(update_plan, src=0) elif self.rank > 0 and self.world_size > 1: update_plan = [ None ] * NUM_BROADCAST_ITEMS # other ranks receive the plan from rank 0 self.broadcast_python_obj_cpu(update_plan, src=0) ( new_det_fa_inds, new_det_obj_ids, new_det_gpu_ids, num_obj_per_gpu_on_rank0, unmatched_trk_obj_ids, det_to_matched_trk_obj_ids, obj_ids_newly_removed, num_obj_dropped_due_to_limit, trk_id_to_max_iou_high_conf_det, ) = update_plan # metadata consistency check: verify that the received `num_obj_per_gpu_on_rank0` is consistent with the local metadata # it's critical that all GPUs agree on the previous number of objects (otherwise the inference might hang or fail silently) if not np.all( num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"] ): raise RuntimeError( f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record " f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution." ) # `tracker_update_plan` should be identical on all GPUs after broadcasting tracker_update_plan = { "new_det_fa_inds": new_det_fa_inds, # npt.NDArray "new_det_obj_ids": new_det_obj_ids, # npt.NDArray "new_det_gpu_ids": new_det_gpu_ids, # npt.NDArray "unmatched_trk_obj_ids": unmatched_trk_obj_ids, # npt.NDArray "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict "obj_ids_newly_removed": obj_ids_newly_removed, # set "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict "reconditioned_obj_ids": reconditioned_obj_ids, # set } # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results should_recondition_iou = False # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections if ( self.reconstruction_bbox_iou_thresh > 0 and len(trk_id_to_max_iou_high_conf_det) > 0 ): for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): det_box = det_out["bbox"][det_idx] det_score = det_out["scores"][det_idx] try: trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index( trk_obj_id ) except ValueError: continue # Skip if tracklet not found tracker_mask = tracker_low_res_masks_global[trk_idx] mask_binary = tracker_mask > 0 mask_area = mask_binary.sum().item() if mask_area == 0: continue # Skip tracklets with zero mask area # Get bounding box from SAM2 mask and convert to normalized coordinates tracker_box_pixels = ( mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0)) .squeeze(0) .squeeze(0) ) mask_height, mask_width = tracker_mask.shape[-2:] tracker_box_normalized = torch.tensor( [ tracker_box_pixels[0] / mask_width, tracker_box_pixels[1] / mask_height, tracker_box_pixels[2] / mask_width, tracker_box_pixels[3] / mask_height, ], device=tracker_box_pixels.device, ) # Compute IoU between detection and SAM2 tracklet bounding boxes det_box_batch = det_box.unsqueeze(0) tracker_box_batch = tracker_box_normalized.unsqueeze(0) iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0] if ( iou < self.reconstruction_bbox_iou_thresh and det_score >= self.reconstruction_bbox_det_score ): should_recondition_iou = True reconditioned_obj_ids.add(trk_obj_id) should_recondition_periodic = ( self.recondition_every_nth_frame > 0 and frame_idx % self.recondition_every_nth_frame == 0 and len(trk_id_to_max_iou_high_conf_det) > 0 ) # Recondition if periodic or IoU condition met if should_recondition_periodic or should_recondition_iou: self._recondition_masklets( frame_idx, det_out, trk_id_to_max_iou_high_conf_det, tracker_states_local, tracker_metadata_prev, tracker_obj_scores_global, ) # Step 4: Run SAM2 memory encoder on the current frame's prediction masks # This is done on all GPUs batch_size = tracker_low_res_masks_global.size(0) if batch_size > 0: if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0: # NOTE: tracker_low_res_masks_global is updated in-place then returned tracker_low_res_masks_global = ( self._suppress_overlapping_based_on_recent_occlusion( frame_idx, tracker_low_res_masks_global, tracker_metadata_prev, tracker_metadata_new, obj_ids_newly_removed, reverse, ) ) self._tracker_update_memories( tracker_states_local, frame_idx, tracker_metadata=tracker_metadata_prev, low_res_masks=tracker_low_res_masks_global, ) # Step 4: update the SAM2 metadata based on the update plan # note: except for "rank0_metadata" (that is only available on GPU 0), # the updated `tracker_metadata_new` should be identical on all GPUs for rank in range(self.world_size): new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank] updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank] if len(new_det_obj_ids_this_gpu) > 0: updated_obj_ids_this_gpu = np.concatenate( [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu] ) if len(obj_ids_newly_removed) > 0: is_removed = np.isin( updated_obj_ids_this_gpu, list(obj_ids_newly_removed) ) updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed] tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu tracker_metadata_new["num_obj_per_gpu"][rank] = len( updated_obj_ids_this_gpu ) tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate( tracker_metadata_new["obj_ids_per_gpu"] ) # update object scores and the maximum object ID assigned so far if len(new_det_obj_ids) > 0: tracker_metadata_new["obj_id_to_score"].update( zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]) ) # tracker scores are not available for new objects, use det score instead. tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][ frame_idx ].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])) tracker_metadata_new["max_obj_id"] = max( tracker_metadata_new["max_obj_id"], np.max(new_det_obj_ids), ) # for removed objects, we set their scores to a very low value (-1e4) but still # keep them in "obj_id_to_score" (it's easier to handle outputs this way) for obj_id in obj_ids_newly_removed: tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4 tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][ obj_id ] = -1e4 tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None) # check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0 assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0) if self.rank == 0 and self.masklet_confirmation_enable: rank0_metadata = self.update_masklet_confirmation_status( rank0_metadata=tracker_metadata_new["rank0_metadata"], obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"], obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"], det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, new_det_obj_ids=new_det_obj_ids, ) tracker_metadata_new["rank0_metadata"] = rank0_metadata return tracker_update_plan, tracker_metadata_new def _suppress_overlapping_based_on_recent_occlusion( self, frame_idx: int, tracker_low_res_masks_global: Tensor, tracker_metadata_prev: Dict[str, Any], tracker_metadata_new: Dict[str, Any], obj_ids_newly_removed: Set[int], reverse: bool = False, ): """ Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object. Args: frame_idx (int): The current frame index. tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame. tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame. tracker_metadata_new (Dict[str, Any]): The metadata for the current frame. obj_ids_newly_removed (Set[int]): The object IDs that have been removed. Return: Tensor: The updated low-resolution masks with some objects suppressed. """ obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"] binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 batch_size = tracker_low_res_masks_global.size(0) if batch_size > 0: assert len(obj_ids_global) == batch_size, ( f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" ) NEVER_OCCLUDED = -1 ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic last_occluded_prev = torch.cat( [ tracker_metadata_prev["obj_id_to_last_occluded"].get( obj_id, torch.full( (1,), fill_value=( NEVER_OCCLUDED if obj_id not in obj_ids_newly_removed else ALWAYS_OCCLUDED ), device=binary_tracker_low_res_masks_global.device, dtype=torch.long, ), ) for obj_id in obj_ids_global ], dim=0, ) to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded( binary_tracker_low_res_masks_global, last_occluded_prev, obj_ids_global, frame_idx, reverse, ) # Update metadata with occlusion information is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2))) is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress last_occluded_new = last_occluded_prev.clone() last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx # Slice out the last occluded frame for each object tracker_metadata_new["obj_id_to_last_occluded"] = { obj_id: last_occluded_new[obj_idx : obj_idx + 1] for obj_idx, obj_id in enumerate(obj_ids_global) } # Zero out suppressed masks before memory encoding NO_OBJ_LOGIT = -10 tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT return tracker_low_res_masks_global def run_tracker_update_execution_phase( self, frame_idx: int, num_frames: int, reverse: bool, det_out: Dict[str, Tensor], tracker_states_local: List[Any], tracker_update_plan: Dict[str, npt.NDArray], orig_vid_height: int, orig_vid_width: int, feature_cache: Dict, ): # initialize tracking scores with detection scores new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"] new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"] new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"] is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu] new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu] obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"] # Step 1: add new objects from the detector to SAM2 inference states if len(new_det_fa_inds_local) > 0: new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local) new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t] # initialize SAM2 with new object masks tracker_states_local = self._tracker_add_new_objects( frame_idx=frame_idx, num_frames=num_frames, new_obj_ids=new_det_obj_ids_local, new_obj_masks=new_det_masks, tracker_states_local=tracker_states_local, orig_vid_height=orig_vid_height, orig_vid_width=orig_vid_width, feature_cache=feature_cache, ) # Step 2: remove from SAM2 inference states those objects removed by heuristics if len(obj_ids_newly_removed) > 0: self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed) return tracker_states_local def build_outputs( self, frame_idx: int, num_frames: int, reverse: bool, det_out: Dict[str, Tensor], tracker_low_res_masks_global: Tensor, tracker_obj_scores_global: Tensor, tracker_metadata_prev: Dict[str, npt.NDArray], tracker_update_plan: Dict[str, npt.NDArray], orig_vid_height: int, orig_vid_width: int, reconditioned_obj_ids: set = None, det_to_matched_trk_obj_ids: dict = None, ): new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"] new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"] obj_id_to_mask = {} # obj_id --> output mask tensor # Part 1: masks from previous SAM2 propagation existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] existing_masklet_video_res_masks = F.interpolate( tracker_low_res_masks_global.unsqueeze(1), size=(orig_vid_height, orig_vid_width), mode="bilinear", align_corners=False, ) # (num_obj, 1, H_video, W_video) existing_masklet_binary = existing_masklet_video_res_masks > 0 assert len(existing_masklet_obj_ids) == len(existing_masklet_binary) for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary): obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) # Part 2: masks from new detections new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds) new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1) new_det_low_res_masks = fill_holes_in_mask_scores( new_det_low_res_masks, max_area=self.fill_hole_area, fill_holes=True, remove_sprinkles=True, ) new_masklet_video_res_masks = F.interpolate( new_det_low_res_masks, size=(orig_vid_height, orig_vid_width), mode="bilinear", align_corners=False, ) # (num_obj, 1, H_video, W_video) new_masklet_binary = new_masklet_video_res_masks > 0 assert len(new_det_obj_ids) == len(new_masklet_video_res_masks) for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary): obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) # Part 3: Override masks for reconditioned objects using detection masks if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0: trk_id_to_max_iou_high_conf_det = tracker_update_plan.get( "trk_id_to_max_iou_high_conf_det", {} ) for obj_id in reconditioned_obj_ids: det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id) if det_idx is not None: det_mask = det_out["mask"][det_idx] det_mask = det_mask.unsqueeze(0).unsqueeze(0) det_mask_resized = ( F.interpolate( det_mask.float(), size=(orig_vid_height, orig_vid_width), mode="bilinear", align_corners=False, ) > 0 ) det_mask_final = det_mask_resized.squeeze(0) obj_id_to_mask[obj_id] = det_mask_final return obj_id_to_mask def _get_objects_to_suppress_based_on_most_recently_occluded( self, binary_low_res_masks: Tensor, last_occluded: List[int], obj_ids: List[int], frame_idx: int = None, reverse: bool = False, ): # Suppress overlapping masks for objects that were most recently occluded assert binary_low_res_masks.dtype == torch.bool, ( f"Expected boolean tensor, got {binary_low_res_masks.dtype}" ) to_suppress = torch.zeros( binary_low_res_masks.size(0), device=binary_low_res_masks.device, dtype=torch.bool, ) if len(obj_ids) <= 1: return to_suppress iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N] # Create masks for upper triangular matrix (i < j) and IoU threshold mask_iou_thresh = ( iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold ) overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N] last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1) last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N) # Suppress most recently occluded cmp_op = torch.gt if not reverse else torch.lt suppress_i_mask = ( overlapping_pairs & cmp_op( last_occ_expanded_i, last_occ_expanded_j ) # (last_occ_expanded_i > last_occ_expanded_j) & ( last_occ_expanded_j > -1 ) # j can suppress i only if i was previously occluded ) suppress_j_mask = ( overlapping_pairs & cmp_op(last_occ_expanded_j, last_occ_expanded_i) & ( last_occ_expanded_i > -1 ) # i can suppress j only if j was previously occluded ) # Apply suppression to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0) # Log for debugging if ( self.rank == 0 and logger.isEnabledFor(logging.DEBUG) and frame_idx is not None ): suppress_i_mask = suppress_i_mask.cpu().numpy() suppress_j_mask = suppress_j_mask.cpu().numpy() last_occluded = last_occluded.cpu().numpy() # Find all suppression pairs without using torch.where batch_size = suppress_i_mask.shape[0] # Log i-suppression cases (where i gets suppressed in favor of j) for i in range(batch_size): for j in range(batch_size): if suppress_i_mask[i, j]: logger.debug( f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}" ) # Log j-suppression cases (where j gets suppressed in favor of i) for i in range(batch_size): for j in range(batch_size): if suppress_j_mask[i, j]: logger.debug( f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}" ) return to_suppress def _propogate_tracker_one_frame_local_gpu( self, inference_states: List[Any], frame_idx: int, reverse: bool, # by default, we disable memory encoding until we gather all outputs run_mem_encoder: bool = False, ): """ inference_states: List of inference states, each state corresponds to a different set of objects. """ obj_ids_local = [] low_res_masks_list = [] obj_scores_list = [] for inference_state in inference_states: if len(inference_state["obj_ids"]) == 0: continue # skip propagation on empty inference states # propagate one frame num_frames_propagated = 0 for out in self.tracker.propagate_in_video( inference_state, start_frame_idx=frame_idx, # end_frame_idx = start_frame_idx + max_frame_num_to_track # (i.e. propagating 1 frame since end_frame_idx is inclusive) max_frame_num_to_track=0, reverse=reverse, tqdm_disable=True, run_mem_encoder=run_mem_encoder, ): out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out num_frames_propagated += 1 # only 1 frames should be propagated assert num_frames_propagated == 1 and out_frame_idx == frame_idx, ( f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" ) assert isinstance(out_obj_ids, list) obj_ids_local.extend(out_obj_ids) low_res_masks_list.append(out_low_res_masks.squeeze(1)) obj_scores_list.append(out_obj_scores.squeeze(1)) # concatenate the output masklets from all local inference states H_mask = W_mask = self.tracker.low_res_mask_size if len(low_res_masks_list) > 0: low_res_masks_local = torch.cat(low_res_masks_list, dim=0) obj_scores_local = torch.cat(obj_scores_list, dim=0) assert low_res_masks_local.shape[1:] == (H_mask, W_mask) # Apply hole filling to the masks low_res_masks_local = fill_holes_in_mask_scores( low_res_masks_local.unsqueeze(1), max_area=self.fill_hole_area, fill_holes=True, remove_sprinkles=True, ) low_res_masks_local = low_res_masks_local.squeeze(1) else: low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) obj_scores_local = torch.zeros(0, device=self.device) return obj_ids_local, low_res_masks_local, obj_scores_local def _associate_det_trk( self, det_masks: Tensor, det_scores_np: npt.NDArray, trk_masks: Tensor, trk_obj_ids: npt.NDArray, ): """ Match detections on the current frame with the existing masklets. Args: - det_masks: (N, H, W) tensor of predicted masks - det_scores_np: (N,) array of detection scores - trk_masks: (M, H, W) tensor of track masks - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks Returns: - new_det_fa_inds: array of new object indices. - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched to any detections on this frame (for unmatched, we only count masklets with >0 area) - det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices to the list of matched tracklet object IDs - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction """ iou_threshold = self.assoc_iou_thresh iou_threshold_trk = self.trk_assoc_iou_thresh new_det_thresh = self.new_det_thresh assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" assert trk_masks.size(0) == len(trk_obj_ids), ( f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" ) if trk_masks.size(0) == 0: # all detections are new new_det_fa_inds = np.arange(det_masks.size(0)) unmatched_trk_obj_ids = np.array([], np.int64) empty_trk_obj_ids = np.array([], np.int64) det_to_matched_trk_obj_ids = {} trk_id_to_max_iou_high_conf_det = {} return ( new_det_fa_inds, unmatched_trk_obj_ids, det_to_matched_trk_obj_ids, trk_id_to_max_iou_high_conf_det, empty_trk_obj_ids, ) elif det_masks.size(0) == 0: # all previous tracklets are unmatched if they have a non-zero area new_det_fa_inds = np.array([], np.int64) trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy() unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty] empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] det_to_matched_trk_obj_ids = {} trk_id_to_max_iou_high_conf_det = {} return ( new_det_fa_inds, unmatched_trk_obj_ids, det_to_matched_trk_obj_ids, trk_id_to_max_iou_high_conf_det, empty_trk_obj_ids, ) if det_masks.shape[-2:] != trk_masks.shape[-2:]: # resize to the smaller size to save GPU memory if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]): trk_masks = F.interpolate( trk_masks.unsqueeze(1), size=det_masks.shape[-2:], mode="bilinear", align_corners=False, ).squeeze(1) else: # resize detections to track size det_masks = F.interpolate( det_masks.unsqueeze(1), size=trk_masks.shape[-2:], mode="bilinear", align_corners=False, ).squeeze(1) det_masks_binary = det_masks > 0 trk_masks_binary = trk_masks > 0 ious = mask_iou(det_masks_binary, trk_masks_binary) # (N, M) ious_np = ious.cpu().numpy() if self.o2o_matching_masklets_enable: from scipy.optimize import linear_sum_assignment # Hungarian matching for tracks (one-to-one: each track matches at most one detection) cost_matrix = 1 - ious_np # Hungarian solves for minimum cost row_ind, col_ind = linear_sum_assignment(cost_matrix) trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool) for d, t in zip(row_ind, col_ind): if ious_np[d, t] >= iou_threshold_trk: trk_is_matched[t] = True else: trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0) # Non-empty tracks not matched by Hungarian assignment above threshold are unmatched trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy() trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched) unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched] # also record masklets that have zero area in SAM 2 prediction empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] # For detections: allow many tracks to match to the same detection (many-to-one) # So, a detection is 'new' if it does not match any track above threshold is_new_det = np.logical_and( det_scores_np >= new_det_thresh, np.logical_not(np.any(ious_np >= iou_threshold, axis=1)), ) new_det_fa_inds = np.nonzero(is_new_det)[0] # for each detection, which tracks it matched to (above threshold) det_to_matched_trk_obj_ids = {} trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx HIGH_CONF_THRESH = 0.8 HIGH_IOU_THRESH = 0.8 det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1) det_is_high_conf = (det_scores_np >= HIGH_CONF_THRESH) & ~is_new_det det_is_high_iou = np.max(ious_np, axis=1) >= HIGH_IOU_THRESH det_is_high_conf_and_iou = set( np.nonzero(det_is_high_conf & det_is_high_iou)[0] ) for d in range(det_masks.size(0)): det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold] if d in det_is_high_conf_and_iou: trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item() trk_id_to_max_iou_high_conf_det[trk_obj_id] = d return ( new_det_fa_inds, unmatched_trk_obj_ids, det_to_matched_trk_obj_ids, trk_id_to_max_iou_high_conf_det, empty_trk_obj_ids, ) def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu): """Distribute the new objects to the GPUs with the least workload.""" workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy() new_det_gpu_ids = np.zeros(new_det_num, np.int64) # assign the objects one by one for i in range(len(new_det_gpu_ids)): # find the GPU with the least workload min_gpu = np.argmin(workload_per_gpu) new_det_gpu_ids[i] = min_gpu workload_per_gpu[min_gpu] += 1 return new_det_gpu_ids def _process_hotstart( self, frame_idx: int, num_frames: int, reverse: bool, det_to_matched_trk_obj_ids: Dict[int, npt.NDArray], new_det_obj_ids: npt.NDArray, empty_trk_obj_ids: npt.NDArray, unmatched_trk_obj_ids: npt.NDArray, rank0_metadata: Dict[str, Any], tracker_metadata: Dict[str, Any], ): """Handle hotstart heuristics to remove unmatched or duplicated objects.""" # obj_id --> first frame index where the object was detected obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"] # obj_id --> [mismatched frame indices] unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"] trk_keep_alive = rank0_metadata["trk_keep_alive"] # (first_appear_obj_id, obj_id) --> [overlap frame indices] overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"] # removed_obj_ids: object IDs that are suppressed via hot-start removed_obj_ids = rank0_metadata["removed_obj_ids"] suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx] obj_ids_newly_removed = set() # object IDs to be newly removed on this frame hotstart_diff = ( frame_idx - self.hotstart_delay if not reverse else frame_idx + self.hotstart_delay ) # Step 1: log the frame index where each object ID first appears for obj_id in new_det_obj_ids: if obj_id not in obj_first_frame_idx: obj_first_frame_idx[obj_id] = frame_idx assert obj_id not in trk_keep_alive trk_keep_alive[obj_id] = self.init_trk_keep_alive matched_trks = set() # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded for matched_trks_per_det in det_to_matched_trk_obj_ids.values(): matched_trks.update(matched_trks_per_det) for obj_id in matched_trks: # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive trk_keep_alive[obj_id] = min( self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1 ) for obj_id in unmatched_trk_obj_ids: unmatched_frame_inds[obj_id].append(frame_idx) # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough. trk_keep_alive[obj_id] = max( self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 ) if self.decrease_trk_keep_alive_for_empty_masklets: for obj_id in empty_trk_obj_ids: # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive trk_keep_alive[obj_id] = max( self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 ) # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period # a) add unmatched frame indices for each existing object ID # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask # doesn't match any detection; it excludes those frames where SAM2 gives an empty mask # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more # than `self.hotstart_unmatch_thresh` frames for obj_id, frame_indices in unmatched_frame_inds.items(): if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: continue # skip if the object is already removed if len(frame_indices) >= self.hotstart_unmatch_thresh: is_within_hotstart = ( obj_first_frame_idx[obj_id] > hotstart_diff and not reverse ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse) if is_within_hotstart: obj_ids_newly_removed.add(obj_id) logger.debug( f"Removing object {obj_id} at frame {frame_idx} " f"since it is unmatched for frames: {frame_indices}" ) if ( trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long and not self.suppress_unmatched_only_within_hotstart and obj_id not in removed_obj_ids and obj_id not in obj_ids_newly_removed ): logger.debug( f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched" ) suppressed_obj_ids.add(obj_id) # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames # a) find overlaps tracks -- we consider overlap if they match to the same detection for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items(): if len(matched_trk_obj_ids) < 2: continue # only count detections that are matched to multiple (>=2) masklets # if there are multiple matched track ids, we need to find the one that appeared first; # these later appearing ids may be removed since they may be considered as duplicates first_appear_obj_id = ( min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) if not reverse else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) ) for obj_id in matched_trk_obj_ids: if obj_id != first_appear_obj_id: key = (first_appear_obj_id, obj_id) overlap_pair_to_frame_inds[key].append(frame_idx) # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items(): if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: continue # skip if the object is already removed if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( obj_first_frame_idx[obj_id] < hotstart_diff and reverse ): if len(frame_indices) >= self.hotstart_dup_thresh: obj_ids_newly_removed.add(obj_id) logger.debug( f"Removing object {obj_id} at frame {frame_idx} " f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}" ) removed_obj_ids.update(obj_ids_newly_removed) return obj_ids_newly_removed, rank0_metadata def _tracker_update_memories( self, tracker_inference_states: List[Any], frame_idx: int, tracker_metadata: Dict[str, Any], low_res_masks: Tensor, ): """ Run Sam2 memory encoder, enforcing non-overlapping constraints globally. """ if len(tracker_inference_states) == 0: return # Avoid an extra interpolation step by directly interpolating to `interpol_size` high_res_H, high_res_W = ( self.tracker.maskmem_backbone.mask_downsampler.interpol_size ) # NOTE: inspect this part if we observe OOMs in the demo high_res_masks = F.interpolate( low_res_masks.unsqueeze(1), size=(high_res_H, high_res_W), mode="bilinear", align_corners=False, ) # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics. if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: high_res_masks = self.tracker._suppress_object_pw_area_shrinkage( high_res_masks ) # Instead of gathering the predicted object scores, we use mask areas as a proxy. object_score_logits = torch.where( (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0 ) # Run the memory encoder on local slices for each GPU start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank]) start_idx_state = start_idx_gpu for tracker_state in tracker_inference_states: num_obj_per_state = len(tracker_state["obj_ids"]) if num_obj_per_state == 0: continue # Get the local high-res masks and object score logits for this inference state end_idx_state = start_idx_state + num_obj_per_state local_high_res_masks = high_res_masks[start_idx_state:end_idx_state] local_object_score_logits = object_score_logits[ start_idx_state:end_idx_state ] local_batch_size = local_high_res_masks.size(0) # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default encoded_mem = self.tracker._run_memory_encoder( tracker_state, frame_idx, local_batch_size, local_high_res_masks, local_object_score_logits, is_mask_from_pts=False, ) local_maskmem_features, local_maskmem_pos_enc = encoded_mem # Store encoded memories in the local inference state output_dict = tracker_state["output_dict"] for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: if frame_idx not in output_dict[storage_key]: continue output_dict[storage_key][frame_idx]["maskmem_features"] = ( local_maskmem_features ) output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [ pos for pos in local_maskmem_pos_enc ] # for batched inference state, we also need to add per-object # memory slides to support instance interactivity self.tracker._add_output_per_object( inference_state=tracker_state, frame_idx=frame_idx, current_out=output_dict[storage_key][frame_idx], storage_key=storage_key, ) start_idx_state += num_obj_per_state def _tracker_add_new_objects( self, frame_idx: int, num_frames: int, new_obj_ids: List[int], new_obj_masks: Tensor, tracker_states_local: List[Any], orig_vid_height: int, orig_vid_width: int, feature_cache: Dict, ): """Add a new object to SAM2 inference states.""" prev_tracker_state = ( tracker_states_local[0] if len(tracker_states_local) > 0 else None ) # prepare inference_state # batch objects that first appear on the same frame together # Clear inference state. Keep the cached image features if available. new_tracker_state = self.tracker.init_state( cached_features=feature_cache, video_height=orig_vid_height, video_width=orig_vid_width, num_frames=num_frames, ) new_tracker_state["backbone_out"] = ( prev_tracker_state.get("backbone_out", None) if prev_tracker_state is not None else None ) assert len(new_obj_ids) == new_obj_masks.size(0) assert new_obj_masks.is_floating_point() input_mask_res = self.tracker.input_mask_size new_obj_masks = F.interpolate( new_obj_masks.unsqueeze(1), size=(input_mask_res, input_mask_res), mode="bilinear", align_corners=False, ).squeeze(1) new_obj_masks = new_obj_masks > 0 # add object one by one for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks): self.tracker.add_new_mask( inference_state=new_tracker_state, frame_idx=frame_idx, obj_id=new_obj_id, mask=new_mask, add_mask_to_memory=True, ) # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects. self.tracker.propagate_in_video_preflight( new_tracker_state, run_mem_encoder=True ) tracker_states_local.append(new_tracker_state) return tracker_states_local def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int): """ Remove an object from SAM2 inference states. This would remove the object from all frames in the video. """ tracker_states_local_before_removal = tracker_states_local.copy() tracker_states_local.clear() for tracker_inference_state in tracker_states_local_before_removal: # we try to remove `obj_id` on every inference state with `strict=False` # it will not do anything if an inference state doesn't contain `obj_id` new_obj_ids, _ = self.tracker.remove_object( tracker_inference_state, obj_id, strict=False, need_output=False ) # only keep an inference state if it's non-empty after object removal if len(new_obj_ids) > 0: tracker_states_local.append(tracker_inference_state) def _tracker_remove_objects( self, tracker_states_local: List[Any], obj_ids: list[int] ): """ Remove an object from SAM2 inference states. This would remove the object from all frames in the video. """ for obj_id in obj_ids: self._tracker_remove_object(tracker_states_local, obj_id) def _initialize_metadata(self): """Initialize metadata for the masklets.""" tracker_metadata = { "obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)], "obj_ids_all_gpu": np.array([], np.int64), "num_obj_per_gpu": np.zeros(self.world_size, np.int64), "max_obj_id": -1, "obj_id_to_score": {}, "obj_id_to_tracker_score_frame_wise": defaultdict(dict), "obj_id_to_last_occluded": {}, } if self.rank == 0: # "rank0_metadata" contains metadata that is only stored on (and accessible to) GPU 0 # - obj_first_frame_idx: obj_id --> first frame index where the object was detected # - unmatched_frame_inds: obj_id --> [mismatched frame indices] # - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices] # - removed_obj_ids: object IDs that are suppressed via hot-start rank0_metadata = { "obj_first_frame_idx": {}, "unmatched_frame_inds": defaultdict(list), "trk_keep_alive": defaultdict( int ), # This is used only for object suppression not for removal "overlap_pair_to_frame_inds": defaultdict(list), "removed_obj_ids": set(), "suppressed_obj_ids": defaultdict( set ), # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked } if self.masklet_confirmation_enable: # all the following are npt.NDArray with the same shape as `obj_ids_all_gpu` rank0_metadata["masklet_confirmation"] = { # "status" is the confirmation status of each masklet (in `MaskletConfirmationStatus`) "status": np.array([], np.int64), # "consecutive_det_num" is the number of consecutive frames where the masklet is # detected by the detector (with a matched detection) "consecutive_det_num": np.array([], np.int64), } tracker_metadata["rank0_metadata"] = rank0_metadata return tracker_metadata def update_masklet_confirmation_status( self, rank0_metadata: Dict[str, Any], obj_ids_all_gpu_prev: npt.NDArray, obj_ids_all_gpu_updated: npt.NDArray, det_to_matched_trk_obj_ids: Dict[int, npt.NDArray], new_det_obj_ids: npt.NDArray, ): confirmation_data = rank0_metadata["masklet_confirmation"] # a) first, expand "confirmation_data" to include new masklets added in this frame status_prev = confirmation_data["status"] consecutive_det_num_prev = confirmation_data["consecutive_det_num"] assert status_prev.shape == obj_ids_all_gpu_prev.shape, ( f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" ) obj_id_to_updated_idx = { obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) } prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated) prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated] prev_elem_inds_in_updated = np.array( [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated], dtype=np.int64, ) # newly added masklets are initialized to "UNCONFIRMED" status unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val) status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated] consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated) consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[ prev_elem_is_in_updated ] # b) update the confirmation status of all masklets based on the current frame # b.1) update "consecutive_det_num" # "is_matched": whether a masklet is matched to a detection on this frame is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids) for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values(): is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids) consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0) # b.2) update "status" change_to_confirmed = ( consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh ) status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value confirmation_data["status"] = status confirmation_data["consecutive_det_num"] = consecutive_det_num return rank0_metadata def forward(self, input: BatchedDatapoint, is_inference: bool = False): raise NotImplementedError("Evaluation outside demo is not implemented yet") def _load_checkpoint(self, ckpt_path: str, strict: bool = True): sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict) if len(missing_keys) > 0 or len(unexpected_keys) > 0: logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}") else: logger.info("Loaded ckpt successfully without missing or unexpected keys") def prep_for_evaluator(self, video_frames, tracking_res, scores_labels): """This method is only used for benchmark eval (not used in the demo).""" num_frames = len(video_frames) w, h = video_frames[0].size zero_mask = torch.zeros((1, h, w), dtype=torch.bool) object_ids = list(scores_labels.keys()) preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []} for oid in object_ids: o_masks = [] o_score = scores_labels[oid][0].item() o_label = scores_labels[oid][1] for frame_idx in range(num_frames): if frame_idx not in tracking_res: o_masks.append(zero_mask) else: o_masks.append(tracking_res[frame_idx].get(oid, zero_mask)) o_masks = torch.cat(o_masks, dim=0) # (n_frames, H, W) preds["scores"].append(o_score) preds["labels"].append(o_label) preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze()) preds["masks_rle"].append(rle_encode(o_masks, return_areas=True)) preds["boxes"] = ( torch.stack(preds["boxes"], dim=0) if len(preds["boxes"]) > 0 else torch.empty( (0, num_frames, 4), dtype=torch.float32, device=self.device ) ) preds["scores"] = ( torch.tensor(preds["scores"], device=self.device) if len(preds["scores"]) > 0 else torch.empty((0,), device=self.device) ) preds["per_frame_scores"] = preds["scores"] preds["labels"] = ( torch.tensor(preds["labels"], device=self.device) if len(preds["labels"]) > 0 else torch.empty((0,), device=self.device) ) return preds def _encode_prompt(self, **kwargs): return self.detector._encode_prompt(**kwargs) def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep): """ Drop a few new detections based on the maximum number of objects. We drop new objects based on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones. """ assert 0 <= num_to_keep <= len(new_det_fa_inds) if num_to_keep == 0: return np.array([], np.int64) # keep none if num_to_keep == len(new_det_fa_inds): return new_det_fa_inds # keep all # keep the top-scoring detections score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1] new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]] return new_det_fa_inds