|
@@ -0,0 +1,1172 @@
|
|
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
|
+# All rights reserved.
|
|
|
|
|
+
|
|
|
|
|
+# This source code is licensed under the license found in the
|
|
|
|
|
+# LICENSE file in the root directory of this source tree.
|
|
|
|
|
+
|
|
|
|
|
+import warnings
|
|
|
|
|
+from collections import OrderedDict
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+from tqdm import tqdm
|
|
|
|
|
+
|
|
|
|
|
+from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
|
|
|
|
+from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SAM2VideoPredictor(SAM2Base):
|
|
|
|
|
+ """The predictor class to handle user interactions and manage inference states."""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ fill_hole_area=0,
|
|
|
|
|
+ # whether to apply non-overlapping constraints on the output object masks
|
|
|
|
|
+ non_overlap_masks=False,
|
|
|
|
|
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
|
|
|
|
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
|
|
|
|
+ clear_non_cond_mem_around_input=False,
|
|
|
|
|
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
|
|
|
|
+ clear_non_cond_mem_for_multi_obj=False,
|
|
|
|
|
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
|
|
|
|
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
|
|
|
|
+ add_all_frames_to_correct_as_cond=False,
|
|
|
|
|
+ **kwargs,
|
|
|
|
|
+ ):
|
|
|
|
|
+ super().__init__(**kwargs)
|
|
|
|
|
+ self.fill_hole_area = fill_hole_area
|
|
|
|
|
+ self.non_overlap_masks = non_overlap_masks
|
|
|
|
|
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
|
|
|
|
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
|
|
|
|
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def init_state(
|
|
|
|
|
+ self,
|
|
|
|
|
+ video_path,
|
|
|
|
|
+ offload_video_to_cpu=False,
|
|
|
|
|
+ offload_state_to_cpu=False,
|
|
|
|
|
+ async_loading_frames=False,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """Initialize an inference state."""
|
|
|
|
|
+ compute_device = self.device # device of the model
|
|
|
|
|
+ images, video_height, video_width = load_video_frames(
|
|
|
|
|
+ video_path=video_path,
|
|
|
|
|
+ image_size=self.image_size,
|
|
|
|
|
+ offload_video_to_cpu=offload_video_to_cpu,
|
|
|
|
|
+ async_loading_frames=async_loading_frames,
|
|
|
|
|
+ compute_device=compute_device,
|
|
|
|
|
+ )
|
|
|
|
|
+ inference_state = {}
|
|
|
|
|
+ inference_state["images"] = images
|
|
|
|
|
+ inference_state["num_frames"] = len(images)
|
|
|
|
|
+ # whether to offload the video frames to CPU memory
|
|
|
|
|
+ # turning on this option saves the GPU memory with only a very small overhead
|
|
|
|
|
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
|
|
|
|
+ # whether to offload the inference state to CPU memory
|
|
|
|
|
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
|
|
|
|
|
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
|
|
|
|
|
+ # and from 24 to 21 when tracking two objects)
|
|
|
|
|
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
|
|
|
|
|
+ # the original video height and width, used for resizing final output scores
|
|
|
|
|
+ inference_state["video_height"] = video_height
|
|
|
|
|
+ inference_state["video_width"] = video_width
|
|
|
|
|
+ inference_state["device"] = compute_device
|
|
|
|
|
+ if offload_state_to_cpu:
|
|
|
|
|
+ inference_state["storage_device"] = torch.device("cpu")
|
|
|
|
|
+ else:
|
|
|
|
|
+ inference_state["storage_device"] = compute_device
|
|
|
|
|
+ # inputs on each frame
|
|
|
|
|
+ inference_state["point_inputs_per_obj"] = {}
|
|
|
|
|
+ inference_state["mask_inputs_per_obj"] = {}
|
|
|
|
|
+ # visual features on a small number of recently visited frames for quick interactions
|
|
|
|
|
+ inference_state["cached_features"] = {}
|
|
|
|
|
+ # values that don't change across frames (so we only need to hold one copy of them)
|
|
|
|
|
+ inference_state["constants"] = {}
|
|
|
|
|
+ # mapping between client-side object id and model-side object index
|
|
|
|
|
+ inference_state["obj_id_to_idx"] = OrderedDict()
|
|
|
|
|
+ inference_state["obj_idx_to_id"] = OrderedDict()
|
|
|
|
|
+ inference_state["obj_ids"] = []
|
|
|
|
|
+ # A storage to hold the model's tracking results and states on each frame
|
|
|
|
|
+ inference_state["output_dict"] = {
|
|
|
|
|
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
|
|
|
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
|
|
|
+ }
|
|
|
|
|
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
|
|
|
|
+ inference_state["output_dict_per_obj"] = {}
|
|
|
|
|
+ # A temporary storage to hold new outputs when user interact with a frame
|
|
|
|
|
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
|
|
|
|
+ inference_state["temp_output_dict_per_obj"] = {}
|
|
|
|
|
+ # Frames that already holds consolidated outputs from click or mask inputs
|
|
|
|
|
+ # (we directly use their consolidated outputs during tracking)
|
|
|
|
|
+ inference_state["consolidated_frame_inds"] = {
|
|
|
|
|
+ "cond_frame_outputs": set(), # set containing frame indices
|
|
|
|
|
+ "non_cond_frame_outputs": set(), # set containing frame indices
|
|
|
|
|
+ }
|
|
|
|
|
+ # metadata for each tracking frame (e.g. which direction it's tracked)
|
|
|
|
|
+ inference_state["tracking_has_started"] = False
|
|
|
|
|
+ inference_state["frames_already_tracked"] = {}
|
|
|
|
|
+ # Warm up the visual backbone and cache the image feature on frame 0
|
|
|
|
|
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
|
|
|
|
+ return inference_state
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
|
|
|
|
+ """
|
|
|
|
|
+ Load a pretrained model from the Hugging Face hub.
|
|
|
|
|
+
|
|
|
|
|
+ Arguments:
|
|
|
|
|
+ model_id (str): The Hugging Face repository ID.
|
|
|
|
|
+ **kwargs: Additional arguments to pass to the model constructor.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ (SAM2VideoPredictor): The loaded model.
|
|
|
|
|
+ """
|
|
|
|
|
+ from sam2.build_sam import build_sam2_video_predictor_hf
|
|
|
|
|
+
|
|
|
|
|
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
|
|
|
|
+ return sam_model
|
|
|
|
|
+
|
|
|
|
|
+ def _obj_id_to_idx(self, inference_state, obj_id):
|
|
|
|
|
+ """Map client-side object id to model-side object index."""
|
|
|
|
|
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
|
|
|
|
+ if obj_idx is not None:
|
|
|
|
|
+ return obj_idx
|
|
|
|
|
+
|
|
|
|
|
+ # This is a new object id not sent to the server before. We only allow adding
|
|
|
|
|
+ # new objects *before* the tracking starts.
|
|
|
|
|
+ allow_new_object = not inference_state["tracking_has_started"]
|
|
|
|
|
+ if allow_new_object:
|
|
|
|
|
+ # get the next object slot
|
|
|
|
|
+ obj_idx = len(inference_state["obj_id_to_idx"])
|
|
|
|
|
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
|
|
|
|
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
|
|
|
|
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
|
|
|
|
+ # set up input and output structures for this object
|
|
|
|
|
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
|
|
|
|
|
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
|
|
|
|
+ inference_state["output_dict_per_obj"][obj_idx] = {
|
|
|
|
|
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
|
|
|
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
|
|
|
+ }
|
|
|
|
|
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
|
|
|
|
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
|
|
|
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
|
|
|
+ }
|
|
|
|
|
+ return obj_idx
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise RuntimeError(
|
|
|
|
|
+ f"Cannot add new object id {obj_id} after tracking starts. "
|
|
|
|
|
+ f"All existing object ids: {inference_state['obj_ids']}. "
|
|
|
|
|
+ f"Please call 'reset_state' to restart from scratch."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _obj_idx_to_id(self, inference_state, obj_idx):
|
|
|
|
|
+ """Map model-side object index to client-side object id."""
|
|
|
|
|
+ return inference_state["obj_idx_to_id"][obj_idx]
|
|
|
|
|
+
|
|
|
|
|
+ def _get_obj_num(self, inference_state):
|
|
|
|
|
+ """Get the total number of unique object ids received so far in this session."""
|
|
|
|
|
+ return len(inference_state["obj_idx_to_id"])
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def add_new_points_or_box(
|
|
|
|
|
+ self,
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ obj_id,
|
|
|
|
|
+ points=None,
|
|
|
|
|
+ labels=None,
|
|
|
|
|
+ clear_old_points=True,
|
|
|
|
|
+ normalize_coords=True,
|
|
|
|
|
+ box=None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """Add new points to a frame."""
|
|
|
|
|
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
|
|
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
|
|
|
|
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
|
|
+
|
|
|
|
|
+ if (points is not None) != (labels is not None):
|
|
|
|
|
+ raise ValueError("points and labels must be provided together")
|
|
|
|
|
+ if points is None and box is None:
|
|
|
|
|
+ raise ValueError("at least one of points or box must be provided as input")
|
|
|
|
|
+
|
|
|
|
|
+ if points is None:
|
|
|
|
|
+ points = torch.zeros(0, 2, dtype=torch.float32)
|
|
|
|
|
+ elif not isinstance(points, torch.Tensor):
|
|
|
|
|
+ points = torch.tensor(points, dtype=torch.float32)
|
|
|
|
|
+ if labels is None:
|
|
|
|
|
+ labels = torch.zeros(0, dtype=torch.int32)
|
|
|
|
|
+ elif not isinstance(labels, torch.Tensor):
|
|
|
|
|
+ labels = torch.tensor(labels, dtype=torch.int32)
|
|
|
|
|
+ if points.dim() == 2:
|
|
|
|
|
+ points = points.unsqueeze(0) # add batch dimension
|
|
|
|
|
+ if labels.dim() == 1:
|
|
|
|
|
+ labels = labels.unsqueeze(0) # add batch dimension
|
|
|
|
|
+
|
|
|
|
|
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
|
|
|
|
|
+ # along with the user-provided points (consistent with how SAM 2 is trained).
|
|
|
|
|
+ if box is not None:
|
|
|
|
|
+ if not clear_old_points:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "cannot add box without clearing old points, since "
|
|
|
|
|
+ "box prompt must be provided before any point prompt "
|
|
|
|
|
+ "(please use clear_old_points=True instead)"
|
|
|
|
|
+ )
|
|
|
|
|
+ if inference_state["tracking_has_started"]:
|
|
|
|
|
+ warnings.warn(
|
|
|
|
|
+ "You are adding a box after tracking starts. SAM 2 may not always be "
|
|
|
|
|
+ "able to incorporate a box prompt for *refinement*. If you intend to "
|
|
|
|
|
+ "use box prompt as an *initial* input before tracking, please call "
|
|
|
|
|
+ "'reset_state' on the inference state to restart from scratch.",
|
|
|
|
|
+ category=UserWarning,
|
|
|
|
|
+ stacklevel=2,
|
|
|
|
|
+ )
|
|
|
|
|
+ if not isinstance(box, torch.Tensor):
|
|
|
|
|
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
|
|
|
|
+ box_coords = box.reshape(1, 2, 2)
|
|
|
|
|
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
|
|
|
|
+ box_labels = box_labels.reshape(1, 2)
|
|
|
|
|
+ points = torch.cat([box_coords, points], dim=1)
|
|
|
|
|
+ labels = torch.cat([box_labels, labels], dim=1)
|
|
|
|
|
+
|
|
|
|
|
+ if normalize_coords:
|
|
|
|
|
+ video_H = inference_state["video_height"]
|
|
|
|
|
+ video_W = inference_state["video_width"]
|
|
|
|
|
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
|
|
|
|
|
+ # scale the (normalized) coordinates by the model's internal image size
|
|
|
|
|
+ points = points * self.image_size
|
|
|
|
|
+ points = points.to(inference_state["device"])
|
|
|
|
|
+ labels = labels.to(inference_state["device"])
|
|
|
|
|
+
|
|
|
|
|
+ if not clear_old_points:
|
|
|
|
|
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
|
|
|
|
|
+ else:
|
|
|
|
|
+ point_inputs = None
|
|
|
|
|
+ point_inputs = concat_points(point_inputs, points, labels)
|
|
|
|
|
+
|
|
|
|
|
+ point_inputs_per_frame[frame_idx] = point_inputs
|
|
|
|
|
+ mask_inputs_per_frame.pop(frame_idx, None)
|
|
|
|
|
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
|
|
|
|
|
+ # frame, meaning that the inputs points are to generate segments on this frame without
|
|
|
|
|
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
|
|
|
|
+ # the input points will be used to correct the already tracked masks.
|
|
|
|
|
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
|
|
|
|
+ # whether to track in reverse time order
|
|
|
|
|
+ if is_init_cond_frame:
|
|
|
|
|
+ reverse = False
|
|
|
|
|
+ else:
|
|
|
|
|
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
|
|
|
|
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
|
|
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
+ # Add a frame to conditioning output if it's an initial conditioning frame or
|
|
|
|
|
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
|
|
|
|
|
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
|
|
|
|
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
+
|
|
|
|
|
+ # Get any previously predicted mask logits on this object and feed it along with
|
|
|
|
|
+ # the new clicks into the SAM mask decoder.
|
|
|
|
|
+ prev_sam_mask_logits = None
|
|
|
|
|
+ # lookup temporary output dict first, which contains the most recent output
|
|
|
|
|
+ # (if not found, then lookup conditioning and non-conditioning frame output)
|
|
|
|
|
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
|
|
|
|
+ if prev_out is None:
|
|
|
|
|
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
|
|
|
|
+ if prev_out is None:
|
|
|
|
|
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
|
|
|
|
+
|
|
|
|
|
+ if prev_out is not None and prev_out["pred_masks"] is not None:
|
|
|
|
|
+ device = inference_state["device"]
|
|
|
|
|
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
|
|
|
|
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
|
|
|
|
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
|
|
|
|
+ current_out, _ = self._run_single_frame_inference(
|
|
|
|
|
+ inference_state=inference_state,
|
|
|
|
|
+ output_dict=obj_output_dict, # run on the slice of a single object
|
|
|
|
|
+ frame_idx=frame_idx,
|
|
|
|
|
+ batch_size=1, # run on the slice of a single object
|
|
|
|
|
+ is_init_cond_frame=is_init_cond_frame,
|
|
|
|
|
+ point_inputs=point_inputs,
|
|
|
|
|
+ mask_inputs=None,
|
|
|
|
|
+ reverse=reverse,
|
|
|
|
|
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
|
|
|
|
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
|
|
|
|
+ # allows us to enforce non-overlapping constraints on all objects before encoding
|
|
|
|
|
+ # them into memory.
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ prev_sam_mask_logits=prev_sam_mask_logits,
|
|
|
|
|
+ )
|
|
|
|
|
+ # Add the output to the output dict (to be used as future memory)
|
|
|
|
|
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
+
|
|
|
|
|
+ # Resize the output mask to the original video resolution
|
|
|
|
|
+ obj_ids = inference_state["obj_ids"]
|
|
|
|
|
+ consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ is_cond=is_cond,
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ consolidate_at_video_res=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ _, video_res_masks = self._get_orig_video_res_output(
|
|
|
|
|
+ inference_state, consolidated_out["pred_masks_video_res"]
|
|
|
|
|
+ )
|
|
|
|
|
+ return frame_idx, obj_ids, video_res_masks
|
|
|
|
|
+
|
|
|
|
|
+ def add_new_points(self, *args, **kwargs):
|
|
|
|
|
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
|
|
|
|
|
+ return self.add_new_points_or_box(*args, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def add_new_mask(
|
|
|
|
|
+ self,
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ obj_id,
|
|
|
|
|
+ mask,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """Add new mask to a frame."""
|
|
|
|
|
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
|
|
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
|
|
|
|
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
|
|
+
|
|
|
|
|
+ if not isinstance(mask, torch.Tensor):
|
|
|
|
|
+ mask = torch.tensor(mask, dtype=torch.bool)
|
|
|
|
|
+ assert mask.dim() == 2
|
|
|
|
|
+ mask_H, mask_W = mask.shape
|
|
|
|
|
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
|
|
|
|
|
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
|
|
|
|
|
+
|
|
|
|
|
+ # resize the mask if it doesn't match the model's image size
|
|
|
|
|
+ if mask_H != self.image_size or mask_W != self.image_size:
|
|
|
|
|
+ mask_inputs = torch.nn.functional.interpolate(
|
|
|
|
|
+ mask_inputs_orig,
|
|
|
|
|
+ size=(self.image_size, self.image_size),
|
|
|
|
|
+ align_corners=False,
|
|
|
|
|
+ mode="bilinear",
|
|
|
|
|
+ antialias=True, # use antialias for downsampling
|
|
|
|
|
+ )
|
|
|
|
|
+ mask_inputs = (mask_inputs >= 0.5).float()
|
|
|
|
|
+ else:
|
|
|
|
|
+ mask_inputs = mask_inputs_orig
|
|
|
|
|
+
|
|
|
|
|
+ mask_inputs_per_frame[frame_idx] = mask_inputs
|
|
|
|
|
+ point_inputs_per_frame.pop(frame_idx, None)
|
|
|
|
|
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
|
|
|
|
|
+ # frame, meaning that the inputs points are to generate segments on this frame without
|
|
|
|
|
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
|
|
|
|
+ # the input points will be used to correct the already tracked masks.
|
|
|
|
|
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
|
|
|
|
+ # whether to track in reverse time order
|
|
|
|
|
+ if is_init_cond_frame:
|
|
|
|
|
+ reverse = False
|
|
|
|
|
+ else:
|
|
|
|
|
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
|
|
|
|
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
|
|
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
+ # Add a frame to conditioning output if it's an initial conditioning frame or
|
|
|
|
|
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
|
|
|
|
|
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
|
|
|
|
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
+
|
|
|
|
|
+ current_out, _ = self._run_single_frame_inference(
|
|
|
|
|
+ inference_state=inference_state,
|
|
|
|
|
+ output_dict=obj_output_dict, # run on the slice of a single object
|
|
|
|
|
+ frame_idx=frame_idx,
|
|
|
|
|
+ batch_size=1, # run on the slice of a single object
|
|
|
|
|
+ is_init_cond_frame=is_init_cond_frame,
|
|
|
|
|
+ point_inputs=None,
|
|
|
|
|
+ mask_inputs=mask_inputs,
|
|
|
|
|
+ reverse=reverse,
|
|
|
|
|
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
|
|
|
|
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
|
|
|
|
+ # allows us to enforce non-overlapping constraints on all objects before encoding
|
|
|
|
|
+ # them into memory.
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ # Add the output to the output dict (to be used as future memory)
|
|
|
|
|
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
+
|
|
|
|
|
+ # Resize the output mask to the original video resolution
|
|
|
|
|
+ obj_ids = inference_state["obj_ids"]
|
|
|
|
|
+ consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ is_cond=is_cond,
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ consolidate_at_video_res=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ _, video_res_masks = self._get_orig_video_res_output(
|
|
|
|
|
+ inference_state, consolidated_out["pred_masks_video_res"]
|
|
|
|
|
+ )
|
|
|
|
|
+ return frame_idx, obj_ids, video_res_masks
|
|
|
|
|
+
|
|
|
|
|
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
|
|
|
|
+ """
|
|
|
|
|
+ Resize the object scores to the original video resolution (video_res_masks)
|
|
|
|
|
+ and apply non-overlapping constraints for final output.
|
|
|
|
|
+ """
|
|
|
|
|
+ device = inference_state["device"]
|
|
|
|
|
+ video_H = inference_state["video_height"]
|
|
|
|
|
+ video_W = inference_state["video_width"]
|
|
|
|
|
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
|
|
|
|
|
+ if any_res_masks.shape[-2:] == (video_H, video_W):
|
|
|
|
|
+ video_res_masks = any_res_masks
|
|
|
|
|
+ else:
|
|
|
|
|
+ video_res_masks = torch.nn.functional.interpolate(
|
|
|
|
|
+ any_res_masks,
|
|
|
|
|
+ size=(video_H, video_W),
|
|
|
|
|
+ mode="bilinear",
|
|
|
|
|
+ align_corners=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ if self.non_overlap_masks:
|
|
|
|
|
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
|
|
|
|
+ return any_res_masks, video_res_masks
|
|
|
|
|
+
|
|
|
|
|
+ def _consolidate_temp_output_across_obj(
|
|
|
|
|
+ self,
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ is_cond,
|
|
|
|
|
+ run_mem_encoder,
|
|
|
|
|
+ consolidate_at_video_res=False,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
|
|
|
|
+ a frame into a single output for all objects, including
|
|
|
|
|
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
|
|
|
|
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
|
|
|
|
|
+ (if they don't exist in `output_dict_per_obj` for this frame);
|
|
|
|
|
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
|
|
|
|
|
+ on the object scores.
|
|
|
|
|
+ """
|
|
|
|
|
+ batch_size = self._get_obj_num(inference_state)
|
|
|
|
|
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
+ # Optionally, we allow consolidating the temporary outputs at the original
|
|
|
|
|
+ # video resolution (to provide a better editing experience for mask prompts).
|
|
|
|
|
+ if consolidate_at_video_res:
|
|
|
|
|
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
|
|
|
|
+ consolidated_H = inference_state["video_height"]
|
|
|
|
|
+ consolidated_W = inference_state["video_width"]
|
|
|
|
|
+ consolidated_mask_key = "pred_masks_video_res"
|
|
|
|
|
+ else:
|
|
|
|
|
+ consolidated_H = consolidated_W = self.image_size // 4
|
|
|
|
|
+ consolidated_mask_key = "pred_masks"
|
|
|
|
|
+
|
|
|
|
|
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
|
|
|
|
+ # will be added when rerunning the memory encoder after applying non-overlapping
|
|
|
|
|
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
|
|
|
|
|
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
|
|
|
|
|
+ consolidated_out = {
|
|
|
|
|
+ "maskmem_features": None,
|
|
|
|
|
+ "maskmem_pos_enc": None,
|
|
|
|
|
+ consolidated_mask_key: torch.full(
|
|
|
|
|
+ size=(batch_size, 1, consolidated_H, consolidated_W),
|
|
|
|
|
+ fill_value=NO_OBJ_SCORE,
|
|
|
|
|
+ dtype=torch.float32,
|
|
|
|
|
+ device=inference_state["storage_device"],
|
|
|
|
|
+ ),
|
|
|
|
|
+ "obj_ptr": torch.full(
|
|
|
|
|
+ size=(batch_size, self.hidden_dim),
|
|
|
|
|
+ fill_value=NO_OBJ_SCORE,
|
|
|
|
|
+ dtype=torch.float32,
|
|
|
|
|
+ device=inference_state["device"],
|
|
|
|
|
+ ),
|
|
|
|
|
+ "object_score_logits": torch.full(
|
|
|
|
|
+ size=(batch_size, 1),
|
|
|
|
|
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
|
|
|
|
|
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
|
|
|
|
+ fill_value=10.0,
|
|
|
|
|
+ dtype=torch.float32,
|
|
|
|
|
+ device=inference_state["device"],
|
|
|
|
|
+ ),
|
|
|
|
|
+ }
|
|
|
|
|
+ empty_mask_ptr = None
|
|
|
|
|
+ for obj_idx in range(batch_size):
|
|
|
|
|
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
|
|
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
|
|
|
|
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
|
|
|
|
+ # we fall back and look up its previous output in "output_dict_per_obj".
|
|
|
|
|
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
|
|
|
|
+ # "output_dict_per_obj" to find a previous output for this object.
|
|
|
|
|
+ if out is None:
|
|
|
|
|
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
|
|
|
|
+ if out is None:
|
|
|
|
|
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
|
|
|
|
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
|
|
|
|
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
|
|
|
|
+ # placeholder above) and set its object pointer to be a dummy pointer.
|
|
|
|
|
+ if out is None:
|
|
|
|
|
+ # Fill in dummy object pointers for those objects without any inputs or
|
|
|
|
|
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
|
|
|
|
+ # i.e. when we need to build the memory for tracking).
|
|
|
|
|
+ if run_mem_encoder:
|
|
|
|
|
+ if empty_mask_ptr is None:
|
|
|
|
|
+ empty_mask_ptr = self._get_empty_mask_ptr(
|
|
|
|
|
+ inference_state, frame_idx
|
|
|
|
|
+ )
|
|
|
|
|
+ # fill object pointer with a dummy pointer (based on an empty mask)
|
|
|
|
|
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
|
|
|
|
+ continue
|
|
|
|
|
+ # Add the temporary object output mask to consolidated output mask
|
|
|
|
|
+ obj_mask = out["pred_masks"]
|
|
|
|
|
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
|
|
|
|
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
|
|
|
|
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Resize first if temporary object mask has a different resolution
|
|
|
|
|
+ resized_obj_mask = torch.nn.functional.interpolate(
|
|
|
|
|
+ obj_mask,
|
|
|
|
|
+ size=consolidated_pred_masks.shape[-2:],
|
|
|
|
|
+ mode="bilinear",
|
|
|
|
|
+ align_corners=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
|
|
|
|
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
|
|
|
|
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
|
|
|
|
+ "object_score_logits"
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ # Optionally, apply non-overlapping constraints on the consolidated scores
|
|
|
|
|
+ # and rerun the memory encoder
|
|
|
|
|
+ if run_mem_encoder:
|
|
|
|
|
+ device = inference_state["device"]
|
|
|
|
|
+ high_res_masks = torch.nn.functional.interpolate(
|
|
|
|
|
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
|
|
|
|
|
+ size=(self.image_size, self.image_size),
|
|
|
|
|
+ mode="bilinear",
|
|
|
|
|
+ align_corners=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ if self.non_overlap_masks_for_mem_enc:
|
|
|
|
|
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
|
|
|
|
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
|
|
|
|
+ inference_state=inference_state,
|
|
|
|
|
+ frame_idx=frame_idx,
|
|
|
|
|
+ batch_size=batch_size,
|
|
|
|
|
+ high_res_masks=high_res_masks,
|
|
|
|
|
+ object_score_logits=consolidated_out["object_score_logits"],
|
|
|
|
|
+ is_mask_from_pts=True, # these frames are what the user interacted with
|
|
|
|
|
+ )
|
|
|
|
|
+ consolidated_out["maskmem_features"] = maskmem_features
|
|
|
|
|
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
+
|
|
|
|
|
+ return consolidated_out
|
|
|
|
|
+
|
|
|
|
|
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
|
|
|
|
+ """Get a dummy object pointer based on an empty mask on the current frame."""
|
|
|
|
|
+ # A dummy (empty) mask with a single object
|
|
|
|
|
+ batch_size = 1
|
|
|
|
|
+ mask_inputs = torch.zeros(
|
|
|
|
|
+ (batch_size, 1, self.image_size, self.image_size),
|
|
|
|
|
+ dtype=torch.float32,
|
|
|
|
|
+ device=inference_state["device"],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieve correct image features
|
|
|
|
|
+ (
|
|
|
|
|
+ _,
|
|
|
|
|
+ _,
|
|
|
|
|
+ current_vision_feats,
|
|
|
|
|
+ current_vision_pos_embeds,
|
|
|
|
|
+ feat_sizes,
|
|
|
|
|
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ # Feed the empty mask and image feature above to get a dummy object pointer
|
|
|
|
|
+ current_out = self.track_step(
|
|
|
|
|
+ frame_idx=frame_idx,
|
|
|
|
|
+ is_init_cond_frame=True,
|
|
|
|
|
+ current_vision_feats=current_vision_feats,
|
|
|
|
|
+ current_vision_pos_embeds=current_vision_pos_embeds,
|
|
|
|
|
+ feat_sizes=feat_sizes,
|
|
|
|
|
+ point_inputs=None,
|
|
|
|
|
+ mask_inputs=mask_inputs,
|
|
|
|
|
+ output_dict={},
|
|
|
|
|
+ num_frames=inference_state["num_frames"],
|
|
|
|
|
+ track_in_reverse=False,
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ prev_sam_mask_logits=None,
|
|
|
|
|
+ )
|
|
|
|
|
+ return current_out["obj_ptr"]
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def propagate_in_video_preflight(self, inference_state):
|
|
|
|
|
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
|
|
|
|
|
+ # Tracking has started and we don't allow adding new objects until session is reset.
|
|
|
|
|
+ inference_state["tracking_has_started"] = True
|
|
|
|
|
+ batch_size = self._get_obj_num(inference_state)
|
|
|
|
|
+
|
|
|
|
|
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
|
|
|
|
+ # add them into "output_dict".
|
|
|
|
|
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
|
|
|
|
+ output_dict = inference_state["output_dict"]
|
|
|
|
|
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
|
|
|
|
|
+ # temporary outputs have been added (either in this call or any previous calls
|
|
|
|
|
+ # to `propagate_in_video_preflight`).
|
|
|
|
|
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
|
|
+ for is_cond in [False, True]:
|
|
|
|
|
+ # Separately consolidate conditioning and non-conditioning temp outputs
|
|
|
|
|
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
+ # Find all the frames that contain temporary outputs for any objects
|
|
|
|
|
+ # (these should be the frames that have just received clicks for mask inputs
|
|
|
|
|
+ # via `add_new_points_or_box` or `add_new_mask`)
|
|
|
|
|
+ temp_frame_inds = set()
|
|
|
|
|
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
|
|
|
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
|
|
|
|
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
|
|
|
|
+ # consolidate the temporary output across all objects on this frame
|
|
|
|
|
+ for frame_idx in temp_frame_inds:
|
|
|
|
|
+ consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
|
|
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
|
|
|
|
+ )
|
|
|
|
|
+ # merge them into "output_dict" and also create per-object slices
|
|
|
|
|
+ output_dict[storage_key][frame_idx] = consolidated_out
|
|
|
|
|
+ self._add_output_per_object(
|
|
|
|
|
+ inference_state, frame_idx, consolidated_out, storage_key
|
|
|
|
|
+ )
|
|
|
|
|
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
|
|
|
|
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
|
|
|
|
+ )
|
|
|
|
|
+ if clear_non_cond_mem:
|
|
|
|
|
+ # clear non-conditioning memory of the surrounding frames
|
|
|
|
|
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
|
|
|
|
+
|
|
|
|
|
+ # clear temporary outputs in `temp_output_dict_per_obj`
|
|
|
|
|
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
|
|
|
+ obj_temp_output_dict[storage_key].clear()
|
|
|
|
|
+
|
|
|
|
|
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
|
|
|
|
+ # output on the same frame in "non_cond_frame_outputs"
|
|
|
|
|
+ for frame_idx in output_dict["cond_frame_outputs"]:
|
|
|
|
|
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
|
|
|
|
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
|
|
|
|
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
|
|
|
|
+ assert frame_idx in output_dict["cond_frame_outputs"]
|
|
|
|
|
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
|
|
|
|
+
|
|
|
|
|
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
|
|
|
|
+ # with either points or mask inputs (which should be true under a correct workflow).
|
|
|
|
|
+ all_consolidated_frame_inds = (
|
|
|
|
|
+ consolidated_frame_inds["cond_frame_outputs"]
|
|
|
|
|
+ | consolidated_frame_inds["non_cond_frame_outputs"]
|
|
|
|
|
+ )
|
|
|
|
|
+ input_frames_inds = set()
|
|
|
|
|
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
|
|
|
|
+ input_frames_inds.update(point_inputs_per_frame.keys())
|
|
|
|
|
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
|
|
|
|
+ input_frames_inds.update(mask_inputs_per_frame.keys())
|
|
|
|
|
+ assert all_consolidated_frame_inds == input_frames_inds
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def propagate_in_video(
|
|
|
|
|
+ self,
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ start_frame_idx=None,
|
|
|
|
|
+ max_frame_num_to_track=None,
|
|
|
|
|
+ reverse=False,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """Propagate the input points across frames to track in the entire video."""
|
|
|
|
|
+ self.propagate_in_video_preflight(inference_state)
|
|
|
|
|
+
|
|
|
|
|
+ output_dict = inference_state["output_dict"]
|
|
|
|
|
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
|
|
+ obj_ids = inference_state["obj_ids"]
|
|
|
|
|
+ num_frames = inference_state["num_frames"]
|
|
|
|
|
+ batch_size = self._get_obj_num(inference_state)
|
|
|
|
|
+ if len(output_dict["cond_frame_outputs"]) == 0:
|
|
|
|
|
+ raise RuntimeError("No points are provided; please add points first")
|
|
|
|
|
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
|
|
|
|
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # set start index, end index, and processing order
|
|
|
|
|
+ if start_frame_idx is None:
|
|
|
|
|
+ # default: start from the earliest frame with input points
|
|
|
|
|
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
|
|
|
|
|
+ if max_frame_num_to_track is None:
|
|
|
|
|
+ # default: track all the frames in the video
|
|
|
|
|
+ max_frame_num_to_track = num_frames
|
|
|
|
|
+ if reverse:
|
|
|
|
|
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
|
|
|
|
+ if start_frame_idx > 0:
|
|
|
|
|
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
|
|
|
|
+ else:
|
|
|
|
|
+ processing_order = [] # skip reverse tracking if starting from frame 0
|
|
|
|
|
+ else:
|
|
|
|
|
+ end_frame_idx = min(
|
|
|
|
|
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
|
|
|
|
|
+ )
|
|
|
|
|
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
|
|
|
|
|
+
|
|
|
|
|
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
|
|
|
|
+ # We skip those frames already in consolidated outputs (these are frames
|
|
|
|
|
+ # that received input clicks or mask). Note that we cannot directly run
|
|
|
|
|
+ # batched forward on them via `_run_single_frame_inference` because the
|
|
|
|
|
+ # number of clicks on each object might be different.
|
|
|
|
|
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
|
|
|
|
+ storage_key = "cond_frame_outputs"
|
|
|
|
|
+ current_out = output_dict[storage_key][frame_idx]
|
|
|
|
|
+ pred_masks = current_out["pred_masks"]
|
|
|
|
|
+ if clear_non_cond_mem:
|
|
|
|
|
+ # clear non-conditioning memory of the surrounding frames
|
|
|
|
|
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
|
|
|
|
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
|
|
|
|
+ storage_key = "non_cond_frame_outputs"
|
|
|
|
|
+ current_out = output_dict[storage_key][frame_idx]
|
|
|
|
|
+ pred_masks = current_out["pred_masks"]
|
|
|
|
|
+ else:
|
|
|
|
|
+ storage_key = "non_cond_frame_outputs"
|
|
|
|
|
+ current_out, pred_masks = self._run_single_frame_inference(
|
|
|
|
|
+ inference_state=inference_state,
|
|
|
|
|
+ output_dict=output_dict,
|
|
|
|
|
+ frame_idx=frame_idx,
|
|
|
|
|
+ batch_size=batch_size,
|
|
|
|
|
+ is_init_cond_frame=False,
|
|
|
|
|
+ point_inputs=None,
|
|
|
|
|
+ mask_inputs=None,
|
|
|
|
|
+ reverse=reverse,
|
|
|
|
|
+ run_mem_encoder=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
+ # Create slices of per-object outputs for subsequent interaction with each
|
|
|
|
|
+ # individual object after tracking.
|
|
|
|
|
+ self._add_output_per_object(
|
|
|
|
|
+ inference_state, frame_idx, current_out, storage_key
|
|
|
|
|
+ )
|
|
|
|
|
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
|
|
|
|
+
|
|
|
|
|
+ # Resize the output mask to the original video resolution (we directly use
|
|
|
|
|
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
|
|
|
|
|
+ _, video_res_masks = self._get_orig_video_res_output(
|
|
|
|
|
+ inference_state, pred_masks
|
|
|
|
|
+ )
|
|
|
|
|
+ yield frame_idx, obj_ids, video_res_masks
|
|
|
|
|
+
|
|
|
|
|
+ def _add_output_per_object(
|
|
|
|
|
+ self, inference_state, frame_idx, current_out, storage_key
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ Split a multi-object output into per-object output slices and add them into
|
|
|
|
|
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
|
|
|
|
|
+ """
|
|
|
|
|
+ maskmem_features = current_out["maskmem_features"]
|
|
|
|
|
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
|
|
|
|
+
|
|
|
|
|
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
|
|
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
|
|
|
|
+
|
|
|
|
|
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
|
|
|
|
|
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
|
|
|
|
+ obj_slice = slice(obj_idx, obj_idx + 1)
|
|
|
|
|
+ obj_out = {
|
|
|
|
|
+ "maskmem_features": None,
|
|
|
|
|
+ "maskmem_pos_enc": None,
|
|
|
|
|
+ "pred_masks": current_out["pred_masks"][obj_slice],
|
|
|
|
|
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
|
|
|
|
|
+ "object_score_logits": current_out["object_score_logits"][obj_slice],
|
|
|
|
|
+ }
|
|
|
|
|
+ if maskmem_features is not None:
|
|
|
|
|
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
|
|
|
|
+ if maskmem_pos_enc is not None:
|
|
|
|
|
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
|
|
|
|
+ obj_output_dict[storage_key][frame_idx] = obj_out
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def clear_all_prompts_in_frame(
|
|
|
|
|
+ self, inference_state, frame_idx, obj_id, need_output=True
|
|
|
|
|
+ ):
|
|
|
|
|
+ """Remove all input points or mask in a specific frame for a given object."""
|
|
|
|
|
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
|
|
+
|
|
|
|
|
+ # Clear the conditioning information on the given frame
|
|
|
|
|
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
|
|
|
|
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
|
|
|
|
+
|
|
|
|
|
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
|
|
|
|
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
+
|
|
|
|
|
+ # Check and see if there are still any inputs left on this frame
|
|
|
|
|
+ batch_size = self._get_obj_num(inference_state)
|
|
|
|
|
+ frame_has_input = False
|
|
|
|
|
+ for obj_idx2 in range(batch_size):
|
|
|
|
|
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
|
|
|
|
+ frame_has_input = True
|
|
|
|
|
+ break
|
|
|
|
|
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
|
|
|
|
+ frame_has_input = True
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ # If this frame has no remaining inputs for any objects, we further clear its
|
|
|
|
|
+ # conditioning frame status
|
|
|
|
|
+ if not frame_has_input:
|
|
|
|
|
+ output_dict = inference_state["output_dict"]
|
|
|
|
|
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
|
|
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
|
|
|
|
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
|
|
|
|
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
|
|
|
|
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
+ if out is not None:
|
|
|
|
|
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
|
|
|
|
|
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
|
|
|
|
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
|
|
|
|
|
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
|
|
|
|
|
+ # Similarly, do it for the sliced output on each object.
|
|
|
|
|
+ for obj_idx2 in range(batch_size):
|
|
|
|
|
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
|
|
|
|
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
+ if obj_out is not None:
|
|
|
|
|
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
|
|
|
|
+
|
|
|
|
|
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
|
|
|
|
|
+ if len(output_dict["cond_frame_outputs"]) == 0:
|
|
|
|
|
+ self._reset_tracking_results(inference_state)
|
|
|
|
|
+
|
|
|
|
|
+ if not need_output:
|
|
|
|
|
+ return
|
|
|
|
|
+ # Finally, output updated masks per object (after removing the inputs above)
|
|
|
|
|
+ obj_ids = inference_state["obj_ids"]
|
|
|
|
|
+ is_cond = any(
|
|
|
|
|
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
|
|
|
|
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
|
|
|
|
+ )
|
|
|
|
|
+ consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ is_cond=is_cond,
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ consolidate_at_video_res=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ _, video_res_masks = self._get_orig_video_res_output(
|
|
|
|
|
+ inference_state, consolidated_out["pred_masks_video_res"]
|
|
|
|
|
+ )
|
|
|
|
|
+ return frame_idx, obj_ids, video_res_masks
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def reset_state(self, inference_state):
|
|
|
|
|
+ """Remove all input points or mask in all frames throughout the video."""
|
|
|
|
|
+ self._reset_tracking_results(inference_state)
|
|
|
|
|
+ # Remove all object ids
|
|
|
|
|
+ inference_state["obj_id_to_idx"].clear()
|
|
|
|
|
+ inference_state["obj_idx_to_id"].clear()
|
|
|
|
|
+ inference_state["obj_ids"].clear()
|
|
|
|
|
+ inference_state["point_inputs_per_obj"].clear()
|
|
|
|
|
+ inference_state["mask_inputs_per_obj"].clear()
|
|
|
|
|
+ inference_state["output_dict_per_obj"].clear()
|
|
|
|
|
+ inference_state["temp_output_dict_per_obj"].clear()
|
|
|
|
|
+
|
|
|
|
|
+ def _reset_tracking_results(self, inference_state):
|
|
|
|
|
+ """Reset all tracking inputs and results across the videos."""
|
|
|
|
|
+ for v in inference_state["point_inputs_per_obj"].values():
|
|
|
|
|
+ v.clear()
|
|
|
|
|
+ for v in inference_state["mask_inputs_per_obj"].values():
|
|
|
|
|
+ v.clear()
|
|
|
|
|
+ for v in inference_state["output_dict_per_obj"].values():
|
|
|
|
|
+ v["cond_frame_outputs"].clear()
|
|
|
|
|
+ v["non_cond_frame_outputs"].clear()
|
|
|
|
|
+ for v in inference_state["temp_output_dict_per_obj"].values():
|
|
|
|
|
+ v["cond_frame_outputs"].clear()
|
|
|
|
|
+ v["non_cond_frame_outputs"].clear()
|
|
|
|
|
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
|
|
|
|
|
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
|
|
|
|
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
|
|
|
|
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
|
|
|
|
+ inference_state["tracking_has_started"] = False
|
|
|
|
|
+ inference_state["frames_already_tracked"].clear()
|
|
|
|
|
+
|
|
|
|
|
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
|
|
|
|
+ """Compute the image features on a given frame."""
|
|
|
|
|
+ # Look up in the cache first
|
|
|
|
|
+ image, backbone_out = inference_state["cached_features"].get(
|
|
|
|
|
+ frame_idx, (None, None)
|
|
|
|
|
+ )
|
|
|
|
|
+ if backbone_out is None:
|
|
|
|
|
+ # Cache miss -- we will run inference on a single image
|
|
|
|
|
+ device = inference_state["device"]
|
|
|
|
|
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
|
|
|
|
+ backbone_out = self.forward_image(image)
|
|
|
|
|
+ # Cache the most recent frame's feature (for repeated interactions with
|
|
|
|
|
+ # a frame; we can use an LRU cache for more frames in the future).
|
|
|
|
|
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
|
|
|
|
+
|
|
|
|
|
+ # expand the features to have the same dimension as the number of objects
|
|
|
|
|
+ expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
|
|
|
+ expanded_backbone_out = {
|
|
|
|
|
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
|
|
|
|
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
|
|
|
|
+ }
|
|
|
|
|
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
|
|
|
|
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
|
|
|
|
+ batch_size, -1, -1, -1
|
|
|
|
|
+ )
|
|
|
|
|
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
|
|
|
|
+ pos = pos.expand(batch_size, -1, -1, -1)
|
|
|
|
|
+ expanded_backbone_out["vision_pos_enc"][i] = pos
|
|
|
|
|
+
|
|
|
|
|
+ features = self._prepare_backbone_features(expanded_backbone_out)
|
|
|
|
|
+ features = (expanded_image,) + features
|
|
|
|
|
+ return features
|
|
|
|
|
+
|
|
|
|
|
+ def _run_single_frame_inference(
|
|
|
|
|
+ self,
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ output_dict,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ batch_size,
|
|
|
|
|
+ is_init_cond_frame,
|
|
|
|
|
+ point_inputs,
|
|
|
|
|
+ mask_inputs,
|
|
|
|
|
+ reverse,
|
|
|
|
|
+ run_mem_encoder,
|
|
|
|
|
+ prev_sam_mask_logits=None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """Run tracking on a single frame based on current inputs and previous memory."""
|
|
|
|
|
+ # Retrieve correct image features
|
|
|
|
|
+ (
|
|
|
|
|
+ _,
|
|
|
|
|
+ _,
|
|
|
|
|
+ current_vision_feats,
|
|
|
|
|
+ current_vision_pos_embeds,
|
|
|
|
|
+ feat_sizes,
|
|
|
|
|
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ # point and mask should not appear as input simultaneously on the same frame
|
|
|
|
|
+ assert point_inputs is None or mask_inputs is None
|
|
|
|
|
+ current_out = self.track_step(
|
|
|
|
|
+ frame_idx=frame_idx,
|
|
|
|
|
+ is_init_cond_frame=is_init_cond_frame,
|
|
|
|
|
+ current_vision_feats=current_vision_feats,
|
|
|
|
|
+ current_vision_pos_embeds=current_vision_pos_embeds,
|
|
|
|
|
+ feat_sizes=feat_sizes,
|
|
|
|
|
+ point_inputs=point_inputs,
|
|
|
|
|
+ mask_inputs=mask_inputs,
|
|
|
|
|
+ output_dict=output_dict,
|
|
|
|
|
+ num_frames=inference_state["num_frames"],
|
|
|
|
|
+ track_in_reverse=reverse,
|
|
|
|
|
+ run_mem_encoder=run_mem_encoder,
|
|
|
|
|
+ prev_sam_mask_logits=prev_sam_mask_logits,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # optionally offload the output to CPU memory to save GPU space
|
|
|
|
|
+ storage_device = inference_state["storage_device"]
|
|
|
|
|
+ maskmem_features = current_out["maskmem_features"]
|
|
|
|
|
+ if maskmem_features is not None:
|
|
|
|
|
+ maskmem_features = maskmem_features.to(torch.bfloat16)
|
|
|
|
|
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
|
|
|
|
+ pred_masks_gpu = current_out["pred_masks"]
|
|
|
|
|
+ # potentially fill holes in the predicted masks
|
|
|
|
|
+ if self.fill_hole_area > 0:
|
|
|
|
|
+ pred_masks_gpu = fill_holes_in_mask_scores(
|
|
|
|
|
+ pred_masks_gpu, self.fill_hole_area
|
|
|
|
|
+ )
|
|
|
|
|
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
|
|
|
|
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
|
|
|
|
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
|
|
|
|
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
|
|
|
|
+ obj_ptr = current_out["obj_ptr"]
|
|
|
|
|
+ object_score_logits = current_out["object_score_logits"]
|
|
|
|
|
+ # make a compact version of this frame's output to reduce the state size
|
|
|
|
|
+ compact_current_out = {
|
|
|
|
|
+ "maskmem_features": maskmem_features,
|
|
|
|
|
+ "maskmem_pos_enc": maskmem_pos_enc,
|
|
|
|
|
+ "pred_masks": pred_masks,
|
|
|
|
|
+ "obj_ptr": obj_ptr,
|
|
|
|
|
+ "object_score_logits": object_score_logits,
|
|
|
|
|
+ }
|
|
|
|
|
+ return compact_current_out, pred_masks_gpu
|
|
|
|
|
+
|
|
|
|
|
+ def _run_memory_encoder(
|
|
|
|
|
+ self,
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ batch_size,
|
|
|
|
|
+ high_res_masks,
|
|
|
|
|
+ object_score_logits,
|
|
|
|
|
+ is_mask_from_pts,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ Run the memory encoder on `high_res_masks`. This is usually after applying
|
|
|
|
|
+ non-overlapping constraints to object scores. Since their scores changed, their
|
|
|
|
|
+ memory also need to be computed again with the memory encoder.
|
|
|
|
|
+ """
|
|
|
|
|
+ # Retrieve correct image features
|
|
|
|
|
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
|
|
|
|
+ inference_state, frame_idx, batch_size
|
|
|
|
|
+ )
|
|
|
|
|
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
|
|
|
+ current_vision_feats=current_vision_feats,
|
|
|
|
|
+ feat_sizes=feat_sizes,
|
|
|
|
|
+ pred_masks_high_res=high_res_masks,
|
|
|
|
|
+ object_score_logits=object_score_logits,
|
|
|
|
|
+ is_mask_from_pts=is_mask_from_pts,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # optionally offload the output to CPU memory to save GPU space
|
|
|
|
|
+ storage_device = inference_state["storage_device"]
|
|
|
|
|
+ maskmem_features = maskmem_features.to(torch.bfloat16)
|
|
|
|
|
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
|
|
|
|
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
|
|
|
|
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
|
|
|
|
|
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
|
|
|
|
+ )
|
|
|
|
|
+ return maskmem_features, maskmem_pos_enc
|
|
|
|
|
+
|
|
|
|
|
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
|
|
|
|
|
+ """
|
|
|
|
|
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
|
|
|
|
+ a constant in the inference session to reduce session storage size.
|
|
|
|
|
+ """
|
|
|
|
|
+ model_constants = inference_state["constants"]
|
|
|
|
|
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
|
|
|
|
|
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
|
|
+ if out_maskmem_pos_enc is not None:
|
|
|
|
|
+ if "maskmem_pos_enc" not in model_constants:
|
|
|
|
|
+ assert isinstance(out_maskmem_pos_enc, list)
|
|
|
|
|
+ # only take the slice for one object, since it's same across objects
|
|
|
|
|
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
|
|
|
|
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
+ else:
|
|
|
|
|
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
|
|
|
|
+ # expand the cached maskmem_pos_enc to the actual batch size
|
|
|
|
|
+ batch_size = out_maskmem_pos_enc[0].size(0)
|
|
|
|
|
+ expanded_maskmem_pos_enc = [
|
|
|
|
|
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
|
|
|
|
+ ]
|
|
|
|
|
+ else:
|
|
|
|
|
+ expanded_maskmem_pos_enc = None
|
|
|
|
|
+ return expanded_maskmem_pos_enc
|
|
|
|
|
+
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
|
|
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
|
|
|
|
|
+ """
|
|
|
|
|
+ Remove an object id from the tracking state. If strict is True, we check whether
|
|
|
|
|
+ the object id actually exists and raise an error if it doesn't exist.
|
|
|
|
|
+ """
|
|
|
|
|
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
|
|
|
|
|
+ updated_frames = []
|
|
|
|
|
+ # Check whether this object_id to remove actually exists and possibly raise an error.
|
|
|
|
|
+ if old_obj_idx_to_rm is None:
|
|
|
|
|
+ if not strict:
|
|
|
|
|
+ return inference_state["obj_ids"], updated_frames
|
|
|
|
|
+ raise RuntimeError(
|
|
|
|
|
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
|
|
|
|
|
+ f"All existing object ids: {inference_state['obj_ids']}."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # If this is the only remaining object id, we simply reset the state.
|
|
|
|
|
+ if len(inference_state["obj_id_to_idx"]) == 1:
|
|
|
|
|
+ self.reset_state(inference_state)
|
|
|
|
|
+ return inference_state["obj_ids"], updated_frames
|
|
|
|
|
+
|
|
|
|
|
+ # There are still remaining objects after removing this object id. In this case,
|
|
|
|
|
+ # we need to delete the object storage from inference state tensors.
|
|
|
|
|
+ # Step 0: clear the input on those frames where this object id has point or mask input
|
|
|
|
|
+ # (note that this step is required as it might downgrade conditioning frames to
|
|
|
|
|
+ # non-conditioning ones)
|
|
|
|
|
+ obj_input_frames_inds = set()
|
|
|
|
|
+ obj_input_frames_inds.update(
|
|
|
|
|
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
|
|
|
|
|
+ )
|
|
|
|
|
+ obj_input_frames_inds.update(
|
|
|
|
|
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
|
|
|
|
+ )
|
|
|
|
|
+ for frame_idx in obj_input_frames_inds:
|
|
|
|
|
+ self.clear_all_prompts_in_frame(
|
|
|
|
|
+ inference_state, frame_idx, obj_id, need_output=False
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
|
|
|
|
|
+ # since Step 0 still requires the old object id mappings in inference_state)
|
|
|
|
|
+ old_obj_ids = inference_state["obj_ids"]
|
|
|
|
|
+ old_obj_inds = list(range(len(old_obj_ids)))
|
|
|
|
|
+ remain_old_obj_inds = old_obj_inds.copy()
|
|
|
|
|
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
|
|
|
|
|
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
|
|
|
|
|
+ new_obj_inds = list(range(len(new_obj_ids)))
|
|
|
|
|
+ # build new mappings
|
|
|
|
|
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
|
|
|
|
|
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
|
|
|
|
|
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
|
|
|
|
|
+ inference_state["obj_ids"] = new_obj_ids
|
|
|
|
|
+
|
|
|
|
|
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
|
|
|
|
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
|
|
|
|
+ # it's already handled in Step 0)
|
|
|
|
|
+ def _map_keys(container):
|
|
|
|
|
+ new_kvs = []
|
|
|
|
|
+ for k in old_obj_inds:
|
|
|
|
|
+ v = container.pop(k)
|
|
|
|
|
+ if k in old_idx_to_new_idx:
|
|
|
|
|
+ new_kvs.append((old_idx_to_new_idx[k], v))
|
|
|
|
|
+ container.update(new_kvs)
|
|
|
|
|
+
|
|
|
|
|
+ _map_keys(inference_state["point_inputs_per_obj"])
|
|
|
|
|
+ _map_keys(inference_state["mask_inputs_per_obj"])
|
|
|
|
|
+ _map_keys(inference_state["output_dict_per_obj"])
|
|
|
|
|
+ _map_keys(inference_state["temp_output_dict_per_obj"])
|
|
|
|
|
+
|
|
|
|
|
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
|
|
|
|
+ def _slice_state(output_dict, storage_key):
|
|
|
|
|
+ for frame_idx, out in output_dict[storage_key].items():
|
|
|
|
|
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
|
|
|
|
+ out["maskmem_pos_enc"] = [
|
|
|
|
|
+ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
|
|
|
|
+ ]
|
|
|
|
|
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
|
|
|
|
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
|
|
|
|
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
|
|
|
|
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
|
|
|
|
+ out["object_score_logits"] = out["object_score_logits"][
|
|
|
|
|
+ remain_old_obj_inds
|
|
|
|
|
+ ]
|
|
|
|
|
+ # also update the per-object slices
|
|
|
|
|
+ self._add_output_per_object(
|
|
|
|
|
+ inference_state, frame_idx, out, storage_key
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
|
|
|
|
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
|
|
|
|
+
|
|
|
|
|
+ # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
|
|
|
|
+ # could show an updated mask for objects previously occluded by the object being removed
|
|
|
|
|
+ if need_output:
|
|
|
|
|
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
|
|
|
|
+ for frame_idx in obj_input_frames_inds:
|
|
|
|
|
+ is_cond = any(
|
|
|
|
|
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
|
|
|
|
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
|
|
|
|
+ )
|
|
|
|
|
+ consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
|
|
+ inference_state,
|
|
|
|
|
+ frame_idx,
|
|
|
|
|
+ is_cond=is_cond,
|
|
|
|
|
+ run_mem_encoder=False,
|
|
|
|
|
+ consolidate_at_video_res=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ _, video_res_masks = self._get_orig_video_res_output(
|
|
|
|
|
+ inference_state, consolidated_out["pred_masks_video_res"]
|
|
|
|
|
+ )
|
|
|
|
|
+ updated_frames.append((frame_idx, video_res_masks))
|
|
|
|
|
+
|
|
|
|
|
+ return inference_state["obj_ids"], updated_frames
|
|
|
|
|
+
|
|
|
|
|
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
|
|
|
|
+ """
|
|
|
|
|
+ Remove the non-conditioning memory around the input frame. When users provide
|
|
|
|
|
+ correction clicks, the surrounding frames' non-conditioning memories can still
|
|
|
|
|
+ contain outdated object appearance information and could confuse the model.
|
|
|
|
|
+
|
|
|
|
|
+ This method clears those non-conditioning memories surrounding the interacted
|
|
|
|
|
+ frame to avoid giving the model both old and new information about the object.
|
|
|
|
|
+ """
|
|
|
|
|
+ r = self.memory_temporal_stride_for_eval
|
|
|
|
|
+ frame_idx_begin = frame_idx - r * self.num_maskmem
|
|
|
|
|
+ frame_idx_end = frame_idx + r * self.num_maskmem
|
|
|
|
|
+ output_dict = inference_state["output_dict"]
|
|
|
|
|
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
|
|
|
|
+ for t in range(frame_idx_begin, frame_idx_end + 1):
|
|
|
|
|
+ non_cond_frame_outputs.pop(t, None)
|
|
|
|
|
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
|
|
|
|
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|