# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe import os from copy import deepcopy from typing import Dict, List, Optional, Tuple import numpy as np import torch from sam3.model.model_misc import SAM3Output from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor from sam3.model.vl_combiner import SAM3VLBackbone from sam3.perflib.nms import nms_masks from sam3.train.data.collator import BatchedDatapoint from .act_ckpt_utils import activation_ckpt_wrapper from .box_ops import box_cxcywh_to_xyxy from .geometry_encoders import Prompt from .model_misc import inverse_sigmoid def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True): out[out_name] = out_value[-1] if auxiliary else out_value if auxiliary and update_aux: if "aux_outputs" not in out: out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)] assert len(out["aux_outputs"]) == len(out_value) - 1 for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]): aux_output[out_name] = aux_value class Sam3Image(torch.nn.Module): TEXT_ID_FOR_TEXT = 0 TEXT_ID_FOR_VISUAL = 1 TEXT_ID_FOR_GEOMETRIC = 2 def __init__( self, backbone: SAM3VLBackbone, transformer, input_geometry_encoder, segmentation_head=None, num_feature_levels=1, o2m_mask_predict=True, dot_prod_scoring=None, use_instance_query: bool = True, multimask_output: bool = True, use_act_checkpoint_seg_head: bool = True, interactivity_in_encoder: bool = True, matcher=None, use_dot_prod_scoring=True, supervise_joint_box_scores: bool = False, # only relevant if using presence token/score detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score separate_scorer_for_instance: bool = False, num_interactive_steps_val: int = 0, inst_interactive_predictor: SAM3InteractiveImagePredictor = None, **kwargs, ): super().__init__() self.backbone = backbone self.geometry_encoder = input_geometry_encoder self.transformer = transformer self.hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.segmentation_head = segmentation_head self.o2m_mask_predict = o2m_mask_predict self.dot_prod_scoring = dot_prod_scoring self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head self.interactivity_in_encoder = interactivity_in_encoder self.matcher = matcher self.num_interactive_steps_val = num_interactive_steps_val self.use_dot_prod_scoring = use_dot_prod_scoring if self.use_dot_prod_scoring: assert dot_prod_scoring is not None self.dot_prod_scoring = dot_prod_scoring self.instance_dot_prod_scoring = None if separate_scorer_for_instance: self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring) else: self.class_embed = torch.nn.Linear(self.hidden_dim, 1) self.instance_class_embed = None if separate_scorer_for_instance: self.instance_class_embed = deepcopy(self.class_embed) self.supervise_joint_box_scores = supervise_joint_box_scores self.detach_presence_in_joint_score = detach_presence_in_joint_score # verify the number of queries for O2O and O2M num_o2o_static = self.transformer.decoder.num_queries num_o2m_static = self.transformer.decoder.num_o2m_queries assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0) self.dac = self.transformer.decoder.dac self.use_instance_query = use_instance_query self.multimask_output = multimask_output self.inst_interactive_predictor = inst_interactive_predictor @property def device(self): self._device = getattr(self, "_device", None) or next(self.parameters()).device return self._device def to(self, *args, **kwargs): # clear cached _device in case the model is moved to a different device self._device = None return super().to(*args, **kwargs) def _get_img_feats(self, backbone_out, img_ids): """Retrieve correct image features from backbone output.""" if "backbone_fpn" in backbone_out: if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None: img_ids = backbone_out["id_mapping"][img_ids] # If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?) # We currently don't expect this to happen. We could technically trigger a recompute here, # but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf torch._assert_async((img_ids >= 0).all()) vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :] vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :] vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] # (H, W) shapes # index and flatten visual features NxCxHxW => HWxNxC (batch-first => seq-first) img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats] img_pos_embeds = [ x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc ] return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes # Image features not available in backbone output, so we compute them on the fly # This case likely occurs for video. In that case, we want to forward only the current frame img_batch = backbone_out["img_batch_all_stages"] if img_ids.numel() > 1: # Only forward backbone on unique image ids to avoid repetitive computation unique_ids, _ = torch.unique(img_ids, return_inverse=True) else: unique_ids, _ = img_ids, slice(None) # Compute the image features on those unique image ids # note: we allow using a list (or other indexable types) of tensors as img_batch # (e.g. for async frame loading in demo). In this case we index img_batch.tensors directly if isinstance(img_batch, torch.Tensor): image = img_batch[unique_ids] elif unique_ids.numel() == 1: image = img_batch[unique_ids.item()].unsqueeze(0) else: image = torch.stack([img_batch[i] for i in unique_ids.tolist()]) # `img_batch` might be fp16 and offloaded to CPU image = image.to(dtype=torch.float32, device=self.device) # Next time we call this function, we want to remember which indices we computed id_mapping = torch.full( (len(img_batch),), -1, dtype=torch.long, device=self.device ) id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device) backbone_out = { **backbone_out, **self.backbone.forward_image(image), "id_mapping": id_mapping, } assert "backbone_fpn" in backbone_out return self._get_img_feats(backbone_out, img_ids=img_ids) def _encode_prompt( self, backbone_out, find_input, geometric_prompt, visual_prompt_embed=None, visual_prompt_mask=None, encode_text=True, prev_mask_pred=None, ): # index text features (note that regardless of early or late fusion, the batch size of # `txt_feats` is always the number of *prompts* in the encoder) txt_ids = find_input.text_ids txt_feats = backbone_out["language_features"][:, txt_ids] txt_masks = backbone_out["language_mask"][txt_ids] feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids) backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple if prev_mask_pred is not None: img_feats = [img_feats[-1] + prev_mask_pred] # Encode geometry geo_feats, geo_masks = self.geometry_encoder( geo_prompt=geometric_prompt, img_feats=img_feats, img_sizes=vis_feat_sizes, img_pos_embeds=img_pos_embeds, ) if visual_prompt_embed is None: visual_prompt_embed = torch.zeros( (0, *geo_feats.shape[1:]), device=geo_feats.device ) visual_prompt_mask = torch.zeros( (*geo_masks.shape[:-1], 0), device=geo_masks.device, dtype=geo_masks.dtype, ) if encode_text: prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0) prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1) else: prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0) prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1) return prompt, prompt_mask, backbone_out def _run_encoder( self, backbone_out, find_input, prompt, prompt_mask, encoder_extra_kwargs: Optional[Dict] = None, ): feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids) backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple # Run the encoder prompt_pos_embed = torch.zeros_like(prompt) # make a copy of the image feature lists since the encoder may modify these lists in-place memory = self.transformer.encoder( src=img_feats.copy(), src_key_padding_mask=None, src_pos=img_pos_embeds.copy(), prompt=prompt, prompt_pos=prompt_pos_embed, prompt_key_padding_mask=prompt_mask, feat_sizes=vis_feat_sizes, encoder_extra_kwargs=encoder_extra_kwargs, ) encoder_out = { # encoded image features "encoder_hidden_states": memory["memory"], "pos_embed": memory["pos_embed"], "padding_mask": memory["padding_mask"], "level_start_index": memory["level_start_index"], "spatial_shapes": memory["spatial_shapes"], "valid_ratios": memory["valid_ratios"], "vis_feat_sizes": vis_feat_sizes, # encoded text features (or other prompts) "prompt_before_enc": prompt, "prompt_after_enc": memory.get("memory_text", prompt), "prompt_mask": prompt_mask, } return backbone_out, encoder_out, feat_tuple def _run_decoder( self, pos_embed, memory, src_mask, out, prompt, prompt_mask, encoder_out, ): bs = memory.shape[1] query_embed = self.transformer.decoder.query_embed.weight tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) apply_dac = self.transformer.decoder.dac and self.training hs, reference_boxes, dec_presence_out, dec_presence_feats = ( self.transformer.decoder( tgt=tgt, memory=memory, memory_key_padding_mask=src_mask, pos=pos_embed, reference_boxes=None, level_start_index=encoder_out["level_start_index"], spatial_shapes=encoder_out["spatial_shapes"], valid_ratios=encoder_out["valid_ratios"], tgt_mask=None, memory_text=prompt, text_attention_mask=prompt_mask, apply_dac=apply_dac, ) ) hs = hs.transpose(1, 2) # seq-first to batch-first reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first if dec_presence_out is not None: # seq-first to batch-first dec_presence_out = dec_presence_out.transpose(1, 2) out["presence_feats"] = dec_presence_feats self._update_scores_and_boxes( out, hs, reference_boxes, prompt, prompt_mask, dec_presence_out=dec_presence_out, ) return out, hs def _update_scores_and_boxes( self, out, hs, reference_boxes, prompt, prompt_mask, dec_presence_out=None, is_instance_prompt=False, ): apply_dac = self.transformer.decoder.dac and self.training num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2) num_o2m = hs.size(2) - num_o2o assert num_o2m == (num_o2o if apply_dac else 0) out["queries"] = hs[-1][:, :num_o2o] # remove o2m queries if there are any # score prediction if self.use_dot_prod_scoring: dot_prod_scoring_head = self.dot_prod_scoring if is_instance_prompt and self.instance_dot_prod_scoring is not None: dot_prod_scoring_head = self.instance_dot_prod_scoring outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask) else: class_embed_head = self.class_embed if is_instance_prompt and self.instance_class_embed is not None: class_embed_head = self.instance_class_embed outputs_class = class_embed_head(hs) # box prediction box_head = self.transformer.decoder.bbox_embed if ( is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None ): box_head = self.transformer.decoder.instance_bbox_embed anchor_box_offsets = box_head(hs) reference_boxes_inv_sig = inverse_sigmoid(reference_boxes) outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid() outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord) if dec_presence_out is not None: _update_out( out, "presence_logit_dec", dec_presence_out, update_aux=self.training ) if self.supervise_joint_box_scores: assert dec_presence_out is not None prob_dec_presence_out = dec_presence_out.clone().sigmoid() if self.detach_presence_in_joint_score: prob_dec_presence_out = prob_dec_presence_out.detach() outputs_class = inverse_sigmoid( outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2) ).clamp(min=-10.0, max=10.0) _update_out( out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training ) _update_out( out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training ) _update_out( out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=self.training, ) if num_o2m > 0 and self.training: _update_out( out, "pred_logits_o2m", outputs_class[:, :, num_o2o:], update_aux=self.training, ) _update_out( out, "pred_boxes_o2m", outputs_coord[:, :, num_o2o:], update_aux=self.training, ) _update_out( out, "pred_boxes_xyxy_o2m", outputs_boxes_xyxy[:, :, num_o2o:], update_aux=self.training, ) def _run_segmentation_heads( self, out, backbone_out, img_ids, vis_feat_sizes, encoder_hidden_states, prompt, prompt_mask, hs, ): apply_dac = self.transformer.decoder.dac and self.training if self.segmentation_head is not None: num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2) num_o2m = hs.size(2) - num_o2o obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o] seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)( backbone_feats=backbone_out["backbone_fpn"], obj_queries=obj_queries, image_ids=img_ids, encoder_hidden_states=encoder_hidden_states, act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head, prompt=prompt, prompt_mask=prompt_mask, ) aux_masks = False # self.aux_loss and self.segmentation_head.aux_masks for k, v in seg_head_outputs.items(): if k in self.segmentation_head.instance_keys: _update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks) if ( self.o2m_mask_predict and num_o2m > 0 ): # handle o2m mask prediction _update_out( out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks ) else: out[k] = v else: backbone_out.pop("backbone_fpn", None) def _get_best_mask(self, out): prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1) batch_idx = torch.arange( out["pred_logits"].shape[0], device=prev_mask_idx.device ) prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None] # Downsample mask to match image resolution. prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler( prev_mask_pred ) prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1) return prev_mask_pred def forward_grounding( self, backbone_out, find_input, find_target, geometric_prompt: Prompt, ): with torch.profiler.record_function("SAM3Image._encode_prompt"): prompt, prompt_mask, backbone_out = self._encode_prompt( backbone_out, find_input, geometric_prompt ) # Run the encoder with torch.profiler.record_function("SAM3Image._run_encoder"): backbone_out, encoder_out, _ = self._run_encoder( backbone_out, find_input, prompt, prompt_mask ) out = { "encoder_hidden_states": encoder_out["encoder_hidden_states"], "prev_encoder_out": { "encoder_out": encoder_out, "backbone_out": backbone_out, }, } # Run the decoder with torch.profiler.record_function("SAM3Image._run_decoder"): out, hs = self._run_decoder( memory=out["encoder_hidden_states"], pos_embed=encoder_out["pos_embed"], src_mask=encoder_out["padding_mask"], out=out, prompt=prompt, prompt_mask=prompt_mask, encoder_out=encoder_out, ) # Run segmentation heads with torch.profiler.record_function("SAM3Image._run_segmentation_heads"): self._run_segmentation_heads( out=out, backbone_out=backbone_out, img_ids=find_input.img_ids, vis_feat_sizes=encoder_out["vis_feat_sizes"], encoder_hidden_states=out["encoder_hidden_states"], prompt=prompt, prompt_mask=prompt_mask, hs=hs, ) if self.training or self.num_interactive_steps_val > 0: self._compute_matching(out, self.back_convert(find_target)) return out def _postprocess_out(self, out: Dict, multimask_output: bool = False): # For multimask output, during eval we return the single best mask with the dict keys expected by the evaluators, but also return the multimasks output with new keys. num_mask_boxes = out["pred_boxes"].size(1) if not self.training and multimask_output and num_mask_boxes > 1: out["multi_pred_logits"] = out["pred_logits"] if "pred_masks" in out: out["multi_pred_masks"] = out["pred_masks"] out["multi_pred_boxes"] = out["pred_boxes"] out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"] best_mask_idx = out["pred_logits"].argmax(1).squeeze(1) batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device) out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze( 1 ) if "pred_masks" in out: out["pred_masks"] = out["pred_masks"][ batch_idx, best_mask_idx ].unsqueeze(1) out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1) out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][ batch_idx, best_mask_idx ].unsqueeze(1) return out def _get_dummy_prompt(self, num_prompts=1): device = self.device geometric_prompt = Prompt( box_embeddings=torch.zeros(0, num_prompts, 4, device=device), box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool), ) return geometric_prompt def forward(self, input: BatchedDatapoint): device = self.device backbone_out = {"img_batch_all_stages": input.img_batch} backbone_out.update(self.backbone.forward_image(input.img_batch)) num_frames = len(input.find_inputs) assert num_frames == 1 text_outputs = self.backbone.forward_text(input.find_text_batch, device=device) backbone_out.update(text_outputs) previous_stages_out = SAM3Output( iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE ) find_input = input.find_inputs[0] find_target = input.find_targets[0] if find_input.input_points is not None and find_input.input_points.numel() > 0: print("Warning: Point prompts are ignored in PCS.") num_interactive_steps = 0 if self.training else self.num_interactive_steps_val geometric_prompt = Prompt( box_embeddings=find_input.input_boxes, box_mask=find_input.input_boxes_mask, box_labels=find_input.input_boxes_label, ) # Init vars that are shared across the loop. stage_outs = [] for cur_step in range(num_interactive_steps + 1): if cur_step > 0: # We sample interactive geometric prompts (boxes, points) geometric_prompt, _ = self.interactive_prompt_sampler.sample( geo_prompt=geometric_prompt, find_target=find_target, previous_out=stage_outs[-1], ) out = self.forward_grounding( backbone_out=backbone_out, find_input=find_input, find_target=find_target, geometric_prompt=geometric_prompt.clone(), ) stage_outs.append(out) previous_stages_out.append(stage_outs) return previous_stages_out def _compute_matching(self, out, targets): out["indices"] = self.matcher(out, targets) for aux_out in out.get("aux_outputs", []): aux_out["indices"] = self.matcher(aux_out, targets) def back_convert(self, targets): batched_targets = { "boxes": targets.boxes.view(-1, 4), "boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)), "boxes_padded": targets.boxes_padded, "positive_map": targets.boxes.new_ones(len(targets.boxes), 1), "num_boxes": targets.num_boxes, "masks": targets.segments, "semantic_masks": targets.semantic_segments, "is_valid_mask": targets.is_valid_segment, "is_exhaustive": targets.is_exhaustive, "object_ids_packed": targets.object_ids, "object_ids_padded": targets.object_ids_padded, } return batched_targets def predict_inst( self, inference_state, **kwargs, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: orig_h, orig_w = ( inference_state["original_height"], inference_state["original_width"], ) backbone_out = inference_state["backbone_out"]["sam2_backbone_out"] ( _, vision_feats, _, _, ) = self.inst_interactive_predictor.model._prepare_backbone_features( backbone_out ) # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos vision_feats[-1] = ( vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed ) feats = [ feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip( vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1] ) ][::-1] self.inst_interactive_predictor._features = { "image_embed": feats[-1], "high_res_feats": feats[:-1], } self.inst_interactive_predictor._is_image_set = True self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)] res = self.inst_interactive_predictor.predict(**kwargs) self.inst_interactive_predictor._features = None self.inst_interactive_predictor._is_image_set = False return res def predict_inst_batch( self, inference_state, *args, **kwargs, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: backbone_out = inference_state["backbone_out"]["sam2_backbone_out"] ( _, vision_feats, _, _, ) = self.inst_interactive_predictor.model._prepare_backbone_features( backbone_out ) # Add no_mem_embed, which is added to the lowest res feat. map during training on videos vision_feats[-1] = ( vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed ) batch_size = vision_feats[-1].shape[1] orig_heights, orig_widths = ( inference_state["original_heights"], inference_state["original_widths"], ) assert batch_size == len(orig_heights) == len(orig_widths), ( f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}" ) feats = [ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip( vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1] ) ][::-1] self.inst_interactive_predictor._features = { "image_embed": feats[-1], "high_res_feats": feats[:-1], } self.inst_interactive_predictor._is_image_set = True self.inst_interactive_predictor._is_batch = True self.inst_interactive_predictor._orig_hw = [ (orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths) ] res = self.inst_interactive_predictor.predict_batch(*args, **kwargs) self.inst_interactive_predictor._features = None self.inst_interactive_predictor._is_image_set = False self.inst_interactive_predictor._is_batch = False return res class Sam3ImageOnVideoMultiGPU(Sam3Image): def __init__( self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs ): super().__init__(*args, **kwargs) self.rank = int(os.getenv("RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) self.async_all_gather = async_all_gather # if gather_backbone is not set, default to gathering only for `SAM3VLBackbone` if gather_backbone_out is None: gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone) self.gather_backbone_out = gather_backbone_out def forward_video_grounding_multigpu( self, backbone_out, find_inputs, geometric_prompt: Prompt, frame_idx, num_frames, # `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls multigpu_buffer, track_in_reverse=False, # whether to also return the SAM2 backbone features return_sam2_backbone_feats=False, # whether to perform NMS and suppress the scores of those detections removed by NMS run_nms=False, nms_prob_thresh=None, nms_iou_thresh=None, **kwargs, ): """ Compute the detector's detection outputs in a distributed manner, where all GPUs process a chunk of frames (equal to the number of GPUs) at once and store them in cache. """ # Step 1: fetch the detector outputs in the current chunk from buffer frame_idx_curr_b = frame_idx - frame_idx % self.world_size frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames) # in case the current frame's detection results are not in the buffer yet, build the current chunk # (this should only happen on the first chunk, since we are also building the next chunk below) if frame_idx not in multigpu_buffer: with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"): self._build_multigpu_buffer_next_chunk( backbone_out=backbone_out, find_inputs=find_inputs, geometric_prompt=geometric_prompt, frame_idx_begin=frame_idx_curr_b, frame_idx_end=frame_idx_curr_e, num_frames=num_frames, multigpu_buffer=multigpu_buffer, run_nms=run_nms, nms_prob_thresh=nms_prob_thresh, nms_iou_thresh=nms_iou_thresh, ) # read out the current frame's results from `multigpu_buffer` out = {} for k, (v, handle) in multigpu_buffer[frame_idx].items(): if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats: continue if handle is not None: handle.wait() # wait for async all-gather to finish out[k] = v # Step 2: remove detection outputs of the previous chunk from cache to save GPU memory if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0: frame_idx_prev_e = frame_idx_curr_b frame_idx_prev_b = frame_idx_curr_b - self.world_size elif track_in_reverse and frame_idx_curr_e < num_frames: frame_idx_prev_b = frame_idx_curr_e frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames) else: frame_idx_prev_b = frame_idx_prev_e = None if frame_idx_prev_b is not None: for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e): multigpu_buffer.pop(frame_idx_rm, None) # Step 3: compute and cache detection outputs of the next chunk ahead of time # (so that we can overlap computation with all-gather transfer) if not track_in_reverse and frame_idx_curr_e < num_frames: frame_idx_next_b = frame_idx_curr_e frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames) elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0: frame_idx_next_e = frame_idx_curr_b frame_idx_next_b = frame_idx_curr_b - self.world_size else: frame_idx_next_b = frame_idx_next_e = None if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer: with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"): self._build_multigpu_buffer_next_chunk( backbone_out=backbone_out, find_inputs=find_inputs, geometric_prompt=geometric_prompt, frame_idx_begin=frame_idx_next_b, frame_idx_end=frame_idx_next_e, num_frames=num_frames, multigpu_buffer=multigpu_buffer, run_nms=run_nms, nms_prob_thresh=nms_prob_thresh, nms_iou_thresh=nms_iou_thresh, ) return out, backbone_out def _build_multigpu_buffer_next_chunk( self, backbone_out, find_inputs, geometric_prompt: Prompt, frame_idx_begin, frame_idx_end, num_frames, multigpu_buffer, run_nms=False, nms_prob_thresh=None, nms_iou_thresh=None, ): """Compute detection outputs on a chunk of frames and store their results in multigpu_buffer.""" # each GPU computes detections on one frame in the chunk (in a round-robin manner) frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1) # `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame with torch.profiler.record_function("forward_grounding"): out_local = self.forward_grounding( backbone_out=backbone_out, find_input=find_inputs[frame_idx_local_gpu], find_target=None, geometric_prompt=geometric_prompt, ) if run_nms: with torch.profiler.record_function("nms_masks"): # run NMS as a post-processing step on top of the detection outputs assert nms_prob_thresh is not None and nms_iou_thresh is not None pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid() pred_masks = out_local["pred_masks"] # loop over text prompts (not an overhead for demo where there's only 1 prompt) for prompt_idx in range(pred_probs.size(0)): keep = nms_masks( pred_probs=pred_probs[prompt_idx], pred_masks=pred_masks[prompt_idx], prob_threshold=nms_prob_thresh, iou_threshold=nms_iou_thresh, ) # set a very low threshold for those detections removed by NMS out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float() if self.gather_backbone_out: # gather the SAM 2 backbone features across GPUs feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"] assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP) backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0]) fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1]) fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2]) # vision_pos_enc is the same on all frames, so no need to all-gather them vision_pos_enc = feats["vision_pos_enc"] # trim the detector output to only include the necessary keys out_local = { "pred_logits": out_local["pred_logits"], "pred_boxes": out_local["pred_boxes"], "pred_boxes_xyxy": out_local["pred_boxes_xyxy"], "pred_masks": out_local["pred_masks"], } # gather the results: after this step, each GPU will receive detector outputs on # all frames in the chunk and store them in `multigpu_buffer` out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()} for rank in range(self.world_size): frame_idx_to_save = frame_idx_begin + rank if frame_idx_to_save >= num_frames: continue frame_buffer = { k: (v[rank], handle) for k, (v, handle) in out_gathered.items() } if self.gather_backbone_out: # also add gathered SAM 2 backbone features to frame_buffer frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0) frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1) frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2) frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None) multigpu_buffer[frame_idx_to_save] = frame_buffer def _gather_tensor(self, x): if self.world_size == 1: return [x], None async_op = self.async_all_gather # here `.contiguous()` is required -- otherwise NCCL all_gather # sometimes gives wrong results x = x.contiguous() # ensure contiguous memory for NCCL output_list = [torch.empty_like(x) for _ in range(self.world_size)] handle = torch.distributed.all_gather(output_list, x, async_op=async_op) return output_list, handle