|
|
@@ -4,6 +4,7 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
+import warnings
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
import torch
|
|
|
@@ -43,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
offload_state_to_cpu=False,
|
|
|
async_loading_frames=False,
|
|
|
):
|
|
|
- """Initialize a inference state."""
|
|
|
+ """Initialize an inference state."""
|
|
|
+ compute_device = self.device # device of the model
|
|
|
images, video_height, video_width = load_video_frames(
|
|
|
video_path=video_path,
|
|
|
image_size=self.image_size,
|
|
|
offload_video_to_cpu=offload_video_to_cpu,
|
|
|
async_loading_frames=async_loading_frames,
|
|
|
+ compute_device=compute_device,
|
|
|
)
|
|
|
inference_state = {}
|
|
|
inference_state["images"] = images
|
|
|
@@ -64,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
# the original video height and width, used for resizing final output scores
|
|
|
inference_state["video_height"] = video_height
|
|
|
inference_state["video_width"] = video_width
|
|
|
- inference_state["device"] = torch.device("cuda")
|
|
|
+ inference_state["device"] = compute_device
|
|
|
if offload_state_to_cpu:
|
|
|
inference_state["storage_device"] = torch.device("cpu")
|
|
|
else:
|
|
|
- inference_state["storage_device"] = torch.device("cuda")
|
|
|
+ inference_state["storage_device"] = compute_device
|
|
|
# inputs on each frame
|
|
|
inference_state["point_inputs_per_obj"] = {}
|
|
|
inference_state["mask_inputs_per_obj"] = {}
|
|
|
@@ -103,6 +106,23 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
|
|
return inference_state
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
|
|
+ """
|
|
|
+ Load a pretrained model from the Hugging Face hub.
|
|
|
+
|
|
|
+ Arguments:
|
|
|
+ model_id (str): The Hugging Face repository ID.
|
|
|
+ **kwargs: Additional arguments to pass to the model constructor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (SAM2VideoPredictor): The loaded model.
|
|
|
+ """
|
|
|
+ from sam2.build_sam import build_sam2_video_predictor_hf
|
|
|
+
|
|
|
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
|
|
+ return sam_model
|
|
|
+
|
|
|
def _obj_id_to_idx(self, inference_state, obj_id):
|
|
|
"""Map client-side object id to model-side object index."""
|
|
|
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
|
|
@@ -146,29 +166,66 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
return len(inference_state["obj_idx_to_id"])
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
- def add_new_points(
|
|
|
+ def add_new_points_or_box(
|
|
|
self,
|
|
|
inference_state,
|
|
|
frame_idx,
|
|
|
obj_id,
|
|
|
- points,
|
|
|
- labels,
|
|
|
+ points=None,
|
|
|
+ labels=None,
|
|
|
clear_old_points=True,
|
|
|
normalize_coords=True,
|
|
|
+ box=None,
|
|
|
):
|
|
|
"""Add new points to a frame."""
|
|
|
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
|
|
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
|
|
|
- if not isinstance(points, torch.Tensor):
|
|
|
+ if (points is not None) != (labels is not None):
|
|
|
+ raise ValueError("points and labels must be provided together")
|
|
|
+ if points is None and box is None:
|
|
|
+ raise ValueError("at least one of points or box must be provided as input")
|
|
|
+
|
|
|
+ if points is None:
|
|
|
+ points = torch.zeros(0, 2, dtype=torch.float32)
|
|
|
+ elif not isinstance(points, torch.Tensor):
|
|
|
points = torch.tensor(points, dtype=torch.float32)
|
|
|
- if not isinstance(labels, torch.Tensor):
|
|
|
+ if labels is None:
|
|
|
+ labels = torch.zeros(0, dtype=torch.int32)
|
|
|
+ elif not isinstance(labels, torch.Tensor):
|
|
|
labels = torch.tensor(labels, dtype=torch.int32)
|
|
|
if points.dim() == 2:
|
|
|
points = points.unsqueeze(0) # add batch dimension
|
|
|
if labels.dim() == 1:
|
|
|
labels = labels.unsqueeze(0) # add batch dimension
|
|
|
+
|
|
|
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
|
|
|
+ # along with the user-provided points (consistent with how SAM 2 is trained).
|
|
|
+ if box is not None:
|
|
|
+ if not clear_old_points:
|
|
|
+ raise ValueError(
|
|
|
+ "cannot add box without clearing old points, since "
|
|
|
+ "box prompt must be provided before any point prompt "
|
|
|
+ "(please use clear_old_points=True instead)"
|
|
|
+ )
|
|
|
+ if inference_state["tracking_has_started"]:
|
|
|
+ warnings.warn(
|
|
|
+ "You are adding a box after tracking starts. SAM 2 may not always be "
|
|
|
+ "able to incorporate a box prompt for *refinement*. If you intend to "
|
|
|
+ "use box prompt as an *initial* input before tracking, please call "
|
|
|
+ "'reset_state' on the inference state to restart from scratch.",
|
|
|
+ category=UserWarning,
|
|
|
+ stacklevel=2,
|
|
|
+ )
|
|
|
+ if not isinstance(box, torch.Tensor):
|
|
|
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
|
|
+ box_coords = box.reshape(1, 2, 2)
|
|
|
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
|
|
+ box_labels = box_labels.reshape(1, 2)
|
|
|
+ points = torch.cat([box_coords, points], dim=1)
|
|
|
+ labels = torch.cat([box_labels, labels], dim=1)
|
|
|
+
|
|
|
if normalize_coords:
|
|
|
video_H = inference_state["video_height"]
|
|
|
video_W = inference_state["video_width"]
|
|
|
@@ -215,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
|
|
|
|
|
if prev_out is not None and prev_out["pred_masks"] is not None:
|
|
|
- prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
|
|
+ device = inference_state["device"]
|
|
|
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
|
|
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
|
|
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
|
|
current_out, _ = self._run_single_frame_inference(
|
|
|
@@ -251,6 +309,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
)
|
|
|
return frame_idx, obj_ids, video_res_masks
|
|
|
|
|
|
+ def add_new_points(self, *args, **kwargs):
|
|
|
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
|
|
|
+ return self.add_new_points_or_box(*args, **kwargs)
|
|
|
+
|
|
|
@torch.inference_mode()
|
|
|
def add_new_mask(
|
|
|
self,
|
|
|
@@ -527,16 +589,16 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
# to `propagate_in_video_preflight`).
|
|
|
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
for is_cond in [False, True]:
|
|
|
- # Separately consolidate conditioning and non-conditioning temp outptus
|
|
|
+ # Separately consolidate conditioning and non-conditioning temp outputs
|
|
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
# Find all the frames that contain temporary outputs for any objects
|
|
|
# (these should be the frames that have just received clicks for mask inputs
|
|
|
- # via `add_new_points` or `add_new_mask`)
|
|
|
+ # via `add_new_points_or_box` or `add_new_mask`)
|
|
|
temp_frame_inds = set()
|
|
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
|
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
|
|
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
|
|
- # consolidate the temprary output across all objects on this frame
|
|
|
+ # consolidate the temporary output across all objects on this frame
|
|
|
for frame_idx in temp_frame_inds:
|
|
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
|
|
@@ -734,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|
|
)
|
|
|
if backbone_out is None:
|
|
|
# Cache miss -- we will run inference on a single image
|
|
|
- image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
|
|
+ device = inference_state["device"]
|
|
|
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
|
|
backbone_out = self.forward_image(image)
|
|
|
# Cache the most recent frame's feature (for repeated interactions with
|
|
|
# a frame; we can use an LRU cache for more frames in the future).
|