sam2_video_predictor.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from collections import OrderedDict
  6. import torch
  7. from tqdm import tqdm
  8. from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
  9. from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
  10. class SAM2VideoPredictor(SAM2Base):
  11. """The predictor class to handle user interactions and manage inference states."""
  12. def __init__(
  13. self,
  14. fill_hole_area=0,
  15. # whether to apply non-overlapping constraints on the output object masks
  16. non_overlap_masks=False,
  17. # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
  18. # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
  19. clear_non_cond_mem_around_input=False,
  20. # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
  21. clear_non_cond_mem_for_multi_obj=False,
  22. **kwargs,
  23. ):
  24. super().__init__(**kwargs)
  25. self.fill_hole_area = fill_hole_area
  26. self.non_overlap_masks = non_overlap_masks
  27. self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
  28. self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
  29. @torch.inference_mode()
  30. def init_state(
  31. self,
  32. video_path,
  33. offload_video_to_cpu=False,
  34. offload_state_to_cpu=False,
  35. async_loading_frames=False,
  36. ):
  37. """Initialize a inference state."""
  38. images, video_height, video_width = load_video_frames(
  39. video_path=video_path,
  40. image_size=self.image_size,
  41. offload_video_to_cpu=offload_video_to_cpu,
  42. async_loading_frames=async_loading_frames,
  43. )
  44. inference_state = {}
  45. inference_state["images"] = images
  46. inference_state["num_frames"] = len(images)
  47. # whether to offload the video frames to CPU memory
  48. # turning on this option saves the GPU memory with only a very small overhead
  49. inference_state["offload_video_to_cpu"] = offload_video_to_cpu
  50. # whether to offload the inference state to CPU memory
  51. # turning on this option saves the GPU memory at the cost of a lower tracking fps
  52. # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
  53. # and from 24 to 21 when tracking two objects)
  54. inference_state["offload_state_to_cpu"] = offload_state_to_cpu
  55. # the original video height and width, used for resizing final output scores
  56. inference_state["video_height"] = video_height
  57. inference_state["video_width"] = video_width
  58. inference_state["device"] = torch.device("cuda")
  59. if offload_state_to_cpu:
  60. inference_state["storage_device"] = torch.device("cpu")
  61. else:
  62. inference_state["storage_device"] = torch.device("cuda")
  63. # inputs on each frame
  64. inference_state["point_inputs_per_obj"] = {}
  65. inference_state["mask_inputs_per_obj"] = {}
  66. # visual features on a small number of recently visited frames for quick interactions
  67. inference_state["cached_features"] = {}
  68. # values that don't change across frames (so we only need to hold one copy of them)
  69. inference_state["constants"] = {}
  70. # mapping between client-side object id and model-side object index
  71. inference_state["obj_id_to_idx"] = OrderedDict()
  72. inference_state["obj_idx_to_id"] = OrderedDict()
  73. inference_state["obj_ids"] = []
  74. # A storage to hold the model's tracking results and states on each frame
  75. inference_state["output_dict"] = {
  76. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  77. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  78. }
  79. # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
  80. inference_state["output_dict_per_obj"] = {}
  81. # A temporary storage to hold new outputs when user interact with a frame
  82. # to add clicks or mask (it's merged into "output_dict" before propagation starts)
  83. inference_state["temp_output_dict_per_obj"] = {}
  84. # Frames that already holds consolidated outputs from click or mask inputs
  85. # (we directly use their consolidated outputs during tracking)
  86. inference_state["consolidated_frame_inds"] = {
  87. "cond_frame_outputs": set(), # set containing frame indices
  88. "non_cond_frame_outputs": set(), # set containing frame indices
  89. }
  90. # metadata for each tracking frame (e.g. which direction it's tracked)
  91. inference_state["tracking_has_started"] = False
  92. inference_state["frames_already_tracked"] = {}
  93. # Warm up the visual backbone and cache the image feature on frame 0
  94. self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
  95. return inference_state
  96. @classmethod
  97. def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
  98. """
  99. Load a pretrained model from the Hugging Face model hub.
  100. Arguments:
  101. model_id (str): The Hugging Face repository ID.
  102. **kwargs: Additional arguments to pass to the model constructor.
  103. Returns:
  104. (SAM2ImagePredictor): The loaded model.
  105. """
  106. from sam2.build_sam import build_sam2_video_predictor_hf
  107. sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
  108. return cls(sam_model)
  109. def _obj_id_to_idx(self, inference_state, obj_id):
  110. """Map client-side object id to model-side object index."""
  111. obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
  112. if obj_idx is not None:
  113. return obj_idx
  114. # This is a new object id not sent to the server before. We only allow adding
  115. # new objects *before* the tracking starts.
  116. allow_new_object = not inference_state["tracking_has_started"]
  117. if allow_new_object:
  118. # get the next object slot
  119. obj_idx = len(inference_state["obj_id_to_idx"])
  120. inference_state["obj_id_to_idx"][obj_id] = obj_idx
  121. inference_state["obj_idx_to_id"][obj_idx] = obj_id
  122. inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
  123. # set up input and output structures for this object
  124. inference_state["point_inputs_per_obj"][obj_idx] = {}
  125. inference_state["mask_inputs_per_obj"][obj_idx] = {}
  126. inference_state["output_dict_per_obj"][obj_idx] = {
  127. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  128. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  129. }
  130. inference_state["temp_output_dict_per_obj"][obj_idx] = {
  131. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  132. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  133. }
  134. return obj_idx
  135. else:
  136. raise RuntimeError(
  137. f"Cannot add new object id {obj_id} after tracking starts. "
  138. f"All existing object ids: {inference_state['obj_ids']}. "
  139. f"Please call 'reset_state' to restart from scratch."
  140. )
  141. def _obj_idx_to_id(self, inference_state, obj_idx):
  142. """Map model-side object index to client-side object id."""
  143. return inference_state["obj_idx_to_id"][obj_idx]
  144. def _get_obj_num(self, inference_state):
  145. """Get the total number of unique object ids received so far in this session."""
  146. return len(inference_state["obj_idx_to_id"])
  147. @torch.inference_mode()
  148. def add_new_points(
  149. self,
  150. inference_state,
  151. frame_idx,
  152. obj_id,
  153. points,
  154. labels,
  155. clear_old_points=True,
  156. normalize_coords=True,
  157. ):
  158. """Add new points to a frame."""
  159. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  160. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  161. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  162. if not isinstance(points, torch.Tensor):
  163. points = torch.tensor(points, dtype=torch.float32)
  164. if not isinstance(labels, torch.Tensor):
  165. labels = torch.tensor(labels, dtype=torch.int32)
  166. if points.dim() == 2:
  167. points = points.unsqueeze(0) # add batch dimension
  168. if labels.dim() == 1:
  169. labels = labels.unsqueeze(0) # add batch dimension
  170. if normalize_coords:
  171. video_H = inference_state["video_height"]
  172. video_W = inference_state["video_width"]
  173. points = points / torch.tensor([video_W, video_H]).to(points.device)
  174. # scale the (normalized) coordinates by the model's internal image size
  175. points = points * self.image_size
  176. points = points.to(inference_state["device"])
  177. labels = labels.to(inference_state["device"])
  178. if not clear_old_points:
  179. point_inputs = point_inputs_per_frame.get(frame_idx, None)
  180. else:
  181. point_inputs = None
  182. point_inputs = concat_points(point_inputs, points, labels)
  183. point_inputs_per_frame[frame_idx] = point_inputs
  184. mask_inputs_per_frame.pop(frame_idx, None)
  185. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  186. # frame, meaning that the inputs points are to generate segments on this frame without
  187. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  188. # the input points will be used to correct the already tracked masks.
  189. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  190. # whether to track in reverse time order
  191. if is_init_cond_frame:
  192. reverse = False
  193. else:
  194. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  195. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  196. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  197. # Add a frame to conditioning output if it's an initial conditioning frame or
  198. # if the model sees all frames receiving clicks/mask as conditioning frames.
  199. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  200. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  201. # Get any previously predicted mask logits on this object and feed it along with
  202. # the new clicks into the SAM mask decoder.
  203. prev_sam_mask_logits = None
  204. # lookup temporary output dict first, which contains the most recent output
  205. # (if not found, then lookup conditioning and non-conditioning frame output)
  206. prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
  207. if prev_out is None:
  208. prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
  209. if prev_out is None:
  210. prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
  211. if prev_out is not None and prev_out["pred_masks"] is not None:
  212. prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
  213. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
  214. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
  215. current_out, _ = self._run_single_frame_inference(
  216. inference_state=inference_state,
  217. output_dict=obj_output_dict, # run on the slice of a single object
  218. frame_idx=frame_idx,
  219. batch_size=1, # run on the slice of a single object
  220. is_init_cond_frame=is_init_cond_frame,
  221. point_inputs=point_inputs,
  222. mask_inputs=None,
  223. reverse=reverse,
  224. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  225. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  226. # allows us to enforce non-overlapping constraints on all objects before encoding
  227. # them into memory.
  228. run_mem_encoder=False,
  229. prev_sam_mask_logits=prev_sam_mask_logits,
  230. )
  231. # Add the output to the output dict (to be used as future memory)
  232. obj_temp_output_dict[storage_key][frame_idx] = current_out
  233. # Resize the output mask to the original video resolution
  234. obj_ids = inference_state["obj_ids"]
  235. consolidated_out = self._consolidate_temp_output_across_obj(
  236. inference_state,
  237. frame_idx,
  238. is_cond=is_cond,
  239. run_mem_encoder=False,
  240. consolidate_at_video_res=True,
  241. )
  242. _, video_res_masks = self._get_orig_video_res_output(
  243. inference_state, consolidated_out["pred_masks_video_res"]
  244. )
  245. return frame_idx, obj_ids, video_res_masks
  246. @torch.inference_mode()
  247. def add_new_mask(
  248. self,
  249. inference_state,
  250. frame_idx,
  251. obj_id,
  252. mask,
  253. ):
  254. """Add new mask to a frame."""
  255. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  256. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  257. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  258. if not isinstance(mask, torch.Tensor):
  259. mask = torch.tensor(mask, dtype=torch.bool)
  260. assert mask.dim() == 2
  261. mask_H, mask_W = mask.shape
  262. mask_inputs_orig = mask[None, None] # add batch and channel dimension
  263. mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
  264. # resize the mask if it doesn't match the model's image size
  265. if mask_H != self.image_size or mask_W != self.image_size:
  266. mask_inputs = torch.nn.functional.interpolate(
  267. mask_inputs_orig,
  268. size=(self.image_size, self.image_size),
  269. align_corners=False,
  270. mode="bilinear",
  271. antialias=True, # use antialias for downsampling
  272. )
  273. mask_inputs = (mask_inputs >= 0.5).float()
  274. else:
  275. mask_inputs = mask_inputs_orig
  276. mask_inputs_per_frame[frame_idx] = mask_inputs
  277. point_inputs_per_frame.pop(frame_idx, None)
  278. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  279. # frame, meaning that the inputs points are to generate segments on this frame without
  280. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  281. # the input points will be used to correct the already tracked masks.
  282. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  283. # whether to track in reverse time order
  284. if is_init_cond_frame:
  285. reverse = False
  286. else:
  287. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  288. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  289. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  290. # Add a frame to conditioning output if it's an initial conditioning frame or
  291. # if the model sees all frames receiving clicks/mask as conditioning frames.
  292. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  293. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  294. current_out, _ = self._run_single_frame_inference(
  295. inference_state=inference_state,
  296. output_dict=obj_output_dict, # run on the slice of a single object
  297. frame_idx=frame_idx,
  298. batch_size=1, # run on the slice of a single object
  299. is_init_cond_frame=is_init_cond_frame,
  300. point_inputs=None,
  301. mask_inputs=mask_inputs,
  302. reverse=reverse,
  303. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  304. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  305. # allows us to enforce non-overlapping constraints on all objects before encoding
  306. # them into memory.
  307. run_mem_encoder=False,
  308. )
  309. # Add the output to the output dict (to be used as future memory)
  310. obj_temp_output_dict[storage_key][frame_idx] = current_out
  311. # Resize the output mask to the original video resolution
  312. obj_ids = inference_state["obj_ids"]
  313. consolidated_out = self._consolidate_temp_output_across_obj(
  314. inference_state,
  315. frame_idx,
  316. is_cond=is_cond,
  317. run_mem_encoder=False,
  318. consolidate_at_video_res=True,
  319. )
  320. _, video_res_masks = self._get_orig_video_res_output(
  321. inference_state, consolidated_out["pred_masks_video_res"]
  322. )
  323. return frame_idx, obj_ids, video_res_masks
  324. def _get_orig_video_res_output(self, inference_state, any_res_masks):
  325. """
  326. Resize the object scores to the original video resolution (video_res_masks)
  327. and apply non-overlapping constraints for final output.
  328. """
  329. device = inference_state["device"]
  330. video_H = inference_state["video_height"]
  331. video_W = inference_state["video_width"]
  332. any_res_masks = any_res_masks.to(device, non_blocking=True)
  333. if any_res_masks.shape[-2:] == (video_H, video_W):
  334. video_res_masks = any_res_masks
  335. else:
  336. video_res_masks = torch.nn.functional.interpolate(
  337. any_res_masks,
  338. size=(video_H, video_W),
  339. mode="bilinear",
  340. align_corners=False,
  341. )
  342. if self.non_overlap_masks:
  343. video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
  344. return any_res_masks, video_res_masks
  345. def _consolidate_temp_output_across_obj(
  346. self,
  347. inference_state,
  348. frame_idx,
  349. is_cond,
  350. run_mem_encoder,
  351. consolidate_at_video_res=False,
  352. ):
  353. """
  354. Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
  355. a frame into a single output for all objects, including
  356. 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
  357. `output_dict_per_obj` for this frame) or leave them as placeholder values
  358. (if they don't exist in `output_dict_per_obj` for this frame);
  359. 2) if specified, rerun memory encoder after apply non-overlapping constraints
  360. on the object scores.
  361. """
  362. batch_size = self._get_obj_num(inference_state)
  363. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  364. # Optionally, we allow consolidating the temporary outputs at the original
  365. # video resolution (to provide a better editing experience for mask prompts).
  366. if consolidate_at_video_res:
  367. assert not run_mem_encoder, "memory encoder cannot run at video resolution"
  368. consolidated_H = inference_state["video_height"]
  369. consolidated_W = inference_state["video_width"]
  370. consolidated_mask_key = "pred_masks_video_res"
  371. else:
  372. consolidated_H = consolidated_W = self.image_size // 4
  373. consolidated_mask_key = "pred_masks"
  374. # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
  375. # will be added when rerunning the memory encoder after applying non-overlapping
  376. # constraints to object scores. Its "pred_masks" are prefilled with a large
  377. # negative value (NO_OBJ_SCORE) to represent missing objects.
  378. consolidated_out = {
  379. "maskmem_features": None,
  380. "maskmem_pos_enc": None,
  381. consolidated_mask_key: torch.full(
  382. size=(batch_size, 1, consolidated_H, consolidated_W),
  383. fill_value=NO_OBJ_SCORE,
  384. dtype=torch.float32,
  385. device=inference_state["storage_device"],
  386. ),
  387. "obj_ptr": torch.full(
  388. size=(batch_size, self.hidden_dim),
  389. fill_value=NO_OBJ_SCORE,
  390. dtype=torch.float32,
  391. device=inference_state["device"],
  392. ),
  393. }
  394. empty_mask_ptr = None
  395. for obj_idx in range(batch_size):
  396. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  397. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  398. out = obj_temp_output_dict[storage_key].get(frame_idx, None)
  399. # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
  400. # we fall back and look up its previous output in "output_dict_per_obj".
  401. # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
  402. # "output_dict_per_obj" to find a previous output for this object.
  403. if out is None:
  404. out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
  405. if out is None:
  406. out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
  407. # If the object doesn't appear in "output_dict_per_obj" either, we skip it
  408. # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
  409. # placeholder above) and set its object pointer to be a dummy pointer.
  410. if out is None:
  411. # Fill in dummy object pointers for those objects without any inputs or
  412. # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
  413. # i.e. when we need to build the memory for tracking).
  414. if run_mem_encoder:
  415. if empty_mask_ptr is None:
  416. empty_mask_ptr = self._get_empty_mask_ptr(
  417. inference_state, frame_idx
  418. )
  419. # fill object pointer with a dummy pointer (based on an empty mask)
  420. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
  421. continue
  422. # Add the temporary object output mask to consolidated output mask
  423. obj_mask = out["pred_masks"]
  424. consolidated_pred_masks = consolidated_out[consolidated_mask_key]
  425. if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
  426. consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
  427. else:
  428. # Resize first if temporary object mask has a different resolution
  429. resized_obj_mask = torch.nn.functional.interpolate(
  430. obj_mask,
  431. size=consolidated_pred_masks.shape[-2:],
  432. mode="bilinear",
  433. align_corners=False,
  434. )
  435. consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
  436. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
  437. # Optionally, apply non-overlapping constraints on the consolidated scores
  438. # and rerun the memory encoder
  439. if run_mem_encoder:
  440. device = inference_state["device"]
  441. high_res_masks = torch.nn.functional.interpolate(
  442. consolidated_out["pred_masks"].to(device, non_blocking=True),
  443. size=(self.image_size, self.image_size),
  444. mode="bilinear",
  445. align_corners=False,
  446. )
  447. if self.non_overlap_masks_for_mem_enc:
  448. high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
  449. maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
  450. inference_state=inference_state,
  451. frame_idx=frame_idx,
  452. batch_size=batch_size,
  453. high_res_masks=high_res_masks,
  454. is_mask_from_pts=True, # these frames are what the user interacted with
  455. )
  456. consolidated_out["maskmem_features"] = maskmem_features
  457. consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
  458. return consolidated_out
  459. def _get_empty_mask_ptr(self, inference_state, frame_idx):
  460. """Get a dummy object pointer based on an empty mask on the current frame."""
  461. # A dummy (empty) mask with a single object
  462. batch_size = 1
  463. mask_inputs = torch.zeros(
  464. (batch_size, 1, self.image_size, self.image_size),
  465. dtype=torch.float32,
  466. device=inference_state["device"],
  467. )
  468. # Retrieve correct image features
  469. (
  470. _,
  471. _,
  472. current_vision_feats,
  473. current_vision_pos_embeds,
  474. feat_sizes,
  475. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  476. # Feed the empty mask and image feature above to get a dummy object pointer
  477. current_out = self.track_step(
  478. frame_idx=frame_idx,
  479. is_init_cond_frame=True,
  480. current_vision_feats=current_vision_feats,
  481. current_vision_pos_embeds=current_vision_pos_embeds,
  482. feat_sizes=feat_sizes,
  483. point_inputs=None,
  484. mask_inputs=mask_inputs,
  485. output_dict={},
  486. num_frames=inference_state["num_frames"],
  487. track_in_reverse=False,
  488. run_mem_encoder=False,
  489. prev_sam_mask_logits=None,
  490. )
  491. return current_out["obj_ptr"]
  492. @torch.inference_mode()
  493. def propagate_in_video_preflight(self, inference_state):
  494. """Prepare inference_state and consolidate temporary outputs before tracking."""
  495. # Tracking has started and we don't allow adding new objects until session is reset.
  496. inference_state["tracking_has_started"] = True
  497. batch_size = self._get_obj_num(inference_state)
  498. # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
  499. # add them into "output_dict".
  500. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  501. output_dict = inference_state["output_dict"]
  502. # "consolidated_frame_inds" contains indices of those frames where consolidated
  503. # temporary outputs have been added (either in this call or any previous calls
  504. # to `propagate_in_video_preflight`).
  505. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  506. for is_cond in [False, True]:
  507. # Separately consolidate conditioning and non-conditioning temp outptus
  508. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  509. # Find all the frames that contain temporary outputs for any objects
  510. # (these should be the frames that have just received clicks for mask inputs
  511. # via `add_new_points` or `add_new_mask`)
  512. temp_frame_inds = set()
  513. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  514. temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
  515. consolidated_frame_inds[storage_key].update(temp_frame_inds)
  516. # consolidate the temprary output across all objects on this frame
  517. for frame_idx in temp_frame_inds:
  518. consolidated_out = self._consolidate_temp_output_across_obj(
  519. inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
  520. )
  521. # merge them into "output_dict" and also create per-object slices
  522. output_dict[storage_key][frame_idx] = consolidated_out
  523. self._add_output_per_object(
  524. inference_state, frame_idx, consolidated_out, storage_key
  525. )
  526. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  527. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  528. )
  529. if clear_non_cond_mem:
  530. # clear non-conditioning memory of the surrounding frames
  531. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  532. # clear temporary outputs in `temp_output_dict_per_obj`
  533. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  534. obj_temp_output_dict[storage_key].clear()
  535. # edge case: if an output is added to "cond_frame_outputs", we remove any prior
  536. # output on the same frame in "non_cond_frame_outputs"
  537. for frame_idx in output_dict["cond_frame_outputs"]:
  538. output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  539. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  540. for frame_idx in obj_output_dict["cond_frame_outputs"]:
  541. obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  542. for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  543. assert frame_idx in output_dict["cond_frame_outputs"]
  544. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  545. # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
  546. # with either points or mask inputs (which should be true under a correct workflow).
  547. all_consolidated_frame_inds = (
  548. consolidated_frame_inds["cond_frame_outputs"]
  549. | consolidated_frame_inds["non_cond_frame_outputs"]
  550. )
  551. input_frames_inds = set()
  552. for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
  553. input_frames_inds.update(point_inputs_per_frame.keys())
  554. for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
  555. input_frames_inds.update(mask_inputs_per_frame.keys())
  556. assert all_consolidated_frame_inds == input_frames_inds
  557. @torch.inference_mode()
  558. def propagate_in_video(
  559. self,
  560. inference_state,
  561. start_frame_idx=None,
  562. max_frame_num_to_track=None,
  563. reverse=False,
  564. ):
  565. """Propagate the input points across frames to track in the entire video."""
  566. self.propagate_in_video_preflight(inference_state)
  567. output_dict = inference_state["output_dict"]
  568. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  569. obj_ids = inference_state["obj_ids"]
  570. num_frames = inference_state["num_frames"]
  571. batch_size = self._get_obj_num(inference_state)
  572. if len(output_dict["cond_frame_outputs"]) == 0:
  573. raise RuntimeError("No points are provided; please add points first")
  574. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  575. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  576. )
  577. # set start index, end index, and processing order
  578. if start_frame_idx is None:
  579. # default: start from the earliest frame with input points
  580. start_frame_idx = min(output_dict["cond_frame_outputs"])
  581. if max_frame_num_to_track is None:
  582. # default: track all the frames in the video
  583. max_frame_num_to_track = num_frames
  584. if reverse:
  585. end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
  586. if start_frame_idx > 0:
  587. processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
  588. else:
  589. processing_order = [] # skip reverse tracking if starting from frame 0
  590. else:
  591. end_frame_idx = min(
  592. start_frame_idx + max_frame_num_to_track, num_frames - 1
  593. )
  594. processing_order = range(start_frame_idx, end_frame_idx + 1)
  595. for frame_idx in tqdm(processing_order, desc="propagate in video"):
  596. # We skip those frames already in consolidated outputs (these are frames
  597. # that received input clicks or mask). Note that we cannot directly run
  598. # batched forward on them via `_run_single_frame_inference` because the
  599. # number of clicks on each object might be different.
  600. if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  601. storage_key = "cond_frame_outputs"
  602. current_out = output_dict[storage_key][frame_idx]
  603. pred_masks = current_out["pred_masks"]
  604. if clear_non_cond_mem:
  605. # clear non-conditioning memory of the surrounding frames
  606. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  607. elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
  608. storage_key = "non_cond_frame_outputs"
  609. current_out = output_dict[storage_key][frame_idx]
  610. pred_masks = current_out["pred_masks"]
  611. else:
  612. storage_key = "non_cond_frame_outputs"
  613. current_out, pred_masks = self._run_single_frame_inference(
  614. inference_state=inference_state,
  615. output_dict=output_dict,
  616. frame_idx=frame_idx,
  617. batch_size=batch_size,
  618. is_init_cond_frame=False,
  619. point_inputs=None,
  620. mask_inputs=None,
  621. reverse=reverse,
  622. run_mem_encoder=True,
  623. )
  624. output_dict[storage_key][frame_idx] = current_out
  625. # Create slices of per-object outputs for subsequent interaction with each
  626. # individual object after tracking.
  627. self._add_output_per_object(
  628. inference_state, frame_idx, current_out, storage_key
  629. )
  630. inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
  631. # Resize the output mask to the original video resolution (we directly use
  632. # the mask scores on GPU for output to avoid any CPU conversion in between)
  633. _, video_res_masks = self._get_orig_video_res_output(
  634. inference_state, pred_masks
  635. )
  636. yield frame_idx, obj_ids, video_res_masks
  637. def _add_output_per_object(
  638. self, inference_state, frame_idx, current_out, storage_key
  639. ):
  640. """
  641. Split a multi-object output into per-object output slices and add them into
  642. `output_dict_per_obj`. The resulting slices share the same tensor storage.
  643. """
  644. maskmem_features = current_out["maskmem_features"]
  645. assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
  646. maskmem_pos_enc = current_out["maskmem_pos_enc"]
  647. assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
  648. output_dict_per_obj = inference_state["output_dict_per_obj"]
  649. for obj_idx, obj_output_dict in output_dict_per_obj.items():
  650. obj_slice = slice(obj_idx, obj_idx + 1)
  651. obj_out = {
  652. "maskmem_features": None,
  653. "maskmem_pos_enc": None,
  654. "pred_masks": current_out["pred_masks"][obj_slice],
  655. "obj_ptr": current_out["obj_ptr"][obj_slice],
  656. }
  657. if maskmem_features is not None:
  658. obj_out["maskmem_features"] = maskmem_features[obj_slice]
  659. if maskmem_pos_enc is not None:
  660. obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
  661. obj_output_dict[storage_key][frame_idx] = obj_out
  662. @torch.inference_mode()
  663. def reset_state(self, inference_state):
  664. """Remove all input points or mask in all frames throughout the video."""
  665. self._reset_tracking_results(inference_state)
  666. # Remove all object ids
  667. inference_state["obj_id_to_idx"].clear()
  668. inference_state["obj_idx_to_id"].clear()
  669. inference_state["obj_ids"].clear()
  670. inference_state["point_inputs_per_obj"].clear()
  671. inference_state["mask_inputs_per_obj"].clear()
  672. inference_state["output_dict_per_obj"].clear()
  673. inference_state["temp_output_dict_per_obj"].clear()
  674. def _reset_tracking_results(self, inference_state):
  675. """Reset all tracking inputs and results across the videos."""
  676. for v in inference_state["point_inputs_per_obj"].values():
  677. v.clear()
  678. for v in inference_state["mask_inputs_per_obj"].values():
  679. v.clear()
  680. for v in inference_state["output_dict_per_obj"].values():
  681. v["cond_frame_outputs"].clear()
  682. v["non_cond_frame_outputs"].clear()
  683. for v in inference_state["temp_output_dict_per_obj"].values():
  684. v["cond_frame_outputs"].clear()
  685. v["non_cond_frame_outputs"].clear()
  686. inference_state["output_dict"]["cond_frame_outputs"].clear()
  687. inference_state["output_dict"]["non_cond_frame_outputs"].clear()
  688. inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
  689. inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
  690. inference_state["tracking_has_started"] = False
  691. inference_state["frames_already_tracked"].clear()
  692. def _get_image_feature(self, inference_state, frame_idx, batch_size):
  693. """Compute the image features on a given frame."""
  694. # Look up in the cache first
  695. image, backbone_out = inference_state["cached_features"].get(
  696. frame_idx, (None, None)
  697. )
  698. if backbone_out is None:
  699. # Cache miss -- we will run inference on a single image
  700. image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
  701. backbone_out = self.forward_image(image)
  702. # Cache the most recent frame's feature (for repeated interactions with
  703. # a frame; we can use an LRU cache for more frames in the future).
  704. inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
  705. # expand the features to have the same dimension as the number of objects
  706. expanded_image = image.expand(batch_size, -1, -1, -1)
  707. expanded_backbone_out = {
  708. "backbone_fpn": backbone_out["backbone_fpn"].copy(),
  709. "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
  710. }
  711. for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
  712. expanded_backbone_out["backbone_fpn"][i] = feat.expand(
  713. batch_size, -1, -1, -1
  714. )
  715. for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
  716. pos = pos.expand(batch_size, -1, -1, -1)
  717. expanded_backbone_out["vision_pos_enc"][i] = pos
  718. features = self._prepare_backbone_features(expanded_backbone_out)
  719. features = (expanded_image,) + features
  720. return features
  721. def _run_single_frame_inference(
  722. self,
  723. inference_state,
  724. output_dict,
  725. frame_idx,
  726. batch_size,
  727. is_init_cond_frame,
  728. point_inputs,
  729. mask_inputs,
  730. reverse,
  731. run_mem_encoder,
  732. prev_sam_mask_logits=None,
  733. ):
  734. """Run tracking on a single frame based on current inputs and previous memory."""
  735. # Retrieve correct image features
  736. (
  737. _,
  738. _,
  739. current_vision_feats,
  740. current_vision_pos_embeds,
  741. feat_sizes,
  742. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  743. # point and mask should not appear as input simultaneously on the same frame
  744. assert point_inputs is None or mask_inputs is None
  745. current_out = self.track_step(
  746. frame_idx=frame_idx,
  747. is_init_cond_frame=is_init_cond_frame,
  748. current_vision_feats=current_vision_feats,
  749. current_vision_pos_embeds=current_vision_pos_embeds,
  750. feat_sizes=feat_sizes,
  751. point_inputs=point_inputs,
  752. mask_inputs=mask_inputs,
  753. output_dict=output_dict,
  754. num_frames=inference_state["num_frames"],
  755. track_in_reverse=reverse,
  756. run_mem_encoder=run_mem_encoder,
  757. prev_sam_mask_logits=prev_sam_mask_logits,
  758. )
  759. # optionally offload the output to CPU memory to save GPU space
  760. storage_device = inference_state["storage_device"]
  761. maskmem_features = current_out["maskmem_features"]
  762. if maskmem_features is not None:
  763. maskmem_features = maskmem_features.to(torch.bfloat16)
  764. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  765. pred_masks_gpu = current_out["pred_masks"]
  766. # potentially fill holes in the predicted masks
  767. if self.fill_hole_area > 0:
  768. pred_masks_gpu = fill_holes_in_mask_scores(
  769. pred_masks_gpu, self.fill_hole_area
  770. )
  771. pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
  772. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  773. maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
  774. # object pointer is a small tensor, so we always keep it on GPU memory for fast access
  775. obj_ptr = current_out["obj_ptr"]
  776. # make a compact version of this frame's output to reduce the state size
  777. compact_current_out = {
  778. "maskmem_features": maskmem_features,
  779. "maskmem_pos_enc": maskmem_pos_enc,
  780. "pred_masks": pred_masks,
  781. "obj_ptr": obj_ptr,
  782. }
  783. return compact_current_out, pred_masks_gpu
  784. def _run_memory_encoder(
  785. self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
  786. ):
  787. """
  788. Run the memory encoder on `high_res_masks`. This is usually after applying
  789. non-overlapping constraints to object scores. Since their scores changed, their
  790. memory also need to be computed again with the memory encoder.
  791. """
  792. # Retrieve correct image features
  793. _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
  794. inference_state, frame_idx, batch_size
  795. )
  796. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  797. current_vision_feats=current_vision_feats,
  798. feat_sizes=feat_sizes,
  799. pred_masks_high_res=high_res_masks,
  800. is_mask_from_pts=is_mask_from_pts,
  801. )
  802. # optionally offload the output to CPU memory to save GPU space
  803. storage_device = inference_state["storage_device"]
  804. maskmem_features = maskmem_features.to(torch.bfloat16)
  805. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  806. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  807. maskmem_pos_enc = self._get_maskmem_pos_enc(
  808. inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
  809. )
  810. return maskmem_features, maskmem_pos_enc
  811. def _get_maskmem_pos_enc(self, inference_state, current_out):
  812. """
  813. `maskmem_pos_enc` is the same across frames and objects, so we cache it as
  814. a constant in the inference session to reduce session storage size.
  815. """
  816. model_constants = inference_state["constants"]
  817. # "out_maskmem_pos_enc" should be either a list of tensors or None
  818. out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
  819. if out_maskmem_pos_enc is not None:
  820. if "maskmem_pos_enc" not in model_constants:
  821. assert isinstance(out_maskmem_pos_enc, list)
  822. # only take the slice for one object, since it's same across objects
  823. maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
  824. model_constants["maskmem_pos_enc"] = maskmem_pos_enc
  825. else:
  826. maskmem_pos_enc = model_constants["maskmem_pos_enc"]
  827. # expand the cached maskmem_pos_enc to the actual batch size
  828. batch_size = out_maskmem_pos_enc[0].size(0)
  829. expanded_maskmem_pos_enc = [
  830. x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
  831. ]
  832. else:
  833. expanded_maskmem_pos_enc = None
  834. return expanded_maskmem_pos_enc
  835. def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
  836. """
  837. Remove the non-conditioning memory around the input frame. When users provide
  838. correction clicks, the surrounding frames' non-conditioning memories can still
  839. contain outdated object appearance information and could confuse the model.
  840. This method clears those non-conditioning memories surrounding the interacted
  841. frame to avoid giving the model both old and new information about the object.
  842. """
  843. r = self.memory_temporal_stride_for_eval
  844. frame_idx_begin = frame_idx - r * self.num_maskmem
  845. frame_idx_end = frame_idx + r * self.num_maskmem
  846. output_dict = inference_state["output_dict"]
  847. non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
  848. for t in range(frame_idx_begin, frame_idx_end + 1):
  849. non_cond_frame_outputs.pop(t, None)
  850. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  851. obj_output_dict["non_cond_frame_outputs"].pop(t, None)