| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import contextlib
- import logging
- import os
- import uuid
- from pathlib import Path
- from threading import Lock
- from typing import Any, Dict, Generator, List
- import numpy as np
- import torch
- from app_conf import APP_ROOT, MODEL_SIZE
- from inference.data_types import (
- AddMaskRequest,
- AddPointsRequest,
- CancelPorpagateResponse,
- CancelPropagateInVideoRequest,
- ClearPointsInFrameRequest,
- ClearPointsInVideoRequest,
- ClearPointsInVideoResponse,
- CloseSessionRequest,
- CloseSessionResponse,
- Mask,
- PropagateDataResponse,
- PropagateDataValue,
- PropagateInVideoRequest,
- RemoveObjectRequest,
- RemoveObjectResponse,
- StartSessionRequest,
- StartSessionResponse,
- )
- from pycocotools.mask import decode as decode_masks, encode as encode_masks
- from sam2.build_sam import build_sam2_video_predictor
- logger = logging.getLogger(__name__)
- class InferenceAPI:
- def __init__(self) -> None:
- super(InferenceAPI, self).__init__()
- self.session_states: Dict[str, Any] = {}
- self.score_thresh = 0
- if MODEL_SIZE == "tiny":
- checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
- model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
- elif MODEL_SIZE == "small":
- checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
- model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
- elif MODEL_SIZE == "large":
- checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
- model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
- else: # base_plus (default)
- checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
- model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
- # select the device for computation
- force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
- if force_cpu_device:
- logger.info("forcing CPU device for SAM 2 demo")
- if torch.cuda.is_available() and not force_cpu_device:
- device = torch.device("cuda")
- elif torch.backends.mps.is_available() and not force_cpu_device:
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
- logger.info(f"using device: {device}")
- if device.type == "cuda":
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
- if torch.cuda.get_device_properties(0).major >= 8:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
- elif device.type == "mps":
- logging.warning(
- "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
- "give numerically different outputs and sometimes degraded performance on MPS. "
- "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
- )
- self.device = device
- self.predictor = build_sam2_video_predictor(
- model_cfg, checkpoint, device=device
- )
- self.inference_lock = Lock()
- def autocast_context(self):
- if self.device.type == "cuda":
- return torch.autocast("cuda", dtype=torch.bfloat16)
- else:
- return contextlib.nullcontext()
- def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
- with self.autocast_context(), self.inference_lock:
- session_id = str(uuid.uuid4())
- # for MPS devices, we offload the video frames to CPU by default to avoid
- # memory fragmentation in MPS (which sometimes crashes the entire process)
- offload_video_to_cpu = self.device.type == "mps"
- inference_state = self.predictor.init_state(
- request.path,
- offload_video_to_cpu=offload_video_to_cpu,
- )
- self.session_states[session_id] = {
- "canceled": False,
- "state": inference_state,
- }
- return StartSessionResponse(session_id=session_id)
- def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
- is_successful = self.__clear_session_state(request.session_id)
- return CloseSessionResponse(success=is_successful)
- def add_points(
- self, request: AddPointsRequest, test: str = ""
- ) -> PropagateDataResponse:
- with self.autocast_context(), self.inference_lock:
- session = self.__get_session(request.session_id)
- inference_state = session["state"]
- frame_idx = request.frame_index
- obj_id = request.object_id
- points = request.points
- labels = request.labels
- clear_old_points = request.clear_old_points
- # add new prompts and instantly get the output on the same frame
- frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
- inference_state=inference_state,
- frame_idx=frame_idx,
- obj_id=obj_id,
- points=points,
- labels=labels,
- clear_old_points=clear_old_points,
- normalize_coords=False,
- )
- masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
- rle_mask_list = self.__get_rle_mask_list(
- object_ids=object_ids, masks=masks_binary
- )
- return PropagateDataResponse(
- frame_index=frame_idx,
- results=rle_mask_list,
- )
- def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
- """
- Add new points on a specific video frame.
- - mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
- Note: providing an input mask would overwrite any previous input points on this frame.
- """
- with self.autocast_context(), self.inference_lock:
- session_id = request.session_id
- frame_idx = request.frame_index
- obj_id = request.object_id
- rle_mask = {
- "counts": request.mask.counts,
- "size": request.mask.size,
- }
- mask = decode_masks(rle_mask)
- logger.info(
- f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
- )
- session = self.__get_session(session_id)
- inference_state = session["state"]
- frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
- inference_state=inference_state,
- frame_idx=frame_idx,
- obj_id=obj_id,
- mask=torch.tensor(mask > 0),
- )
- masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
- rle_mask_list = self.__get_rle_mask_list(
- object_ids=obj_ids, masks=masks_binary
- )
- return PropagateDataResponse(
- frame_index=frame_idx,
- results=rle_mask_list,
- )
- def clear_points_in_frame(
- self, request: ClearPointsInFrameRequest
- ) -> PropagateDataResponse:
- """
- Remove all input points in a specific frame.
- """
- with self.autocast_context(), self.inference_lock:
- session_id = request.session_id
- frame_idx = request.frame_index
- obj_id = request.object_id
- logger.info(
- f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
- )
- session = self.__get_session(session_id)
- inference_state = session["state"]
- frame_idx, obj_ids, video_res_masks = (
- self.predictor.clear_all_prompts_in_frame(
- inference_state, frame_idx, obj_id
- )
- )
- masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
- rle_mask_list = self.__get_rle_mask_list(
- object_ids=obj_ids, masks=masks_binary
- )
- return PropagateDataResponse(
- frame_index=frame_idx,
- results=rle_mask_list,
- )
- def clear_points_in_video(
- self, request: ClearPointsInVideoRequest
- ) -> ClearPointsInVideoResponse:
- """
- Remove all input points in all frames throughout the video.
- """
- with self.autocast_context(), self.inference_lock:
- session_id = request.session_id
- logger.info(f"clear all inputs across the video in session {session_id}")
- session = self.__get_session(session_id)
- inference_state = session["state"]
- self.predictor.reset_state(inference_state)
- return ClearPointsInVideoResponse(success=True)
- def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
- """
- Remove an object id from the tracking state.
- """
- with self.autocast_context(), self.inference_lock:
- session_id = request.session_id
- obj_id = request.object_id
- logger.info(f"remove object in session {session_id}: {obj_id=}")
- session = self.__get_session(session_id)
- inference_state = session["state"]
- new_obj_ids, updated_frames = self.predictor.remove_object(
- inference_state, obj_id
- )
- results = []
- for frame_index, video_res_masks in updated_frames:
- masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
- rle_mask_list = self.__get_rle_mask_list(
- object_ids=new_obj_ids, masks=masks
- )
- results.append(
- PropagateDataResponse(
- frame_index=frame_index,
- results=rle_mask_list,
- )
- )
- return RemoveObjectResponse(results=results)
- def propagate_in_video(
- self, request: PropagateInVideoRequest
- ) -> Generator[PropagateDataResponse, None, None]:
- session_id = request.session_id
- start_frame_idx = request.start_frame_index
- propagation_direction = "both"
- max_frame_num_to_track = None
- """
- Propagate existing input points in all frames to track the object across video.
- """
- # Note that as this method is a generator, we also need to use autocast_context
- # in caller to this method to ensure that it's called under the correct context
- # (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
- with self.autocast_context(), self.inference_lock:
- logger.info(
- f"propagate in video in session {session_id}: "
- f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
- )
- try:
- session = self.__get_session(session_id)
- session["canceled"] = False
- inference_state = session["state"]
- if propagation_direction not in ["both", "forward", "backward"]:
- raise ValueError(
- f"invalid propagation direction: {propagation_direction}"
- )
- # First doing the forward propagation
- if propagation_direction in ["both", "forward"]:
- for outputs in self.predictor.propagate_in_video(
- inference_state=inference_state,
- start_frame_idx=start_frame_idx,
- max_frame_num_to_track=max_frame_num_to_track,
- reverse=False,
- ):
- if session["canceled"]:
- return None
- frame_idx, obj_ids, video_res_masks = outputs
- masks_binary = (
- (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
- )
- rle_mask_list = self.__get_rle_mask_list(
- object_ids=obj_ids, masks=masks_binary
- )
- yield PropagateDataResponse(
- frame_index=frame_idx,
- results=rle_mask_list,
- )
- # Then doing the backward propagation (reverse in time)
- if propagation_direction in ["both", "backward"]:
- for outputs in self.predictor.propagate_in_video(
- inference_state=inference_state,
- start_frame_idx=start_frame_idx,
- max_frame_num_to_track=max_frame_num_to_track,
- reverse=True,
- ):
- if session["canceled"]:
- return None
- frame_idx, obj_ids, video_res_masks = outputs
- masks_binary = (
- (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
- )
- rle_mask_list = self.__get_rle_mask_list(
- object_ids=obj_ids, masks=masks_binary
- )
- yield PropagateDataResponse(
- frame_index=frame_idx,
- results=rle_mask_list,
- )
- finally:
- # Log upon completion (so that e.g. we can see if two propagations happen in parallel).
- # Using `finally` here to log even when the tracking is aborted with GeneratorExit.
- logger.info(
- f"propagation ended in session {session_id}; {self.__get_session_stats()}"
- )
- def cancel_propagate_in_video(
- self, request: CancelPropagateInVideoRequest
- ) -> CancelPorpagateResponse:
- session = self.__get_session(request.session_id)
- session["canceled"] = True
- return CancelPorpagateResponse(success=True)
- def __get_rle_mask_list(
- self, object_ids: List[int], masks: np.ndarray
- ) -> List[PropagateDataValue]:
- """
- Return a list of data values, i.e. list of object/mask combos.
- """
- return [
- self.__get_mask_for_object(object_id=object_id, mask=mask)
- for object_id, mask in zip(object_ids, masks)
- ]
- def __get_mask_for_object(
- self, object_id: int, mask: np.ndarray
- ) -> PropagateDataValue:
- """
- Create a data value for an object/mask combo.
- """
- mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
- mask_rle["counts"] = mask_rle["counts"].decode()
- return PropagateDataValue(
- object_id=object_id,
- mask=Mask(
- size=mask_rle["size"],
- counts=mask_rle["counts"],
- ),
- )
- def __get_session(self, session_id: str):
- session = self.session_states.get(session_id, None)
- if session is None:
- raise RuntimeError(
- f"Cannot find session {session_id}; it might have expired"
- )
- return session
- def __get_session_stats(self):
- """Get a statistics string for live sessions and their GPU usage."""
- # print both the session ids and their video frame numbers
- live_session_strs = [
- f"'{session_id}' ({session['state']['num_frames']} frames, "
- f"{len(session['state']['obj_ids'])} objects)"
- for session_id, session in self.session_states.items()
- ]
- session_stats_str = (
- "Test String Here - -"
- f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
- f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
- f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
- f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
- f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
- )
- return session_stats_str
- def __clear_session_state(self, session_id: str) -> bool:
- session = self.session_states.pop(session_id, None)
- if session is None:
- logger.warning(
- f"cannot close session {session_id} as it does not exist (it might have expired); "
- f"{self.__get_session_stats()}"
- )
- return False
- else:
- logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
- return True
|