sam2_video_predictor.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from collections import OrderedDict
  6. import torch
  7. from tqdm import tqdm
  8. from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
  9. from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
  10. class SAM2VideoPredictor(SAM2Base):
  11. """The predictor class to handle user interactions and manage inference states."""
  12. def __init__(
  13. self,
  14. fill_hole_area=0,
  15. # whether to apply non-overlapping constraints on the output object masks
  16. non_overlap_masks=False,
  17. # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
  18. # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
  19. clear_non_cond_mem_around_input=False,
  20. # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
  21. clear_non_cond_mem_for_multi_obj=False,
  22. **kwargs,
  23. ):
  24. super().__init__(**kwargs)
  25. self.fill_hole_area = fill_hole_area
  26. self.non_overlap_masks = non_overlap_masks
  27. self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
  28. self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
  29. @torch.inference_mode()
  30. def init_state(
  31. self,
  32. video_path,
  33. offload_video_to_cpu=False,
  34. offload_state_to_cpu=False,
  35. async_loading_frames=False,
  36. ):
  37. """Initialize a inference state."""
  38. images, video_height, video_width = load_video_frames(
  39. video_path=video_path,
  40. image_size=self.image_size,
  41. offload_video_to_cpu=offload_video_to_cpu,
  42. async_loading_frames=async_loading_frames,
  43. )
  44. inference_state = {}
  45. inference_state["images"] = images
  46. inference_state["num_frames"] = len(images)
  47. # whether to offload the video frames to CPU memory
  48. # turning on this option saves the GPU memory with only a very small overhead
  49. inference_state["offload_video_to_cpu"] = offload_video_to_cpu
  50. # whether to offload the inference state to CPU memory
  51. # turning on this option saves the GPU memory at the cost of a lower tracking fps
  52. # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
  53. # and from 24 to 21 when tracking two objects)
  54. inference_state["offload_state_to_cpu"] = offload_state_to_cpu
  55. # the original video height and width, used for resizing final output scores
  56. inference_state["video_height"] = video_height
  57. inference_state["video_width"] = video_width
  58. inference_state["device"] = torch.device("cuda")
  59. if offload_state_to_cpu:
  60. inference_state["storage_device"] = torch.device("cpu")
  61. else:
  62. inference_state["storage_device"] = torch.device("cuda")
  63. # inputs on each frame
  64. inference_state["point_inputs_per_obj"] = {}
  65. inference_state["mask_inputs_per_obj"] = {}
  66. # visual features on a small number of recently visited frames for quick interactions
  67. inference_state["cached_features"] = {}
  68. # values that don't change across frames (so we only need to hold one copy of them)
  69. inference_state["constants"] = {}
  70. # mapping between client-side object id and model-side object index
  71. inference_state["obj_id_to_idx"] = OrderedDict()
  72. inference_state["obj_idx_to_id"] = OrderedDict()
  73. inference_state["obj_ids"] = []
  74. # A storage to hold the model's tracking results and states on each frame
  75. inference_state["output_dict"] = {
  76. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  77. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  78. }
  79. # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
  80. inference_state["output_dict_per_obj"] = {}
  81. # A temporary storage to hold new outputs when user interact with a frame
  82. # to add clicks or mask (it's merged into "output_dict" before propagation starts)
  83. inference_state["temp_output_dict_per_obj"] = {}
  84. # Frames that already holds consolidated outputs from click or mask inputs
  85. # (we directly use their consolidated outputs during tracking)
  86. inference_state["consolidated_frame_inds"] = {
  87. "cond_frame_outputs": set(), # set containing frame indices
  88. "non_cond_frame_outputs": set(), # set containing frame indices
  89. }
  90. # metadata for each tracking frame (e.g. which direction it's tracked)
  91. inference_state["tracking_has_started"] = False
  92. inference_state["frames_already_tracked"] = {}
  93. # Warm up the visual backbone and cache the image feature on frame 0
  94. self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
  95. return inference_state
  96. def _obj_id_to_idx(self, inference_state, obj_id):
  97. """Map client-side object id to model-side object index."""
  98. obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
  99. if obj_idx is not None:
  100. return obj_idx
  101. # This is a new object id not sent to the server before. We only allow adding
  102. # new objects *before* the tracking starts.
  103. allow_new_object = not inference_state["tracking_has_started"]
  104. if allow_new_object:
  105. # get the next object slot
  106. obj_idx = len(inference_state["obj_id_to_idx"])
  107. inference_state["obj_id_to_idx"][obj_id] = obj_idx
  108. inference_state["obj_idx_to_id"][obj_idx] = obj_id
  109. inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
  110. # set up input and output structures for this object
  111. inference_state["point_inputs_per_obj"][obj_idx] = {}
  112. inference_state["mask_inputs_per_obj"][obj_idx] = {}
  113. inference_state["output_dict_per_obj"][obj_idx] = {
  114. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  115. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  116. }
  117. inference_state["temp_output_dict_per_obj"][obj_idx] = {
  118. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  119. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  120. }
  121. return obj_idx
  122. else:
  123. raise RuntimeError(
  124. f"Cannot add new object id {obj_id} after tracking starts. "
  125. f"All existing object ids: {inference_state['obj_ids']}. "
  126. f"Please call 'reset_state' to restart from scratch."
  127. )
  128. def _obj_idx_to_id(self, inference_state, obj_idx):
  129. """Map model-side object index to client-side object id."""
  130. return inference_state["obj_idx_to_id"][obj_idx]
  131. def _get_obj_num(self, inference_state):
  132. """Get the total number of unique object ids received so far in this session."""
  133. return len(inference_state["obj_idx_to_id"])
  134. @torch.inference_mode()
  135. def add_new_points(
  136. self,
  137. inference_state,
  138. frame_idx,
  139. obj_id,
  140. points,
  141. labels,
  142. clear_old_points=True,
  143. normalize_coords=True,
  144. ):
  145. """Add new points to a frame."""
  146. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  147. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  148. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  149. if not isinstance(points, torch.Tensor):
  150. points = torch.tensor(points, dtype=torch.float32)
  151. if not isinstance(labels, torch.Tensor):
  152. labels = torch.tensor(labels, dtype=torch.int32)
  153. if points.dim() == 2:
  154. points = points.unsqueeze(0) # add batch dimension
  155. if labels.dim() == 1:
  156. labels = labels.unsqueeze(0) # add batch dimension
  157. if normalize_coords:
  158. video_H = inference_state["video_height"]
  159. video_W = inference_state["video_width"]
  160. points = points / torch.tensor([video_W, video_H]).to(points.device)
  161. # scale the (normalized) coordinates by the model's internal image size
  162. points = points * self.image_size
  163. points = points.to(inference_state["device"])
  164. labels = labels.to(inference_state["device"])
  165. if not clear_old_points:
  166. point_inputs = point_inputs_per_frame.get(frame_idx, None)
  167. else:
  168. point_inputs = None
  169. point_inputs = concat_points(point_inputs, points, labels)
  170. point_inputs_per_frame[frame_idx] = point_inputs
  171. mask_inputs_per_frame.pop(frame_idx, None)
  172. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  173. # frame, meaning that the inputs points are to generate segments on this frame without
  174. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  175. # the input points will be used to correct the already tracked masks.
  176. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  177. # whether to track in reverse time order
  178. if is_init_cond_frame:
  179. reverse = False
  180. else:
  181. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  182. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  183. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  184. # Add a frame to conditioning output if it's an initial conditioning frame or
  185. # if the model sees all frames receiving clicks/mask as conditioning frames.
  186. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  187. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  188. # Get any previously predicted mask logits on this object and feed it along with
  189. # the new clicks into the SAM mask decoder.
  190. prev_sam_mask_logits = None
  191. # lookup temporary output dict first, which contains the most recent output
  192. # (if not found, then lookup conditioning and non-conditioning frame output)
  193. prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
  194. if prev_out is None:
  195. prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
  196. if prev_out is None:
  197. prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
  198. if prev_out is not None and prev_out["pred_masks"] is not None:
  199. prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
  200. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
  201. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
  202. current_out, _ = self._run_single_frame_inference(
  203. inference_state=inference_state,
  204. output_dict=obj_output_dict, # run on the slice of a single object
  205. frame_idx=frame_idx,
  206. batch_size=1, # run on the slice of a single object
  207. is_init_cond_frame=is_init_cond_frame,
  208. point_inputs=point_inputs,
  209. mask_inputs=None,
  210. reverse=reverse,
  211. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  212. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  213. # allows us to enforce non-overlapping constraints on all objects before encoding
  214. # them into memory.
  215. run_mem_encoder=False,
  216. prev_sam_mask_logits=prev_sam_mask_logits,
  217. )
  218. # Add the output to the output dict (to be used as future memory)
  219. obj_temp_output_dict[storage_key][frame_idx] = current_out
  220. # Resize the output mask to the original video resolution
  221. obj_ids = inference_state["obj_ids"]
  222. consolidated_out = self._consolidate_temp_output_across_obj(
  223. inference_state,
  224. frame_idx,
  225. is_cond=is_cond,
  226. run_mem_encoder=False,
  227. consolidate_at_video_res=True,
  228. )
  229. _, video_res_masks = self._get_orig_video_res_output(
  230. inference_state, consolidated_out["pred_masks_video_res"]
  231. )
  232. return frame_idx, obj_ids, video_res_masks
  233. @torch.inference_mode()
  234. def add_new_mask(
  235. self,
  236. inference_state,
  237. frame_idx,
  238. obj_id,
  239. mask,
  240. ):
  241. """Add new mask to a frame."""
  242. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  243. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  244. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  245. if not isinstance(mask, torch.Tensor):
  246. mask = torch.tensor(mask, dtype=torch.bool)
  247. assert mask.dim() == 2
  248. mask_H, mask_W = mask.shape
  249. mask_inputs_orig = mask[None, None] # add batch and channel dimension
  250. mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
  251. # resize the mask if it doesn't match the model's image size
  252. if mask_H != self.image_size or mask_W != self.image_size:
  253. mask_inputs = torch.nn.functional.interpolate(
  254. mask_inputs_orig,
  255. size=(self.image_size, self.image_size),
  256. align_corners=False,
  257. mode="bilinear",
  258. antialias=True, # use antialias for downsampling
  259. )
  260. mask_inputs = (mask_inputs >= 0.5).float()
  261. else:
  262. mask_inputs = mask_inputs_orig
  263. mask_inputs_per_frame[frame_idx] = mask_inputs
  264. point_inputs_per_frame.pop(frame_idx, None)
  265. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  266. # frame, meaning that the inputs points are to generate segments on this frame without
  267. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  268. # the input points will be used to correct the already tracked masks.
  269. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  270. # whether to track in reverse time order
  271. if is_init_cond_frame:
  272. reverse = False
  273. else:
  274. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  275. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  276. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  277. # Add a frame to conditioning output if it's an initial conditioning frame or
  278. # if the model sees all frames receiving clicks/mask as conditioning frames.
  279. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  280. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  281. current_out, _ = self._run_single_frame_inference(
  282. inference_state=inference_state,
  283. output_dict=obj_output_dict, # run on the slice of a single object
  284. frame_idx=frame_idx,
  285. batch_size=1, # run on the slice of a single object
  286. is_init_cond_frame=is_init_cond_frame,
  287. point_inputs=None,
  288. mask_inputs=mask_inputs,
  289. reverse=reverse,
  290. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  291. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  292. # allows us to enforce non-overlapping constraints on all objects before encoding
  293. # them into memory.
  294. run_mem_encoder=False,
  295. )
  296. # Add the output to the output dict (to be used as future memory)
  297. obj_temp_output_dict[storage_key][frame_idx] = current_out
  298. # Resize the output mask to the original video resolution
  299. obj_ids = inference_state["obj_ids"]
  300. consolidated_out = self._consolidate_temp_output_across_obj(
  301. inference_state,
  302. frame_idx,
  303. is_cond=is_cond,
  304. run_mem_encoder=False,
  305. consolidate_at_video_res=True,
  306. )
  307. _, video_res_masks = self._get_orig_video_res_output(
  308. inference_state, consolidated_out["pred_masks_video_res"]
  309. )
  310. return frame_idx, obj_ids, video_res_masks
  311. def _get_orig_video_res_output(self, inference_state, any_res_masks):
  312. """
  313. Resize the object scores to the original video resolution (video_res_masks)
  314. and apply non-overlapping constraints for final output.
  315. """
  316. device = inference_state["device"]
  317. video_H = inference_state["video_height"]
  318. video_W = inference_state["video_width"]
  319. any_res_masks = any_res_masks.to(device, non_blocking=True)
  320. if any_res_masks.shape[-2:] == (video_H, video_W):
  321. video_res_masks = any_res_masks
  322. else:
  323. video_res_masks = torch.nn.functional.interpolate(
  324. any_res_masks,
  325. size=(video_H, video_W),
  326. mode="bilinear",
  327. align_corners=False,
  328. )
  329. if self.non_overlap_masks:
  330. video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
  331. return any_res_masks, video_res_masks
  332. def _consolidate_temp_output_across_obj(
  333. self,
  334. inference_state,
  335. frame_idx,
  336. is_cond,
  337. run_mem_encoder,
  338. consolidate_at_video_res=False,
  339. ):
  340. """
  341. Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
  342. a frame into a single output for all objects, including
  343. 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
  344. `output_dict_per_obj` for this frame) or leave them as placeholder values
  345. (if they don't exist in `output_dict_per_obj` for this frame);
  346. 2) if specified, rerun memory encoder after apply non-overlapping constraints
  347. on the object scores.
  348. """
  349. batch_size = self._get_obj_num(inference_state)
  350. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  351. # Optionally, we allow consolidating the temporary outputs at the original
  352. # video resolution (to provide a better editing experience for mask prompts).
  353. if consolidate_at_video_res:
  354. assert not run_mem_encoder, "memory encoder cannot run at video resolution"
  355. consolidated_H = inference_state["video_height"]
  356. consolidated_W = inference_state["video_width"]
  357. consolidated_mask_key = "pred_masks_video_res"
  358. else:
  359. consolidated_H = consolidated_W = self.image_size // 4
  360. consolidated_mask_key = "pred_masks"
  361. # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
  362. # will be added when rerunning the memory encoder after applying non-overlapping
  363. # constraints to object scores. Its "pred_masks" are prefilled with a large
  364. # negative value (NO_OBJ_SCORE) to represent missing objects.
  365. consolidated_out = {
  366. "maskmem_features": None,
  367. "maskmem_pos_enc": None,
  368. consolidated_mask_key: torch.full(
  369. size=(batch_size, 1, consolidated_H, consolidated_W),
  370. fill_value=NO_OBJ_SCORE,
  371. dtype=torch.float32,
  372. device=inference_state["storage_device"],
  373. ),
  374. "obj_ptr": torch.full(
  375. size=(batch_size, self.hidden_dim),
  376. fill_value=NO_OBJ_SCORE,
  377. dtype=torch.float32,
  378. device=inference_state["device"],
  379. ),
  380. }
  381. empty_mask_ptr = None
  382. for obj_idx in range(batch_size):
  383. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  384. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  385. out = obj_temp_output_dict[storage_key].get(frame_idx, None)
  386. # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
  387. # we fall back and look up its previous output in "output_dict_per_obj".
  388. # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
  389. # "output_dict_per_obj" to find a previous output for this object.
  390. if out is None:
  391. out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
  392. if out is None:
  393. out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
  394. # If the object doesn't appear in "output_dict_per_obj" either, we skip it
  395. # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
  396. # placeholder above) and set its object pointer to be a dummy pointer.
  397. if out is None:
  398. # Fill in dummy object pointers for those objects without any inputs or
  399. # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
  400. # i.e. when we need to build the memory for tracking).
  401. if run_mem_encoder:
  402. if empty_mask_ptr is None:
  403. empty_mask_ptr = self._get_empty_mask_ptr(
  404. inference_state, frame_idx
  405. )
  406. # fill object pointer with a dummy pointer (based on an empty mask)
  407. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
  408. continue
  409. # Add the temporary object output mask to consolidated output mask
  410. obj_mask = out["pred_masks"]
  411. consolidated_pred_masks = consolidated_out[consolidated_mask_key]
  412. if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
  413. consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
  414. else:
  415. # Resize first if temporary object mask has a different resolution
  416. resized_obj_mask = torch.nn.functional.interpolate(
  417. obj_mask,
  418. size=consolidated_pred_masks.shape[-2:],
  419. mode="bilinear",
  420. align_corners=False,
  421. )
  422. consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
  423. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
  424. # Optionally, apply non-overlapping constraints on the consolidated scores
  425. # and rerun the memory encoder
  426. if run_mem_encoder:
  427. device = inference_state["device"]
  428. high_res_masks = torch.nn.functional.interpolate(
  429. consolidated_out["pred_masks"].to(device, non_blocking=True),
  430. size=(self.image_size, self.image_size),
  431. mode="bilinear",
  432. align_corners=False,
  433. )
  434. if self.non_overlap_masks_for_mem_enc:
  435. high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
  436. maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
  437. inference_state=inference_state,
  438. frame_idx=frame_idx,
  439. batch_size=batch_size,
  440. high_res_masks=high_res_masks,
  441. is_mask_from_pts=True, # these frames are what the user interacted with
  442. )
  443. consolidated_out["maskmem_features"] = maskmem_features
  444. consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
  445. return consolidated_out
  446. def _get_empty_mask_ptr(self, inference_state, frame_idx):
  447. """Get a dummy object pointer based on an empty mask on the current frame."""
  448. # A dummy (empty) mask with a single object
  449. batch_size = 1
  450. mask_inputs = torch.zeros(
  451. (batch_size, 1, self.image_size, self.image_size),
  452. dtype=torch.float32,
  453. device=inference_state["device"],
  454. )
  455. # Retrieve correct image features
  456. (
  457. _,
  458. _,
  459. current_vision_feats,
  460. current_vision_pos_embeds,
  461. feat_sizes,
  462. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  463. # Feed the empty mask and image feature above to get a dummy object pointer
  464. current_out = self.track_step(
  465. frame_idx=frame_idx,
  466. is_init_cond_frame=True,
  467. current_vision_feats=current_vision_feats,
  468. current_vision_pos_embeds=current_vision_pos_embeds,
  469. feat_sizes=feat_sizes,
  470. point_inputs=None,
  471. mask_inputs=mask_inputs,
  472. output_dict={},
  473. num_frames=inference_state["num_frames"],
  474. track_in_reverse=False,
  475. run_mem_encoder=False,
  476. prev_sam_mask_logits=None,
  477. )
  478. return current_out["obj_ptr"]
  479. @torch.inference_mode()
  480. def propagate_in_video_preflight(self, inference_state):
  481. """Prepare inference_state and consolidate temporary outputs before tracking."""
  482. # Tracking has started and we don't allow adding new objects until session is reset.
  483. inference_state["tracking_has_started"] = True
  484. batch_size = self._get_obj_num(inference_state)
  485. # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
  486. # add them into "output_dict".
  487. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  488. output_dict = inference_state["output_dict"]
  489. # "consolidated_frame_inds" contains indices of those frames where consolidated
  490. # temporary outputs have been added (either in this call or any previous calls
  491. # to `propagate_in_video_preflight`).
  492. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  493. for is_cond in [False, True]:
  494. # Separately consolidate conditioning and non-conditioning temp outptus
  495. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  496. # Find all the frames that contain temporary outputs for any objects
  497. # (these should be the frames that have just received clicks for mask inputs
  498. # via `add_new_points` or `add_new_mask`)
  499. temp_frame_inds = set()
  500. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  501. temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
  502. consolidated_frame_inds[storage_key].update(temp_frame_inds)
  503. # consolidate the temprary output across all objects on this frame
  504. for frame_idx in temp_frame_inds:
  505. consolidated_out = self._consolidate_temp_output_across_obj(
  506. inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
  507. )
  508. # merge them into "output_dict" and also create per-object slices
  509. output_dict[storage_key][frame_idx] = consolidated_out
  510. self._add_output_per_object(
  511. inference_state, frame_idx, consolidated_out, storage_key
  512. )
  513. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  514. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  515. )
  516. if clear_non_cond_mem:
  517. # clear non-conditioning memory of the surrounding frames
  518. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  519. # clear temporary outputs in `temp_output_dict_per_obj`
  520. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  521. obj_temp_output_dict[storage_key].clear()
  522. # edge case: if an output is added to "cond_frame_outputs", we remove any prior
  523. # output on the same frame in "non_cond_frame_outputs"
  524. for frame_idx in output_dict["cond_frame_outputs"]:
  525. output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  526. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  527. for frame_idx in obj_output_dict["cond_frame_outputs"]:
  528. obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  529. for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  530. assert frame_idx in output_dict["cond_frame_outputs"]
  531. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  532. # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
  533. # with either points or mask inputs (which should be true under a correct workflow).
  534. all_consolidated_frame_inds = (
  535. consolidated_frame_inds["cond_frame_outputs"]
  536. | consolidated_frame_inds["non_cond_frame_outputs"]
  537. )
  538. input_frames_inds = set()
  539. for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
  540. input_frames_inds.update(point_inputs_per_frame.keys())
  541. for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
  542. input_frames_inds.update(mask_inputs_per_frame.keys())
  543. assert all_consolidated_frame_inds == input_frames_inds
  544. @torch.inference_mode()
  545. def propagate_in_video(
  546. self,
  547. inference_state,
  548. start_frame_idx=None,
  549. max_frame_num_to_track=None,
  550. reverse=False,
  551. ):
  552. """Propagate the input points across frames to track in the entire video."""
  553. self.propagate_in_video_preflight(inference_state)
  554. output_dict = inference_state["output_dict"]
  555. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  556. obj_ids = inference_state["obj_ids"]
  557. num_frames = inference_state["num_frames"]
  558. batch_size = self._get_obj_num(inference_state)
  559. if len(output_dict["cond_frame_outputs"]) == 0:
  560. raise RuntimeError("No points are provided; please add points first")
  561. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  562. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  563. )
  564. # set start index, end index, and processing order
  565. if start_frame_idx is None:
  566. # default: start from the earliest frame with input points
  567. start_frame_idx = min(output_dict["cond_frame_outputs"])
  568. if max_frame_num_to_track is None:
  569. # default: track all the frames in the video
  570. max_frame_num_to_track = num_frames
  571. if reverse:
  572. end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
  573. if start_frame_idx > 0:
  574. processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
  575. else:
  576. processing_order = [] # skip reverse tracking if starting from frame 0
  577. else:
  578. end_frame_idx = min(
  579. start_frame_idx + max_frame_num_to_track, num_frames - 1
  580. )
  581. processing_order = range(start_frame_idx, end_frame_idx + 1)
  582. for frame_idx in tqdm(processing_order, desc="propagate in video"):
  583. # We skip those frames already in consolidated outputs (these are frames
  584. # that received input clicks or mask). Note that we cannot directly run
  585. # batched forward on them via `_run_single_frame_inference` because the
  586. # number of clicks on each object might be different.
  587. if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  588. storage_key = "cond_frame_outputs"
  589. current_out = output_dict[storage_key][frame_idx]
  590. pred_masks = current_out["pred_masks"]
  591. if clear_non_cond_mem:
  592. # clear non-conditioning memory of the surrounding frames
  593. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  594. elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
  595. storage_key = "non_cond_frame_outputs"
  596. current_out = output_dict[storage_key][frame_idx]
  597. pred_masks = current_out["pred_masks"]
  598. else:
  599. storage_key = "non_cond_frame_outputs"
  600. current_out, pred_masks = self._run_single_frame_inference(
  601. inference_state=inference_state,
  602. output_dict=output_dict,
  603. frame_idx=frame_idx,
  604. batch_size=batch_size,
  605. is_init_cond_frame=False,
  606. point_inputs=None,
  607. mask_inputs=None,
  608. reverse=reverse,
  609. run_mem_encoder=True,
  610. )
  611. output_dict[storage_key][frame_idx] = current_out
  612. # Create slices of per-object outputs for subsequent interaction with each
  613. # individual object after tracking.
  614. self._add_output_per_object(
  615. inference_state, frame_idx, current_out, storage_key
  616. )
  617. inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
  618. # Resize the output mask to the original video resolution (we directly use
  619. # the mask scores on GPU for output to avoid any CPU conversion in between)
  620. _, video_res_masks = self._get_orig_video_res_output(
  621. inference_state, pred_masks
  622. )
  623. yield frame_idx, obj_ids, video_res_masks
  624. def _add_output_per_object(
  625. self, inference_state, frame_idx, current_out, storage_key
  626. ):
  627. """
  628. Split a multi-object output into per-object output slices and add them into
  629. `output_dict_per_obj`. The resulting slices share the same tensor storage.
  630. """
  631. maskmem_features = current_out["maskmem_features"]
  632. assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
  633. maskmem_pos_enc = current_out["maskmem_pos_enc"]
  634. assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
  635. output_dict_per_obj = inference_state["output_dict_per_obj"]
  636. for obj_idx, obj_output_dict in output_dict_per_obj.items():
  637. obj_slice = slice(obj_idx, obj_idx + 1)
  638. obj_out = {
  639. "maskmem_features": None,
  640. "maskmem_pos_enc": None,
  641. "pred_masks": current_out["pred_masks"][obj_slice],
  642. "obj_ptr": current_out["obj_ptr"][obj_slice],
  643. }
  644. if maskmem_features is not None:
  645. obj_out["maskmem_features"] = maskmem_features[obj_slice]
  646. if maskmem_pos_enc is not None:
  647. obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
  648. obj_output_dict[storage_key][frame_idx] = obj_out
  649. @torch.inference_mode()
  650. def reset_state(self, inference_state):
  651. """Remove all input points or mask in all frames throughout the video."""
  652. self._reset_tracking_results(inference_state)
  653. # Remove all object ids
  654. inference_state["obj_id_to_idx"].clear()
  655. inference_state["obj_idx_to_id"].clear()
  656. inference_state["obj_ids"].clear()
  657. inference_state["point_inputs_per_obj"].clear()
  658. inference_state["mask_inputs_per_obj"].clear()
  659. inference_state["output_dict_per_obj"].clear()
  660. inference_state["temp_output_dict_per_obj"].clear()
  661. def _reset_tracking_results(self, inference_state):
  662. """Reset all tracking inputs and results across the videos."""
  663. for v in inference_state["point_inputs_per_obj"].values():
  664. v.clear()
  665. for v in inference_state["mask_inputs_per_obj"].values():
  666. v.clear()
  667. for v in inference_state["output_dict_per_obj"].values():
  668. v["cond_frame_outputs"].clear()
  669. v["non_cond_frame_outputs"].clear()
  670. for v in inference_state["temp_output_dict_per_obj"].values():
  671. v["cond_frame_outputs"].clear()
  672. v["non_cond_frame_outputs"].clear()
  673. inference_state["output_dict"]["cond_frame_outputs"].clear()
  674. inference_state["output_dict"]["non_cond_frame_outputs"].clear()
  675. inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
  676. inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
  677. inference_state["tracking_has_started"] = False
  678. inference_state["frames_already_tracked"].clear()
  679. def _get_image_feature(self, inference_state, frame_idx, batch_size):
  680. """Compute the image features on a given frame."""
  681. # Look up in the cache first
  682. image, backbone_out = inference_state["cached_features"].get(
  683. frame_idx, (None, None)
  684. )
  685. if backbone_out is None:
  686. # Cache miss -- we will run inference on a single image
  687. image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
  688. backbone_out = self.forward_image(image)
  689. # Cache the most recent frame's feature (for repeated interactions with
  690. # a frame; we can use an LRU cache for more frames in the future).
  691. inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
  692. # expand the features to have the same dimension as the number of objects
  693. expanded_image = image.expand(batch_size, -1, -1, -1)
  694. expanded_backbone_out = {
  695. "backbone_fpn": backbone_out["backbone_fpn"].copy(),
  696. "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
  697. }
  698. for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
  699. expanded_backbone_out["backbone_fpn"][i] = feat.expand(
  700. batch_size, -1, -1, -1
  701. )
  702. for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
  703. pos = pos.expand(batch_size, -1, -1, -1)
  704. expanded_backbone_out["vision_pos_enc"][i] = pos
  705. features = self._prepare_backbone_features(expanded_backbone_out)
  706. features = (expanded_image,) + features
  707. return features
  708. def _run_single_frame_inference(
  709. self,
  710. inference_state,
  711. output_dict,
  712. frame_idx,
  713. batch_size,
  714. is_init_cond_frame,
  715. point_inputs,
  716. mask_inputs,
  717. reverse,
  718. run_mem_encoder,
  719. prev_sam_mask_logits=None,
  720. ):
  721. """Run tracking on a single frame based on current inputs and previous memory."""
  722. # Retrieve correct image features
  723. (
  724. _,
  725. _,
  726. current_vision_feats,
  727. current_vision_pos_embeds,
  728. feat_sizes,
  729. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  730. # point and mask should not appear as input simultaneously on the same frame
  731. assert point_inputs is None or mask_inputs is None
  732. current_out = self.track_step(
  733. frame_idx=frame_idx,
  734. is_init_cond_frame=is_init_cond_frame,
  735. current_vision_feats=current_vision_feats,
  736. current_vision_pos_embeds=current_vision_pos_embeds,
  737. feat_sizes=feat_sizes,
  738. point_inputs=point_inputs,
  739. mask_inputs=mask_inputs,
  740. output_dict=output_dict,
  741. num_frames=inference_state["num_frames"],
  742. track_in_reverse=reverse,
  743. run_mem_encoder=run_mem_encoder,
  744. prev_sam_mask_logits=prev_sam_mask_logits,
  745. )
  746. # optionally offload the output to CPU memory to save GPU space
  747. storage_device = inference_state["storage_device"]
  748. maskmem_features = current_out["maskmem_features"]
  749. if maskmem_features is not None:
  750. maskmem_features = maskmem_features.to(torch.bfloat16)
  751. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  752. pred_masks_gpu = current_out["pred_masks"]
  753. # potentially fill holes in the predicted masks
  754. if self.fill_hole_area > 0:
  755. pred_masks_gpu = fill_holes_in_mask_scores(
  756. pred_masks_gpu, self.fill_hole_area
  757. )
  758. pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
  759. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  760. maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
  761. # object pointer is a small tensor, so we always keep it on GPU memory for fast access
  762. obj_ptr = current_out["obj_ptr"]
  763. # make a compact version of this frame's output to reduce the state size
  764. compact_current_out = {
  765. "maskmem_features": maskmem_features,
  766. "maskmem_pos_enc": maskmem_pos_enc,
  767. "pred_masks": pred_masks,
  768. "obj_ptr": obj_ptr,
  769. }
  770. return compact_current_out, pred_masks_gpu
  771. def _run_memory_encoder(
  772. self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
  773. ):
  774. """
  775. Run the memory encoder on `high_res_masks`. This is usually after applying
  776. non-overlapping constraints to object scores. Since their scores changed, their
  777. memory also need to be computed again with the memory encoder.
  778. """
  779. # Retrieve correct image features
  780. _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
  781. inference_state, frame_idx, batch_size
  782. )
  783. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  784. current_vision_feats=current_vision_feats,
  785. feat_sizes=feat_sizes,
  786. pred_masks_high_res=high_res_masks,
  787. is_mask_from_pts=is_mask_from_pts,
  788. )
  789. # optionally offload the output to CPU memory to save GPU space
  790. storage_device = inference_state["storage_device"]
  791. maskmem_features = maskmem_features.to(torch.bfloat16)
  792. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  793. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  794. maskmem_pos_enc = self._get_maskmem_pos_enc(
  795. inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
  796. )
  797. return maskmem_features, maskmem_pos_enc
  798. def _get_maskmem_pos_enc(self, inference_state, current_out):
  799. """
  800. `maskmem_pos_enc` is the same across frames and objects, so we cache it as
  801. a constant in the inference session to reduce session storage size.
  802. """
  803. model_constants = inference_state["constants"]
  804. # "out_maskmem_pos_enc" should be either a list of tensors or None
  805. out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
  806. if out_maskmem_pos_enc is not None:
  807. if "maskmem_pos_enc" not in model_constants:
  808. assert isinstance(out_maskmem_pos_enc, list)
  809. # only take the slice for one object, since it's same across objects
  810. maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
  811. model_constants["maskmem_pos_enc"] = maskmem_pos_enc
  812. else:
  813. maskmem_pos_enc = model_constants["maskmem_pos_enc"]
  814. # expand the cached maskmem_pos_enc to the actual batch size
  815. batch_size = out_maskmem_pos_enc[0].size(0)
  816. expanded_maskmem_pos_enc = [
  817. x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
  818. ]
  819. else:
  820. expanded_maskmem_pos_enc = None
  821. return expanded_maskmem_pos_enc
  822. def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
  823. """
  824. Remove the non-conditioning memory around the input frame. When users provide
  825. correction clicks, the surrounding frames' non-conditioning memories can still
  826. contain outdated object appearance information and could confuse the model.
  827. This method clears those non-conditioning memories surrounding the interacted
  828. frame to avoid giving the model both old and new information about the object.
  829. """
  830. r = self.memory_temporal_stride_for_eval
  831. frame_idx_begin = frame_idx - r * self.num_maskmem
  832. frame_idx_end = frame_idx + r * self.num_maskmem
  833. output_dict = inference_state["output_dict"]
  834. non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
  835. for t in range(frame_idx_begin, frame_idx_end + 1):
  836. non_cond_frame_outputs.pop(t, None)
  837. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  838. obj_output_dict["non_cond_frame_outputs"].pop(t, None)