predictor.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import contextlib
  6. import logging
  7. import os
  8. import uuid
  9. from pathlib import Path
  10. from threading import Lock
  11. from typing import Any, Dict, Generator, List
  12. import numpy as np
  13. import torch
  14. from app_conf import APP_ROOT, MODEL_SIZE
  15. from inference.data_types import (
  16. AddMaskRequest,
  17. AddPointsRequest,
  18. CancelPorpagateResponse,
  19. CancelPropagateInVideoRequest,
  20. ClearPointsInFrameRequest,
  21. ClearPointsInVideoRequest,
  22. ClearPointsInVideoResponse,
  23. CloseSessionRequest,
  24. CloseSessionResponse,
  25. Mask,
  26. PropagateDataResponse,
  27. PropagateDataValue,
  28. PropagateInVideoRequest,
  29. RemoveObjectRequest,
  30. RemoveObjectResponse,
  31. StartSessionRequest,
  32. StartSessionResponse,
  33. )
  34. from pycocotools.mask import decode as decode_masks, encode as encode_masks
  35. from sam2.build_sam import build_sam2_video_predictor
  36. logger = logging.getLogger(__name__)
  37. class InferenceAPI:
  38. def __init__(self) -> None:
  39. super(InferenceAPI, self).__init__()
  40. self.session_states: Dict[str, Any] = {}
  41. self.score_thresh = 0
  42. if MODEL_SIZE == "tiny":
  43. checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
  44. model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
  45. elif MODEL_SIZE == "small":
  46. checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
  47. model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
  48. elif MODEL_SIZE == "large":
  49. checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
  50. model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
  51. else: # base_plus (default)
  52. checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
  53. model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
  54. # select the device for computation
  55. force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
  56. if force_cpu_device:
  57. logger.info("forcing CPU device for SAM 2 demo")
  58. if torch.cuda.is_available() and not force_cpu_device:
  59. device = torch.device("cuda")
  60. elif torch.backends.mps.is_available() and not force_cpu_device:
  61. device = torch.device("mps")
  62. else:
  63. device = torch.device("cpu")
  64. logger.info(f"using device: {device}")
  65. if device.type == "cuda":
  66. # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
  67. if torch.cuda.get_device_properties(0).major >= 8:
  68. torch.backends.cuda.matmul.allow_tf32 = True
  69. torch.backends.cudnn.allow_tf32 = True
  70. elif device.type == "mps":
  71. logging.warning(
  72. "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
  73. "give numerically different outputs and sometimes degraded performance on MPS. "
  74. "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
  75. )
  76. self.device = device
  77. self.predictor = build_sam2_video_predictor(
  78. model_cfg, checkpoint, device=device
  79. )
  80. self.inference_lock = Lock()
  81. def autocast_context(self):
  82. if self.device.type == "cuda":
  83. return torch.autocast("cuda", dtype=torch.bfloat16)
  84. else:
  85. return contextlib.nullcontext()
  86. def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
  87. with self.autocast_context(), self.inference_lock:
  88. session_id = str(uuid.uuid4())
  89. # for MPS devices, we offload the video frames to CPU by default to avoid
  90. # memory fragmentation in MPS (which sometimes crashes the entire process)
  91. offload_video_to_cpu = self.device.type == "mps"
  92. inference_state = self.predictor.init_state(
  93. request.path,
  94. offload_video_to_cpu=offload_video_to_cpu,
  95. )
  96. self.session_states[session_id] = {
  97. "canceled": False,
  98. "state": inference_state,
  99. }
  100. return StartSessionResponse(session_id=session_id)
  101. def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
  102. is_successful = self.__clear_session_state(request.session_id)
  103. return CloseSessionResponse(success=is_successful)
  104. def add_points(
  105. self, request: AddPointsRequest, test: str = ""
  106. ) -> PropagateDataResponse:
  107. with self.autocast_context(), self.inference_lock:
  108. session = self.__get_session(request.session_id)
  109. inference_state = session["state"]
  110. frame_idx = request.frame_index
  111. obj_id = request.object_id
  112. points = request.points
  113. labels = request.labels
  114. clear_old_points = request.clear_old_points
  115. # add new prompts and instantly get the output on the same frame
  116. frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
  117. inference_state=inference_state,
  118. frame_idx=frame_idx,
  119. obj_id=obj_id,
  120. points=points,
  121. labels=labels,
  122. clear_old_points=clear_old_points,
  123. normalize_coords=False,
  124. )
  125. masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
  126. rle_mask_list = self.__get_rle_mask_list(
  127. object_ids=object_ids, masks=masks_binary
  128. )
  129. return PropagateDataResponse(
  130. frame_index=frame_idx,
  131. results=rle_mask_list,
  132. )
  133. def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
  134. """
  135. Add new points on a specific video frame.
  136. - mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
  137. Note: providing an input mask would overwrite any previous input points on this frame.
  138. """
  139. with self.autocast_context(), self.inference_lock:
  140. session_id = request.session_id
  141. frame_idx = request.frame_index
  142. obj_id = request.object_id
  143. rle_mask = {
  144. "counts": request.mask.counts,
  145. "size": request.mask.size,
  146. }
  147. mask = decode_masks(rle_mask)
  148. logger.info(
  149. f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
  150. )
  151. session = self.__get_session(session_id)
  152. inference_state = session["state"]
  153. frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
  154. inference_state=inference_state,
  155. frame_idx=frame_idx,
  156. obj_id=obj_id,
  157. mask=torch.tensor(mask > 0),
  158. )
  159. masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
  160. rle_mask_list = self.__get_rle_mask_list(
  161. object_ids=obj_ids, masks=masks_binary
  162. )
  163. return PropagateDataResponse(
  164. frame_index=frame_idx,
  165. results=rle_mask_list,
  166. )
  167. def clear_points_in_frame(
  168. self, request: ClearPointsInFrameRequest
  169. ) -> PropagateDataResponse:
  170. """
  171. Remove all input points in a specific frame.
  172. """
  173. with self.autocast_context(), self.inference_lock:
  174. session_id = request.session_id
  175. frame_idx = request.frame_index
  176. obj_id = request.object_id
  177. logger.info(
  178. f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
  179. )
  180. session = self.__get_session(session_id)
  181. inference_state = session["state"]
  182. frame_idx, obj_ids, video_res_masks = (
  183. self.predictor.clear_all_prompts_in_frame(
  184. inference_state, frame_idx, obj_id
  185. )
  186. )
  187. masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
  188. rle_mask_list = self.__get_rle_mask_list(
  189. object_ids=obj_ids, masks=masks_binary
  190. )
  191. return PropagateDataResponse(
  192. frame_index=frame_idx,
  193. results=rle_mask_list,
  194. )
  195. def clear_points_in_video(
  196. self, request: ClearPointsInVideoRequest
  197. ) -> ClearPointsInVideoResponse:
  198. """
  199. Remove all input points in all frames throughout the video.
  200. """
  201. with self.autocast_context(), self.inference_lock:
  202. session_id = request.session_id
  203. logger.info(f"clear all inputs across the video in session {session_id}")
  204. session = self.__get_session(session_id)
  205. inference_state = session["state"]
  206. self.predictor.reset_state(inference_state)
  207. return ClearPointsInVideoResponse(success=True)
  208. def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
  209. """
  210. Remove an object id from the tracking state.
  211. """
  212. with self.autocast_context(), self.inference_lock:
  213. session_id = request.session_id
  214. obj_id = request.object_id
  215. logger.info(f"remove object in session {session_id}: {obj_id=}")
  216. session = self.__get_session(session_id)
  217. inference_state = session["state"]
  218. new_obj_ids, updated_frames = self.predictor.remove_object(
  219. inference_state, obj_id
  220. )
  221. results = []
  222. for frame_index, video_res_masks in updated_frames:
  223. masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
  224. rle_mask_list = self.__get_rle_mask_list(
  225. object_ids=new_obj_ids, masks=masks
  226. )
  227. results.append(
  228. PropagateDataResponse(
  229. frame_index=frame_index,
  230. results=rle_mask_list,
  231. )
  232. )
  233. return RemoveObjectResponse(results=results)
  234. def propagate_in_video(
  235. self, request: PropagateInVideoRequest
  236. ) -> Generator[PropagateDataResponse, None, None]:
  237. session_id = request.session_id
  238. start_frame_idx = request.start_frame_index
  239. propagation_direction = "both"
  240. max_frame_num_to_track = None
  241. """
  242. Propagate existing input points in all frames to track the object across video.
  243. """
  244. # Note that as this method is a generator, we also need to use autocast_context
  245. # in caller to this method to ensure that it's called under the correct context
  246. # (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
  247. with self.autocast_context(), self.inference_lock:
  248. logger.info(
  249. f"propagate in video in session {session_id}: "
  250. f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
  251. )
  252. try:
  253. session = self.__get_session(session_id)
  254. session["canceled"] = False
  255. inference_state = session["state"]
  256. if propagation_direction not in ["both", "forward", "backward"]:
  257. raise ValueError(
  258. f"invalid propagation direction: {propagation_direction}"
  259. )
  260. # First doing the forward propagation
  261. if propagation_direction in ["both", "forward"]:
  262. for outputs in self.predictor.propagate_in_video(
  263. inference_state=inference_state,
  264. start_frame_idx=start_frame_idx,
  265. max_frame_num_to_track=max_frame_num_to_track,
  266. reverse=False,
  267. ):
  268. if session["canceled"]:
  269. return None
  270. frame_idx, obj_ids, video_res_masks = outputs
  271. masks_binary = (
  272. (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
  273. )
  274. rle_mask_list = self.__get_rle_mask_list(
  275. object_ids=obj_ids, masks=masks_binary
  276. )
  277. yield PropagateDataResponse(
  278. frame_index=frame_idx,
  279. results=rle_mask_list,
  280. )
  281. # Then doing the backward propagation (reverse in time)
  282. if propagation_direction in ["both", "backward"]:
  283. for outputs in self.predictor.propagate_in_video(
  284. inference_state=inference_state,
  285. start_frame_idx=start_frame_idx,
  286. max_frame_num_to_track=max_frame_num_to_track,
  287. reverse=True,
  288. ):
  289. if session["canceled"]:
  290. return None
  291. frame_idx, obj_ids, video_res_masks = outputs
  292. masks_binary = (
  293. (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
  294. )
  295. rle_mask_list = self.__get_rle_mask_list(
  296. object_ids=obj_ids, masks=masks_binary
  297. )
  298. yield PropagateDataResponse(
  299. frame_index=frame_idx,
  300. results=rle_mask_list,
  301. )
  302. finally:
  303. # Log upon completion (so that e.g. we can see if two propagations happen in parallel).
  304. # Using `finally` here to log even when the tracking is aborted with GeneratorExit.
  305. logger.info(
  306. f"propagation ended in session {session_id}; {self.__get_session_stats()}"
  307. )
  308. def cancel_propagate_in_video(
  309. self, request: CancelPropagateInVideoRequest
  310. ) -> CancelPorpagateResponse:
  311. session = self.__get_session(request.session_id)
  312. session["canceled"] = True
  313. return CancelPorpagateResponse(success=True)
  314. def __get_rle_mask_list(
  315. self, object_ids: List[int], masks: np.ndarray
  316. ) -> List[PropagateDataValue]:
  317. """
  318. Return a list of data values, i.e. list of object/mask combos.
  319. """
  320. return [
  321. self.__get_mask_for_object(object_id=object_id, mask=mask)
  322. for object_id, mask in zip(object_ids, masks)
  323. ]
  324. def __get_mask_for_object(
  325. self, object_id: int, mask: np.ndarray
  326. ) -> PropagateDataValue:
  327. """
  328. Create a data value for an object/mask combo.
  329. """
  330. mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
  331. mask_rle["counts"] = mask_rle["counts"].decode()
  332. return PropagateDataValue(
  333. object_id=object_id,
  334. mask=Mask(
  335. size=mask_rle["size"],
  336. counts=mask_rle["counts"],
  337. ),
  338. )
  339. def __get_session(self, session_id: str):
  340. session = self.session_states.get(session_id, None)
  341. if session is None:
  342. raise RuntimeError(
  343. f"Cannot find session {session_id}; it might have expired"
  344. )
  345. return session
  346. def __get_session_stats(self):
  347. """Get a statistics string for live sessions and their GPU usage."""
  348. # print both the session ids and their video frame numbers
  349. live_session_strs = [
  350. f"'{session_id}' ({session['state']['num_frames']} frames, "
  351. f"{len(session['state']['obj_ids'])} objects)"
  352. for session_id, session in self.session_states.items()
  353. ]
  354. session_stats_str = (
  355. "Test String Here - -"
  356. f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
  357. f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
  358. f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
  359. f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
  360. f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
  361. )
  362. return session_stats_str
  363. def __clear_session_state(self, session_id: str) -> bool:
  364. session = self.session_states.pop(session_id, None)
  365. if session is None:
  366. logger.warning(
  367. f"cannot close session {session_id} as it does not exist (it might have expired); "
  368. f"{self.__get_session_stats()}"
  369. )
  370. return False
  371. else:
  372. logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
  373. return True