|
|
@@ -0,0 +1,427 @@
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
+# All rights reserved.
|
|
|
+# This source code is licensed under the license found in the
|
|
|
+# LICENSE file in the root directory of this source tree.
|
|
|
+
|
|
|
+import contextlib
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import uuid
|
|
|
+from pathlib import Path
|
|
|
+from threading import Lock
|
|
|
+from typing import Any, Dict, Generator, List
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from app_conf import APP_ROOT, MODEL_SIZE
|
|
|
+from inference.data_types import (
|
|
|
+ AddMaskRequest,
|
|
|
+ AddPointsRequest,
|
|
|
+ CancelPorpagateResponse,
|
|
|
+ CancelPropagateInVideoRequest,
|
|
|
+ ClearPointsInFrameRequest,
|
|
|
+ ClearPointsInVideoRequest,
|
|
|
+ ClearPointsInVideoResponse,
|
|
|
+ CloseSessionRequest,
|
|
|
+ CloseSessionResponse,
|
|
|
+ Mask,
|
|
|
+ PropagateDataResponse,
|
|
|
+ PropagateDataValue,
|
|
|
+ PropagateInVideoRequest,
|
|
|
+ RemoveObjectRequest,
|
|
|
+ RemoveObjectResponse,
|
|
|
+ StartSessionRequest,
|
|
|
+ StartSessionResponse,
|
|
|
+)
|
|
|
+from pycocotools.mask import decode as decode_masks, encode as encode_masks
|
|
|
+from sam2.build_sam import build_sam2_video_predictor
|
|
|
+
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class InferenceAPI:
|
|
|
+
|
|
|
+ def __init__(self) -> None:
|
|
|
+ super(InferenceAPI, self).__init__()
|
|
|
+
|
|
|
+ self.session_states: Dict[str, Any] = {}
|
|
|
+ self.score_thresh = 0
|
|
|
+
|
|
|
+ if MODEL_SIZE == "tiny":
|
|
|
+ checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
|
|
|
+ model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
|
|
+ elif MODEL_SIZE == "small":
|
|
|
+ checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
|
|
|
+ model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
|
|
|
+ elif MODEL_SIZE == "large":
|
|
|
+ checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
|
|
|
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
|
+ else: # base_plus (default)
|
|
|
+ checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
|
|
|
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
|
|
+
|
|
|
+ # select the device for computation
|
|
|
+ force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
|
|
|
+ if force_cpu_device:
|
|
|
+ logger.info("forcing CPU device for SAM 2 demo")
|
|
|
+ if torch.cuda.is_available() and not force_cpu_device:
|
|
|
+ device = torch.device("cuda")
|
|
|
+ elif torch.backends.mps.is_available() and not force_cpu_device:
|
|
|
+ device = torch.device("mps")
|
|
|
+ else:
|
|
|
+ device = torch.device("cpu")
|
|
|
+ logger.info(f"using device: {device}")
|
|
|
+
|
|
|
+ if device.type == "cuda":
|
|
|
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
|
+ if torch.cuda.get_device_properties(0).major >= 8:
|
|
|
+ torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
+ torch.backends.cudnn.allow_tf32 = True
|
|
|
+ elif device.type == "mps":
|
|
|
+ logging.warning(
|
|
|
+ "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
|
|
+ "give numerically different outputs and sometimes degraded performance on MPS. "
|
|
|
+ "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
|
|
|
+ )
|
|
|
+
|
|
|
+ self.device = device
|
|
|
+ self.predictor = build_sam2_video_predictor(
|
|
|
+ model_cfg, checkpoint, device=device
|
|
|
+ )
|
|
|
+ self.inference_lock = Lock()
|
|
|
+
|
|
|
+ def autocast_context(self):
|
|
|
+ if self.device.type == "cuda":
|
|
|
+ return torch.autocast("cuda", dtype=torch.bfloat16)
|
|
|
+ else:
|
|
|
+ return contextlib.nullcontext()
|
|
|
+
|
|
|
+ def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ session_id = str(uuid.uuid4())
|
|
|
+ # for MPS devices, we offload the video frames to CPU by default to avoid
|
|
|
+ # memory fragmentation in MPS (which sometimes crashes the entire process)
|
|
|
+ offload_video_to_cpu = self.device.type == "mps"
|
|
|
+ inference_state = self.predictor.init_state(
|
|
|
+ request.path,
|
|
|
+ offload_video_to_cpu=offload_video_to_cpu,
|
|
|
+ )
|
|
|
+ self.session_states[session_id] = {
|
|
|
+ "canceled": False,
|
|
|
+ "state": inference_state,
|
|
|
+ }
|
|
|
+ return StartSessionResponse(session_id=session_id)
|
|
|
+
|
|
|
+ def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
|
|
|
+ is_successful = self.__clear_session_state(request.session_id)
|
|
|
+ return CloseSessionResponse(success=is_successful)
|
|
|
+
|
|
|
+ def add_points(
|
|
|
+ self, request: AddPointsRequest, test: str = ""
|
|
|
+ ) -> PropagateDataResponse:
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ session = self.__get_session(request.session_id)
|
|
|
+ inference_state = session["state"]
|
|
|
+
|
|
|
+ frame_idx = request.frame_index
|
|
|
+ obj_id = request.object_id
|
|
|
+ points = request.points
|
|
|
+ labels = request.labels
|
|
|
+ clear_old_points = request.clear_old_points
|
|
|
+
|
|
|
+ # add new prompts and instantly get the output on the same frame
|
|
|
+ frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
|
|
|
+ inference_state=inference_state,
|
|
|
+ frame_idx=frame_idx,
|
|
|
+ obj_id=obj_id,
|
|
|
+ points=points,
|
|
|
+ labels=labels,
|
|
|
+ clear_old_points=clear_old_points,
|
|
|
+ normalize_coords=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
+
|
|
|
+ rle_mask_list = self.__get_rle_mask_list(
|
|
|
+ object_ids=object_ids, masks=masks_binary
|
|
|
+ )
|
|
|
+
|
|
|
+ return PropagateDataResponse(
|
|
|
+ frame_index=frame_idx,
|
|
|
+ results=rle_mask_list,
|
|
|
+ )
|
|
|
+
|
|
|
+ def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
|
|
|
+ """
|
|
|
+ Add new points on a specific video frame.
|
|
|
+ - mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
|
|
|
+ Note: providing an input mask would overwrite any previous input points on this frame.
|
|
|
+ """
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ session_id = request.session_id
|
|
|
+ frame_idx = request.frame_index
|
|
|
+ obj_id = request.object_id
|
|
|
+ rle_mask = {
|
|
|
+ "counts": request.mask.counts,
|
|
|
+ "size": request.mask.size,
|
|
|
+ }
|
|
|
+
|
|
|
+ mask = decode_masks(rle_mask)
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
|
|
|
+ )
|
|
|
+ session = self.__get_session(session_id)
|
|
|
+ inference_state = session["state"]
|
|
|
+
|
|
|
+ frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
|
|
|
+ inference_state=inference_state,
|
|
|
+ frame_idx=frame_idx,
|
|
|
+ obj_id=obj_id,
|
|
|
+ mask=torch.tensor(mask > 0),
|
|
|
+ )
|
|
|
+ masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
+
|
|
|
+ rle_mask_list = self.__get_rle_mask_list(
|
|
|
+ object_ids=obj_ids, masks=masks_binary
|
|
|
+ )
|
|
|
+
|
|
|
+ return PropagateDataResponse(
|
|
|
+ frame_index=frame_idx,
|
|
|
+ results=rle_mask_list,
|
|
|
+ )
|
|
|
+
|
|
|
+ def clear_points_in_frame(
|
|
|
+ self, request: ClearPointsInFrameRequest
|
|
|
+ ) -> PropagateDataResponse:
|
|
|
+ """
|
|
|
+ Remove all input points in a specific frame.
|
|
|
+ """
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ session_id = request.session_id
|
|
|
+ frame_idx = request.frame_index
|
|
|
+ obj_id = request.object_id
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
|
|
|
+ )
|
|
|
+ session = self.__get_session(session_id)
|
|
|
+ inference_state = session["state"]
|
|
|
+ frame_idx, obj_ids, video_res_masks = (
|
|
|
+ self.predictor.clear_all_prompts_in_frame(
|
|
|
+ inference_state, frame_idx, obj_id
|
|
|
+ )
|
|
|
+ )
|
|
|
+ masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
+
|
|
|
+ rle_mask_list = self.__get_rle_mask_list(
|
|
|
+ object_ids=obj_ids, masks=masks_binary
|
|
|
+ )
|
|
|
+
|
|
|
+ return PropagateDataResponse(
|
|
|
+ frame_index=frame_idx,
|
|
|
+ results=rle_mask_list,
|
|
|
+ )
|
|
|
+
|
|
|
+ def clear_points_in_video(
|
|
|
+ self, request: ClearPointsInVideoRequest
|
|
|
+ ) -> ClearPointsInVideoResponse:
|
|
|
+ """
|
|
|
+ Remove all input points in all frames throughout the video.
|
|
|
+ """
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ session_id = request.session_id
|
|
|
+ logger.info(f"clear all inputs across the video in session {session_id}")
|
|
|
+ session = self.__get_session(session_id)
|
|
|
+ inference_state = session["state"]
|
|
|
+ self.predictor.reset_state(inference_state)
|
|
|
+ return ClearPointsInVideoResponse(success=True)
|
|
|
+
|
|
|
+ def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
|
|
|
+ """
|
|
|
+ Remove an object id from the tracking state.
|
|
|
+ """
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ session_id = request.session_id
|
|
|
+ obj_id = request.object_id
|
|
|
+ logger.info(f"remove object in session {session_id}: {obj_id=}")
|
|
|
+ session = self.__get_session(session_id)
|
|
|
+ inference_state = session["state"]
|
|
|
+ new_obj_ids, updated_frames = self.predictor.remove_object(
|
|
|
+ inference_state, obj_id
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for frame_index, video_res_masks in updated_frames:
|
|
|
+ masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
+ rle_mask_list = self.__get_rle_mask_list(
|
|
|
+ object_ids=new_obj_ids, masks=masks
|
|
|
+ )
|
|
|
+ results.append(
|
|
|
+ PropagateDataResponse(
|
|
|
+ frame_index=frame_index,
|
|
|
+ results=rle_mask_list,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ return RemoveObjectResponse(results=results)
|
|
|
+
|
|
|
+ def propagate_in_video(
|
|
|
+ self, request: PropagateInVideoRequest
|
|
|
+ ) -> Generator[PropagateDataResponse, None, None]:
|
|
|
+ session_id = request.session_id
|
|
|
+ start_frame_idx = request.start_frame_index
|
|
|
+ propagation_direction = "both"
|
|
|
+ max_frame_num_to_track = None
|
|
|
+
|
|
|
+ """
|
|
|
+ Propagate existing input points in all frames to track the object across video.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Note that as this method is a generator, we also need to use autocast_context
|
|
|
+ # in caller to this method to ensure that it's called under the correct context
|
|
|
+ # (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
|
|
|
+ with self.autocast_context(), self.inference_lock:
|
|
|
+ logger.info(
|
|
|
+ f"propagate in video in session {session_id}: "
|
|
|
+ f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ session = self.__get_session(session_id)
|
|
|
+ session["canceled"] = False
|
|
|
+
|
|
|
+ inference_state = session["state"]
|
|
|
+ if propagation_direction not in ["both", "forward", "backward"]:
|
|
|
+ raise ValueError(
|
|
|
+ f"invalid propagation direction: {propagation_direction}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # First doing the forward propagation
|
|
|
+ if propagation_direction in ["both", "forward"]:
|
|
|
+ for outputs in self.predictor.propagate_in_video(
|
|
|
+ inference_state=inference_state,
|
|
|
+ start_frame_idx=start_frame_idx,
|
|
|
+ max_frame_num_to_track=max_frame_num_to_track,
|
|
|
+ reverse=False,
|
|
|
+ ):
|
|
|
+ if session["canceled"]:
|
|
|
+ return None
|
|
|
+
|
|
|
+ frame_idx, obj_ids, video_res_masks = outputs
|
|
|
+ masks_binary = (
|
|
|
+ (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
+ )
|
|
|
+
|
|
|
+ rle_mask_list = self.__get_rle_mask_list(
|
|
|
+ object_ids=obj_ids, masks=masks_binary
|
|
|
+ )
|
|
|
+
|
|
|
+ yield PropagateDataResponse(
|
|
|
+ frame_index=frame_idx,
|
|
|
+ results=rle_mask_list,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Then doing the backward propagation (reverse in time)
|
|
|
+ if propagation_direction in ["both", "backward"]:
|
|
|
+ for outputs in self.predictor.propagate_in_video(
|
|
|
+ inference_state=inference_state,
|
|
|
+ start_frame_idx=start_frame_idx,
|
|
|
+ max_frame_num_to_track=max_frame_num_to_track,
|
|
|
+ reverse=True,
|
|
|
+ ):
|
|
|
+ if session["canceled"]:
|
|
|
+ return None
|
|
|
+
|
|
|
+ frame_idx, obj_ids, video_res_masks = outputs
|
|
|
+ masks_binary = (
|
|
|
+ (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
+ )
|
|
|
+
|
|
|
+ rle_mask_list = self.__get_rle_mask_list(
|
|
|
+ object_ids=obj_ids, masks=masks_binary
|
|
|
+ )
|
|
|
+
|
|
|
+ yield PropagateDataResponse(
|
|
|
+ frame_index=frame_idx,
|
|
|
+ results=rle_mask_list,
|
|
|
+ )
|
|
|
+ finally:
|
|
|
+ # Log upon completion (so that e.g. we can see if two propagations happen in parallel).
|
|
|
+ # Using `finally` here to log even when the tracking is aborted with GeneratorExit.
|
|
|
+ logger.info(
|
|
|
+ f"propagation ended in session {session_id}; {self.__get_session_stats()}"
|
|
|
+ )
|
|
|
+
|
|
|
+ def cancel_propagate_in_video(
|
|
|
+ self, request: CancelPropagateInVideoRequest
|
|
|
+ ) -> CancelPorpagateResponse:
|
|
|
+ session = self.__get_session(request.session_id)
|
|
|
+ session["canceled"] = True
|
|
|
+ return CancelPorpagateResponse(success=True)
|
|
|
+
|
|
|
+ def __get_rle_mask_list(
|
|
|
+ self, object_ids: List[int], masks: np.ndarray
|
|
|
+ ) -> List[PropagateDataValue]:
|
|
|
+ """
|
|
|
+ Return a list of data values, i.e. list of object/mask combos.
|
|
|
+ """
|
|
|
+ return [
|
|
|
+ self.__get_mask_for_object(object_id=object_id, mask=mask)
|
|
|
+ for object_id, mask in zip(object_ids, masks)
|
|
|
+ ]
|
|
|
+
|
|
|
+ def __get_mask_for_object(
|
|
|
+ self, object_id: int, mask: np.ndarray
|
|
|
+ ) -> PropagateDataValue:
|
|
|
+ """
|
|
|
+ Create a data value for an object/mask combo.
|
|
|
+ """
|
|
|
+ mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
|
|
|
+ mask_rle["counts"] = mask_rle["counts"].decode()
|
|
|
+ return PropagateDataValue(
|
|
|
+ object_id=object_id,
|
|
|
+ mask=Mask(
|
|
|
+ size=mask_rle["size"],
|
|
|
+ counts=mask_rle["counts"],
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ def __get_session(self, session_id: str):
|
|
|
+ session = self.session_states.get(session_id, None)
|
|
|
+ if session is None:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Cannot find session {session_id}; it might have expired"
|
|
|
+ )
|
|
|
+ return session
|
|
|
+
|
|
|
+ def __get_session_stats(self):
|
|
|
+ """Get a statistics string for live sessions and their GPU usage."""
|
|
|
+ # print both the session ids and their video frame numbers
|
|
|
+ live_session_strs = [
|
|
|
+ f"'{session_id}' ({session['state']['num_frames']} frames, "
|
|
|
+ f"{len(session['state']['obj_ids'])} objects)"
|
|
|
+ for session_id, session in self.session_states.items()
|
|
|
+ ]
|
|
|
+ session_stats_str = (
|
|
|
+ "Test String Here - -"
|
|
|
+ f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
|
|
|
+ f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
|
|
|
+ f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
|
|
|
+ f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
|
|
|
+ f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
|
|
|
+ )
|
|
|
+ return session_stats_str
|
|
|
+
|
|
|
+ def __clear_session_state(self, session_id: str) -> bool:
|
|
|
+ session = self.session_states.pop(session_id, None)
|
|
|
+ if session is None:
|
|
|
+ logger.warning(
|
|
|
+ f"cannot close session {session_id} as it does not exist (it might have expired); "
|
|
|
+ f"{self.__get_session_stats()}"
|
|
|
+ )
|
|
|
+ return False
|
|
|
+ else:
|
|
|
+ logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
|
|
|
+ return True
|