sam2_video_predictor.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import warnings
  6. from collections import OrderedDict
  7. import torch
  8. from tqdm import tqdm
  9. from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
  10. from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
  11. class SAM2VideoPredictor(SAM2Base):
  12. """The predictor class to handle user interactions and manage inference states."""
  13. def __init__(
  14. self,
  15. fill_hole_area=0,
  16. # whether to apply non-overlapping constraints on the output object masks
  17. non_overlap_masks=False,
  18. # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
  19. # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
  20. clear_non_cond_mem_around_input=False,
  21. # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
  22. clear_non_cond_mem_for_multi_obj=False,
  23. **kwargs,
  24. ):
  25. super().__init__(**kwargs)
  26. self.fill_hole_area = fill_hole_area
  27. self.non_overlap_masks = non_overlap_masks
  28. self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
  29. self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
  30. @torch.inference_mode()
  31. def init_state(
  32. self,
  33. video_path,
  34. offload_video_to_cpu=False,
  35. offload_state_to_cpu=False,
  36. async_loading_frames=False,
  37. ):
  38. """Initialize a inference state."""
  39. images, video_height, video_width = load_video_frames(
  40. video_path=video_path,
  41. image_size=self.image_size,
  42. offload_video_to_cpu=offload_video_to_cpu,
  43. async_loading_frames=async_loading_frames,
  44. )
  45. inference_state = {}
  46. inference_state["images"] = images
  47. inference_state["num_frames"] = len(images)
  48. # whether to offload the video frames to CPU memory
  49. # turning on this option saves the GPU memory with only a very small overhead
  50. inference_state["offload_video_to_cpu"] = offload_video_to_cpu
  51. # whether to offload the inference state to CPU memory
  52. # turning on this option saves the GPU memory at the cost of a lower tracking fps
  53. # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
  54. # and from 24 to 21 when tracking two objects)
  55. inference_state["offload_state_to_cpu"] = offload_state_to_cpu
  56. # the original video height and width, used for resizing final output scores
  57. inference_state["video_height"] = video_height
  58. inference_state["video_width"] = video_width
  59. inference_state["device"] = torch.device("cuda")
  60. if offload_state_to_cpu:
  61. inference_state["storage_device"] = torch.device("cpu")
  62. else:
  63. inference_state["storage_device"] = torch.device("cuda")
  64. # inputs on each frame
  65. inference_state["point_inputs_per_obj"] = {}
  66. inference_state["mask_inputs_per_obj"] = {}
  67. # visual features on a small number of recently visited frames for quick interactions
  68. inference_state["cached_features"] = {}
  69. # values that don't change across frames (so we only need to hold one copy of them)
  70. inference_state["constants"] = {}
  71. # mapping between client-side object id and model-side object index
  72. inference_state["obj_id_to_idx"] = OrderedDict()
  73. inference_state["obj_idx_to_id"] = OrderedDict()
  74. inference_state["obj_ids"] = []
  75. # A storage to hold the model's tracking results and states on each frame
  76. inference_state["output_dict"] = {
  77. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  78. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  79. }
  80. # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
  81. inference_state["output_dict_per_obj"] = {}
  82. # A temporary storage to hold new outputs when user interact with a frame
  83. # to add clicks or mask (it's merged into "output_dict" before propagation starts)
  84. inference_state["temp_output_dict_per_obj"] = {}
  85. # Frames that already holds consolidated outputs from click or mask inputs
  86. # (we directly use their consolidated outputs during tracking)
  87. inference_state["consolidated_frame_inds"] = {
  88. "cond_frame_outputs": set(), # set containing frame indices
  89. "non_cond_frame_outputs": set(), # set containing frame indices
  90. }
  91. # metadata for each tracking frame (e.g. which direction it's tracked)
  92. inference_state["tracking_has_started"] = False
  93. inference_state["frames_already_tracked"] = {}
  94. # Warm up the visual backbone and cache the image feature on frame 0
  95. self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
  96. return inference_state
  97. @classmethod
  98. def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
  99. """
  100. Load a pretrained model from the Hugging Face hub.
  101. Arguments:
  102. model_id (str): The Hugging Face repository ID.
  103. **kwargs: Additional arguments to pass to the model constructor.
  104. Returns:
  105. (SAM2VideoPredictor): The loaded model.
  106. """
  107. from sam2.build_sam import build_sam2_video_predictor_hf
  108. sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
  109. return cls(sam_model)
  110. def _obj_id_to_idx(self, inference_state, obj_id):
  111. """Map client-side object id to model-side object index."""
  112. obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
  113. if obj_idx is not None:
  114. return obj_idx
  115. # This is a new object id not sent to the server before. We only allow adding
  116. # new objects *before* the tracking starts.
  117. allow_new_object = not inference_state["tracking_has_started"]
  118. if allow_new_object:
  119. # get the next object slot
  120. obj_idx = len(inference_state["obj_id_to_idx"])
  121. inference_state["obj_id_to_idx"][obj_id] = obj_idx
  122. inference_state["obj_idx_to_id"][obj_idx] = obj_id
  123. inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
  124. # set up input and output structures for this object
  125. inference_state["point_inputs_per_obj"][obj_idx] = {}
  126. inference_state["mask_inputs_per_obj"][obj_idx] = {}
  127. inference_state["output_dict_per_obj"][obj_idx] = {
  128. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  129. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  130. }
  131. inference_state["temp_output_dict_per_obj"][obj_idx] = {
  132. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  133. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  134. }
  135. return obj_idx
  136. else:
  137. raise RuntimeError(
  138. f"Cannot add new object id {obj_id} after tracking starts. "
  139. f"All existing object ids: {inference_state['obj_ids']}. "
  140. f"Please call 'reset_state' to restart from scratch."
  141. )
  142. def _obj_idx_to_id(self, inference_state, obj_idx):
  143. """Map model-side object index to client-side object id."""
  144. return inference_state["obj_idx_to_id"][obj_idx]
  145. def _get_obj_num(self, inference_state):
  146. """Get the total number of unique object ids received so far in this session."""
  147. return len(inference_state["obj_idx_to_id"])
  148. @torch.inference_mode()
  149. def add_new_points_or_box(
  150. self,
  151. inference_state,
  152. frame_idx,
  153. obj_id,
  154. points=None,
  155. labels=None,
  156. clear_old_points=True,
  157. normalize_coords=True,
  158. box=None,
  159. ):
  160. """Add new points to a frame."""
  161. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  162. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  163. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  164. if (points is not None) != (labels is not None):
  165. raise ValueError("points and labels must be provided together")
  166. if points is None and box is None:
  167. raise ValueError("at least one of points or box must be provided as input")
  168. if points is None:
  169. points = torch.zeros(0, 2, dtype=torch.float32)
  170. elif not isinstance(points, torch.Tensor):
  171. points = torch.tensor(points, dtype=torch.float32)
  172. if labels is None:
  173. labels = torch.zeros(0, dtype=torch.int32)
  174. elif not isinstance(labels, torch.Tensor):
  175. labels = torch.tensor(labels, dtype=torch.int32)
  176. if points.dim() == 2:
  177. points = points.unsqueeze(0) # add batch dimension
  178. if labels.dim() == 1:
  179. labels = labels.unsqueeze(0) # add batch dimension
  180. # If `box` is provided, we add it as the first two points with labels 2 and 3
  181. # along with the user-provided points (consistent with how SAM 2 is trained).
  182. if box is not None:
  183. if not clear_old_points:
  184. raise ValueError(
  185. "cannot add box without clearing old points, since "
  186. "box prompt must be provided before any point prompt "
  187. "(please use clear_old_points=True instead)"
  188. )
  189. if inference_state["tracking_has_started"]:
  190. warnings.warn(
  191. "You are adding a box after tracking starts. SAM 2 may not always be "
  192. "able to incorporate a box prompt for *refinement*. If you intend to "
  193. "use box prompt as an *initial* input before tracking, please call "
  194. "'reset_state' on the inference state to restart from scratch.",
  195. category=UserWarning,
  196. stacklevel=2,
  197. )
  198. if not isinstance(box, torch.Tensor):
  199. box = torch.tensor(box, dtype=torch.float32, device=points.device)
  200. box_coords = box.reshape(1, 2, 2)
  201. box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
  202. box_labels = box_labels.reshape(1, 2)
  203. points = torch.cat([box_coords, points], dim=1)
  204. labels = torch.cat([box_labels, labels], dim=1)
  205. if normalize_coords:
  206. video_H = inference_state["video_height"]
  207. video_W = inference_state["video_width"]
  208. points = points / torch.tensor([video_W, video_H]).to(points.device)
  209. # scale the (normalized) coordinates by the model's internal image size
  210. points = points * self.image_size
  211. points = points.to(inference_state["device"])
  212. labels = labels.to(inference_state["device"])
  213. if not clear_old_points:
  214. point_inputs = point_inputs_per_frame.get(frame_idx, None)
  215. else:
  216. point_inputs = None
  217. point_inputs = concat_points(point_inputs, points, labels)
  218. point_inputs_per_frame[frame_idx] = point_inputs
  219. mask_inputs_per_frame.pop(frame_idx, None)
  220. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  221. # frame, meaning that the inputs points are to generate segments on this frame without
  222. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  223. # the input points will be used to correct the already tracked masks.
  224. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  225. # whether to track in reverse time order
  226. if is_init_cond_frame:
  227. reverse = False
  228. else:
  229. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  230. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  231. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  232. # Add a frame to conditioning output if it's an initial conditioning frame or
  233. # if the model sees all frames receiving clicks/mask as conditioning frames.
  234. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  235. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  236. # Get any previously predicted mask logits on this object and feed it along with
  237. # the new clicks into the SAM mask decoder.
  238. prev_sam_mask_logits = None
  239. # lookup temporary output dict first, which contains the most recent output
  240. # (if not found, then lookup conditioning and non-conditioning frame output)
  241. prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
  242. if prev_out is None:
  243. prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
  244. if prev_out is None:
  245. prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
  246. if prev_out is not None and prev_out["pred_masks"] is not None:
  247. prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
  248. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
  249. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
  250. current_out, _ = self._run_single_frame_inference(
  251. inference_state=inference_state,
  252. output_dict=obj_output_dict, # run on the slice of a single object
  253. frame_idx=frame_idx,
  254. batch_size=1, # run on the slice of a single object
  255. is_init_cond_frame=is_init_cond_frame,
  256. point_inputs=point_inputs,
  257. mask_inputs=None,
  258. reverse=reverse,
  259. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  260. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  261. # allows us to enforce non-overlapping constraints on all objects before encoding
  262. # them into memory.
  263. run_mem_encoder=False,
  264. prev_sam_mask_logits=prev_sam_mask_logits,
  265. )
  266. # Add the output to the output dict (to be used as future memory)
  267. obj_temp_output_dict[storage_key][frame_idx] = current_out
  268. # Resize the output mask to the original video resolution
  269. obj_ids = inference_state["obj_ids"]
  270. consolidated_out = self._consolidate_temp_output_across_obj(
  271. inference_state,
  272. frame_idx,
  273. is_cond=is_cond,
  274. run_mem_encoder=False,
  275. consolidate_at_video_res=True,
  276. )
  277. _, video_res_masks = self._get_orig_video_res_output(
  278. inference_state, consolidated_out["pred_masks_video_res"]
  279. )
  280. return frame_idx, obj_ids, video_res_masks
  281. def add_new_points(self, *args, **kwargs):
  282. """Deprecated method. Please use `add_new_points_or_box` instead."""
  283. return self.add_new_points_or_box(*args, **kwargs)
  284. @torch.inference_mode()
  285. def add_new_mask(
  286. self,
  287. inference_state,
  288. frame_idx,
  289. obj_id,
  290. mask,
  291. ):
  292. """Add new mask to a frame."""
  293. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  294. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  295. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  296. if not isinstance(mask, torch.Tensor):
  297. mask = torch.tensor(mask, dtype=torch.bool)
  298. assert mask.dim() == 2
  299. mask_H, mask_W = mask.shape
  300. mask_inputs_orig = mask[None, None] # add batch and channel dimension
  301. mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
  302. # resize the mask if it doesn't match the model's image size
  303. if mask_H != self.image_size or mask_W != self.image_size:
  304. mask_inputs = torch.nn.functional.interpolate(
  305. mask_inputs_orig,
  306. size=(self.image_size, self.image_size),
  307. align_corners=False,
  308. mode="bilinear",
  309. antialias=True, # use antialias for downsampling
  310. )
  311. mask_inputs = (mask_inputs >= 0.5).float()
  312. else:
  313. mask_inputs = mask_inputs_orig
  314. mask_inputs_per_frame[frame_idx] = mask_inputs
  315. point_inputs_per_frame.pop(frame_idx, None)
  316. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  317. # frame, meaning that the inputs points are to generate segments on this frame without
  318. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  319. # the input points will be used to correct the already tracked masks.
  320. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  321. # whether to track in reverse time order
  322. if is_init_cond_frame:
  323. reverse = False
  324. else:
  325. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  326. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  327. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  328. # Add a frame to conditioning output if it's an initial conditioning frame or
  329. # if the model sees all frames receiving clicks/mask as conditioning frames.
  330. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  331. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  332. current_out, _ = self._run_single_frame_inference(
  333. inference_state=inference_state,
  334. output_dict=obj_output_dict, # run on the slice of a single object
  335. frame_idx=frame_idx,
  336. batch_size=1, # run on the slice of a single object
  337. is_init_cond_frame=is_init_cond_frame,
  338. point_inputs=None,
  339. mask_inputs=mask_inputs,
  340. reverse=reverse,
  341. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  342. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  343. # allows us to enforce non-overlapping constraints on all objects before encoding
  344. # them into memory.
  345. run_mem_encoder=False,
  346. )
  347. # Add the output to the output dict (to be used as future memory)
  348. obj_temp_output_dict[storage_key][frame_idx] = current_out
  349. # Resize the output mask to the original video resolution
  350. obj_ids = inference_state["obj_ids"]
  351. consolidated_out = self._consolidate_temp_output_across_obj(
  352. inference_state,
  353. frame_idx,
  354. is_cond=is_cond,
  355. run_mem_encoder=False,
  356. consolidate_at_video_res=True,
  357. )
  358. _, video_res_masks = self._get_orig_video_res_output(
  359. inference_state, consolidated_out["pred_masks_video_res"]
  360. )
  361. return frame_idx, obj_ids, video_res_masks
  362. def _get_orig_video_res_output(self, inference_state, any_res_masks):
  363. """
  364. Resize the object scores to the original video resolution (video_res_masks)
  365. and apply non-overlapping constraints for final output.
  366. """
  367. device = inference_state["device"]
  368. video_H = inference_state["video_height"]
  369. video_W = inference_state["video_width"]
  370. any_res_masks = any_res_masks.to(device, non_blocking=True)
  371. if any_res_masks.shape[-2:] == (video_H, video_W):
  372. video_res_masks = any_res_masks
  373. else:
  374. video_res_masks = torch.nn.functional.interpolate(
  375. any_res_masks,
  376. size=(video_H, video_W),
  377. mode="bilinear",
  378. align_corners=False,
  379. )
  380. if self.non_overlap_masks:
  381. video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
  382. return any_res_masks, video_res_masks
  383. def _consolidate_temp_output_across_obj(
  384. self,
  385. inference_state,
  386. frame_idx,
  387. is_cond,
  388. run_mem_encoder,
  389. consolidate_at_video_res=False,
  390. ):
  391. """
  392. Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
  393. a frame into a single output for all objects, including
  394. 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
  395. `output_dict_per_obj` for this frame) or leave them as placeholder values
  396. (if they don't exist in `output_dict_per_obj` for this frame);
  397. 2) if specified, rerun memory encoder after apply non-overlapping constraints
  398. on the object scores.
  399. """
  400. batch_size = self._get_obj_num(inference_state)
  401. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  402. # Optionally, we allow consolidating the temporary outputs at the original
  403. # video resolution (to provide a better editing experience for mask prompts).
  404. if consolidate_at_video_res:
  405. assert not run_mem_encoder, "memory encoder cannot run at video resolution"
  406. consolidated_H = inference_state["video_height"]
  407. consolidated_W = inference_state["video_width"]
  408. consolidated_mask_key = "pred_masks_video_res"
  409. else:
  410. consolidated_H = consolidated_W = self.image_size // 4
  411. consolidated_mask_key = "pred_masks"
  412. # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
  413. # will be added when rerunning the memory encoder after applying non-overlapping
  414. # constraints to object scores. Its "pred_masks" are prefilled with a large
  415. # negative value (NO_OBJ_SCORE) to represent missing objects.
  416. consolidated_out = {
  417. "maskmem_features": None,
  418. "maskmem_pos_enc": None,
  419. consolidated_mask_key: torch.full(
  420. size=(batch_size, 1, consolidated_H, consolidated_W),
  421. fill_value=NO_OBJ_SCORE,
  422. dtype=torch.float32,
  423. device=inference_state["storage_device"],
  424. ),
  425. "obj_ptr": torch.full(
  426. size=(batch_size, self.hidden_dim),
  427. fill_value=NO_OBJ_SCORE,
  428. dtype=torch.float32,
  429. device=inference_state["device"],
  430. ),
  431. }
  432. empty_mask_ptr = None
  433. for obj_idx in range(batch_size):
  434. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  435. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  436. out = obj_temp_output_dict[storage_key].get(frame_idx, None)
  437. # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
  438. # we fall back and look up its previous output in "output_dict_per_obj".
  439. # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
  440. # "output_dict_per_obj" to find a previous output for this object.
  441. if out is None:
  442. out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
  443. if out is None:
  444. out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
  445. # If the object doesn't appear in "output_dict_per_obj" either, we skip it
  446. # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
  447. # placeholder above) and set its object pointer to be a dummy pointer.
  448. if out is None:
  449. # Fill in dummy object pointers for those objects without any inputs or
  450. # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
  451. # i.e. when we need to build the memory for tracking).
  452. if run_mem_encoder:
  453. if empty_mask_ptr is None:
  454. empty_mask_ptr = self._get_empty_mask_ptr(
  455. inference_state, frame_idx
  456. )
  457. # fill object pointer with a dummy pointer (based on an empty mask)
  458. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
  459. continue
  460. # Add the temporary object output mask to consolidated output mask
  461. obj_mask = out["pred_masks"]
  462. consolidated_pred_masks = consolidated_out[consolidated_mask_key]
  463. if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
  464. consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
  465. else:
  466. # Resize first if temporary object mask has a different resolution
  467. resized_obj_mask = torch.nn.functional.interpolate(
  468. obj_mask,
  469. size=consolidated_pred_masks.shape[-2:],
  470. mode="bilinear",
  471. align_corners=False,
  472. )
  473. consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
  474. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
  475. # Optionally, apply non-overlapping constraints on the consolidated scores
  476. # and rerun the memory encoder
  477. if run_mem_encoder:
  478. device = inference_state["device"]
  479. high_res_masks = torch.nn.functional.interpolate(
  480. consolidated_out["pred_masks"].to(device, non_blocking=True),
  481. size=(self.image_size, self.image_size),
  482. mode="bilinear",
  483. align_corners=False,
  484. )
  485. if self.non_overlap_masks_for_mem_enc:
  486. high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
  487. maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
  488. inference_state=inference_state,
  489. frame_idx=frame_idx,
  490. batch_size=batch_size,
  491. high_res_masks=high_res_masks,
  492. is_mask_from_pts=True, # these frames are what the user interacted with
  493. )
  494. consolidated_out["maskmem_features"] = maskmem_features
  495. consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
  496. return consolidated_out
  497. def _get_empty_mask_ptr(self, inference_state, frame_idx):
  498. """Get a dummy object pointer based on an empty mask on the current frame."""
  499. # A dummy (empty) mask with a single object
  500. batch_size = 1
  501. mask_inputs = torch.zeros(
  502. (batch_size, 1, self.image_size, self.image_size),
  503. dtype=torch.float32,
  504. device=inference_state["device"],
  505. )
  506. # Retrieve correct image features
  507. (
  508. _,
  509. _,
  510. current_vision_feats,
  511. current_vision_pos_embeds,
  512. feat_sizes,
  513. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  514. # Feed the empty mask and image feature above to get a dummy object pointer
  515. current_out = self.track_step(
  516. frame_idx=frame_idx,
  517. is_init_cond_frame=True,
  518. current_vision_feats=current_vision_feats,
  519. current_vision_pos_embeds=current_vision_pos_embeds,
  520. feat_sizes=feat_sizes,
  521. point_inputs=None,
  522. mask_inputs=mask_inputs,
  523. output_dict={},
  524. num_frames=inference_state["num_frames"],
  525. track_in_reverse=False,
  526. run_mem_encoder=False,
  527. prev_sam_mask_logits=None,
  528. )
  529. return current_out["obj_ptr"]
  530. @torch.inference_mode()
  531. def propagate_in_video_preflight(self, inference_state):
  532. """Prepare inference_state and consolidate temporary outputs before tracking."""
  533. # Tracking has started and we don't allow adding new objects until session is reset.
  534. inference_state["tracking_has_started"] = True
  535. batch_size = self._get_obj_num(inference_state)
  536. # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
  537. # add them into "output_dict".
  538. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  539. output_dict = inference_state["output_dict"]
  540. # "consolidated_frame_inds" contains indices of those frames where consolidated
  541. # temporary outputs have been added (either in this call or any previous calls
  542. # to `propagate_in_video_preflight`).
  543. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  544. for is_cond in [False, True]:
  545. # Separately consolidate conditioning and non-conditioning temp outptus
  546. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  547. # Find all the frames that contain temporary outputs for any objects
  548. # (these should be the frames that have just received clicks for mask inputs
  549. # via `add_new_points_or_box` or `add_new_mask`)
  550. temp_frame_inds = set()
  551. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  552. temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
  553. consolidated_frame_inds[storage_key].update(temp_frame_inds)
  554. # consolidate the temprary output across all objects on this frame
  555. for frame_idx in temp_frame_inds:
  556. consolidated_out = self._consolidate_temp_output_across_obj(
  557. inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
  558. )
  559. # merge them into "output_dict" and also create per-object slices
  560. output_dict[storage_key][frame_idx] = consolidated_out
  561. self._add_output_per_object(
  562. inference_state, frame_idx, consolidated_out, storage_key
  563. )
  564. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  565. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  566. )
  567. if clear_non_cond_mem:
  568. # clear non-conditioning memory of the surrounding frames
  569. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  570. # clear temporary outputs in `temp_output_dict_per_obj`
  571. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  572. obj_temp_output_dict[storage_key].clear()
  573. # edge case: if an output is added to "cond_frame_outputs", we remove any prior
  574. # output on the same frame in "non_cond_frame_outputs"
  575. for frame_idx in output_dict["cond_frame_outputs"]:
  576. output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  577. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  578. for frame_idx in obj_output_dict["cond_frame_outputs"]:
  579. obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  580. for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  581. assert frame_idx in output_dict["cond_frame_outputs"]
  582. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  583. # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
  584. # with either points or mask inputs (which should be true under a correct workflow).
  585. all_consolidated_frame_inds = (
  586. consolidated_frame_inds["cond_frame_outputs"]
  587. | consolidated_frame_inds["non_cond_frame_outputs"]
  588. )
  589. input_frames_inds = set()
  590. for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
  591. input_frames_inds.update(point_inputs_per_frame.keys())
  592. for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
  593. input_frames_inds.update(mask_inputs_per_frame.keys())
  594. assert all_consolidated_frame_inds == input_frames_inds
  595. @torch.inference_mode()
  596. def propagate_in_video(
  597. self,
  598. inference_state,
  599. start_frame_idx=None,
  600. max_frame_num_to_track=None,
  601. reverse=False,
  602. ):
  603. """Propagate the input points across frames to track in the entire video."""
  604. self.propagate_in_video_preflight(inference_state)
  605. output_dict = inference_state["output_dict"]
  606. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  607. obj_ids = inference_state["obj_ids"]
  608. num_frames = inference_state["num_frames"]
  609. batch_size = self._get_obj_num(inference_state)
  610. if len(output_dict["cond_frame_outputs"]) == 0:
  611. raise RuntimeError("No points are provided; please add points first")
  612. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  613. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  614. )
  615. # set start index, end index, and processing order
  616. if start_frame_idx is None:
  617. # default: start from the earliest frame with input points
  618. start_frame_idx = min(output_dict["cond_frame_outputs"])
  619. if max_frame_num_to_track is None:
  620. # default: track all the frames in the video
  621. max_frame_num_to_track = num_frames
  622. if reverse:
  623. end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
  624. if start_frame_idx > 0:
  625. processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
  626. else:
  627. processing_order = [] # skip reverse tracking if starting from frame 0
  628. else:
  629. end_frame_idx = min(
  630. start_frame_idx + max_frame_num_to_track, num_frames - 1
  631. )
  632. processing_order = range(start_frame_idx, end_frame_idx + 1)
  633. for frame_idx in tqdm(processing_order, desc="propagate in video"):
  634. # We skip those frames already in consolidated outputs (these are frames
  635. # that received input clicks or mask). Note that we cannot directly run
  636. # batched forward on them via `_run_single_frame_inference` because the
  637. # number of clicks on each object might be different.
  638. if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  639. storage_key = "cond_frame_outputs"
  640. current_out = output_dict[storage_key][frame_idx]
  641. pred_masks = current_out["pred_masks"]
  642. if clear_non_cond_mem:
  643. # clear non-conditioning memory of the surrounding frames
  644. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  645. elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
  646. storage_key = "non_cond_frame_outputs"
  647. current_out = output_dict[storage_key][frame_idx]
  648. pred_masks = current_out["pred_masks"]
  649. else:
  650. storage_key = "non_cond_frame_outputs"
  651. current_out, pred_masks = self._run_single_frame_inference(
  652. inference_state=inference_state,
  653. output_dict=output_dict,
  654. frame_idx=frame_idx,
  655. batch_size=batch_size,
  656. is_init_cond_frame=False,
  657. point_inputs=None,
  658. mask_inputs=None,
  659. reverse=reverse,
  660. run_mem_encoder=True,
  661. )
  662. output_dict[storage_key][frame_idx] = current_out
  663. # Create slices of per-object outputs for subsequent interaction with each
  664. # individual object after tracking.
  665. self._add_output_per_object(
  666. inference_state, frame_idx, current_out, storage_key
  667. )
  668. inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
  669. # Resize the output mask to the original video resolution (we directly use
  670. # the mask scores on GPU for output to avoid any CPU conversion in between)
  671. _, video_res_masks = self._get_orig_video_res_output(
  672. inference_state, pred_masks
  673. )
  674. yield frame_idx, obj_ids, video_res_masks
  675. def _add_output_per_object(
  676. self, inference_state, frame_idx, current_out, storage_key
  677. ):
  678. """
  679. Split a multi-object output into per-object output slices and add them into
  680. `output_dict_per_obj`. The resulting slices share the same tensor storage.
  681. """
  682. maskmem_features = current_out["maskmem_features"]
  683. assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
  684. maskmem_pos_enc = current_out["maskmem_pos_enc"]
  685. assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
  686. output_dict_per_obj = inference_state["output_dict_per_obj"]
  687. for obj_idx, obj_output_dict in output_dict_per_obj.items():
  688. obj_slice = slice(obj_idx, obj_idx + 1)
  689. obj_out = {
  690. "maskmem_features": None,
  691. "maskmem_pos_enc": None,
  692. "pred_masks": current_out["pred_masks"][obj_slice],
  693. "obj_ptr": current_out["obj_ptr"][obj_slice],
  694. }
  695. if maskmem_features is not None:
  696. obj_out["maskmem_features"] = maskmem_features[obj_slice]
  697. if maskmem_pos_enc is not None:
  698. obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
  699. obj_output_dict[storage_key][frame_idx] = obj_out
  700. @torch.inference_mode()
  701. def reset_state(self, inference_state):
  702. """Remove all input points or mask in all frames throughout the video."""
  703. self._reset_tracking_results(inference_state)
  704. # Remove all object ids
  705. inference_state["obj_id_to_idx"].clear()
  706. inference_state["obj_idx_to_id"].clear()
  707. inference_state["obj_ids"].clear()
  708. inference_state["point_inputs_per_obj"].clear()
  709. inference_state["mask_inputs_per_obj"].clear()
  710. inference_state["output_dict_per_obj"].clear()
  711. inference_state["temp_output_dict_per_obj"].clear()
  712. def _reset_tracking_results(self, inference_state):
  713. """Reset all tracking inputs and results across the videos."""
  714. for v in inference_state["point_inputs_per_obj"].values():
  715. v.clear()
  716. for v in inference_state["mask_inputs_per_obj"].values():
  717. v.clear()
  718. for v in inference_state["output_dict_per_obj"].values():
  719. v["cond_frame_outputs"].clear()
  720. v["non_cond_frame_outputs"].clear()
  721. for v in inference_state["temp_output_dict_per_obj"].values():
  722. v["cond_frame_outputs"].clear()
  723. v["non_cond_frame_outputs"].clear()
  724. inference_state["output_dict"]["cond_frame_outputs"].clear()
  725. inference_state["output_dict"]["non_cond_frame_outputs"].clear()
  726. inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
  727. inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
  728. inference_state["tracking_has_started"] = False
  729. inference_state["frames_already_tracked"].clear()
  730. def _get_image_feature(self, inference_state, frame_idx, batch_size):
  731. """Compute the image features on a given frame."""
  732. # Look up in the cache first
  733. image, backbone_out = inference_state["cached_features"].get(
  734. frame_idx, (None, None)
  735. )
  736. if backbone_out is None:
  737. # Cache miss -- we will run inference on a single image
  738. image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
  739. backbone_out = self.forward_image(image)
  740. # Cache the most recent frame's feature (for repeated interactions with
  741. # a frame; we can use an LRU cache for more frames in the future).
  742. inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
  743. # expand the features to have the same dimension as the number of objects
  744. expanded_image = image.expand(batch_size, -1, -1, -1)
  745. expanded_backbone_out = {
  746. "backbone_fpn": backbone_out["backbone_fpn"].copy(),
  747. "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
  748. }
  749. for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
  750. expanded_backbone_out["backbone_fpn"][i] = feat.expand(
  751. batch_size, -1, -1, -1
  752. )
  753. for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
  754. pos = pos.expand(batch_size, -1, -1, -1)
  755. expanded_backbone_out["vision_pos_enc"][i] = pos
  756. features = self._prepare_backbone_features(expanded_backbone_out)
  757. features = (expanded_image,) + features
  758. return features
  759. def _run_single_frame_inference(
  760. self,
  761. inference_state,
  762. output_dict,
  763. frame_idx,
  764. batch_size,
  765. is_init_cond_frame,
  766. point_inputs,
  767. mask_inputs,
  768. reverse,
  769. run_mem_encoder,
  770. prev_sam_mask_logits=None,
  771. ):
  772. """Run tracking on a single frame based on current inputs and previous memory."""
  773. # Retrieve correct image features
  774. (
  775. _,
  776. _,
  777. current_vision_feats,
  778. current_vision_pos_embeds,
  779. feat_sizes,
  780. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  781. # point and mask should not appear as input simultaneously on the same frame
  782. assert point_inputs is None or mask_inputs is None
  783. current_out = self.track_step(
  784. frame_idx=frame_idx,
  785. is_init_cond_frame=is_init_cond_frame,
  786. current_vision_feats=current_vision_feats,
  787. current_vision_pos_embeds=current_vision_pos_embeds,
  788. feat_sizes=feat_sizes,
  789. point_inputs=point_inputs,
  790. mask_inputs=mask_inputs,
  791. output_dict=output_dict,
  792. num_frames=inference_state["num_frames"],
  793. track_in_reverse=reverse,
  794. run_mem_encoder=run_mem_encoder,
  795. prev_sam_mask_logits=prev_sam_mask_logits,
  796. )
  797. # optionally offload the output to CPU memory to save GPU space
  798. storage_device = inference_state["storage_device"]
  799. maskmem_features = current_out["maskmem_features"]
  800. if maskmem_features is not None:
  801. maskmem_features = maskmem_features.to(torch.bfloat16)
  802. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  803. pred_masks_gpu = current_out["pred_masks"]
  804. # potentially fill holes in the predicted masks
  805. if self.fill_hole_area > 0:
  806. pred_masks_gpu = fill_holes_in_mask_scores(
  807. pred_masks_gpu, self.fill_hole_area
  808. )
  809. pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
  810. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  811. maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
  812. # object pointer is a small tensor, so we always keep it on GPU memory for fast access
  813. obj_ptr = current_out["obj_ptr"]
  814. # make a compact version of this frame's output to reduce the state size
  815. compact_current_out = {
  816. "maskmem_features": maskmem_features,
  817. "maskmem_pos_enc": maskmem_pos_enc,
  818. "pred_masks": pred_masks,
  819. "obj_ptr": obj_ptr,
  820. }
  821. return compact_current_out, pred_masks_gpu
  822. def _run_memory_encoder(
  823. self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
  824. ):
  825. """
  826. Run the memory encoder on `high_res_masks`. This is usually after applying
  827. non-overlapping constraints to object scores. Since their scores changed, their
  828. memory also need to be computed again with the memory encoder.
  829. """
  830. # Retrieve correct image features
  831. _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
  832. inference_state, frame_idx, batch_size
  833. )
  834. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  835. current_vision_feats=current_vision_feats,
  836. feat_sizes=feat_sizes,
  837. pred_masks_high_res=high_res_masks,
  838. is_mask_from_pts=is_mask_from_pts,
  839. )
  840. # optionally offload the output to CPU memory to save GPU space
  841. storage_device = inference_state["storage_device"]
  842. maskmem_features = maskmem_features.to(torch.bfloat16)
  843. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  844. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  845. maskmem_pos_enc = self._get_maskmem_pos_enc(
  846. inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
  847. )
  848. return maskmem_features, maskmem_pos_enc
  849. def _get_maskmem_pos_enc(self, inference_state, current_out):
  850. """
  851. `maskmem_pos_enc` is the same across frames and objects, so we cache it as
  852. a constant in the inference session to reduce session storage size.
  853. """
  854. model_constants = inference_state["constants"]
  855. # "out_maskmem_pos_enc" should be either a list of tensors or None
  856. out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
  857. if out_maskmem_pos_enc is not None:
  858. if "maskmem_pos_enc" not in model_constants:
  859. assert isinstance(out_maskmem_pos_enc, list)
  860. # only take the slice for one object, since it's same across objects
  861. maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
  862. model_constants["maskmem_pos_enc"] = maskmem_pos_enc
  863. else:
  864. maskmem_pos_enc = model_constants["maskmem_pos_enc"]
  865. # expand the cached maskmem_pos_enc to the actual batch size
  866. batch_size = out_maskmem_pos_enc[0].size(0)
  867. expanded_maskmem_pos_enc = [
  868. x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
  869. ]
  870. else:
  871. expanded_maskmem_pos_enc = None
  872. return expanded_maskmem_pos_enc
  873. def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
  874. """
  875. Remove the non-conditioning memory around the input frame. When users provide
  876. correction clicks, the surrounding frames' non-conditioning memories can still
  877. contain outdated object appearance information and could confuse the model.
  878. This method clears those non-conditioning memories surrounding the interacted
  879. frame to avoid giving the model both old and new information about the object.
  880. """
  881. r = self.memory_temporal_stride_for_eval
  882. frame_idx_begin = frame_idx - r * self.num_maskmem
  883. frame_idx_end = frame_idx + r * self.num_maskmem
  884. output_dict = inference_state["output_dict"]
  885. non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
  886. for t in range(frame_idx_begin, frame_idx_end + 1):
  887. non_cond_frame_outputs.pop(t, None)
  888. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  889. obj_output_dict["non_cond_frame_outputs"].pop(t, None)