sam3_tracking_predictor.py 67 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. from collections import OrderedDict
  5. import torch
  6. from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
  7. from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
  8. from sam3.model.utils.sam2_utils import load_video_frames
  9. from tqdm.auto import tqdm
  10. class Sam3TrackerPredictor(Sam3TrackerBase):
  11. """
  12. The demo class that extends the `Sam3TrackerBase` to handle user interactions
  13. and manage inference states, with support for multi-object tracking.
  14. """
  15. def __init__(
  16. self,
  17. # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
  18. # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
  19. clear_non_cond_mem_around_input=False,
  20. # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
  21. clear_non_cond_mem_for_multi_obj=False,
  22. # if fill_hole_area > 0, we fill small holes in the final masks up to this area (after resizing them to the original video resolution)
  23. fill_hole_area=0,
  24. # if always_start_from_first_ann_frame is True, we always start tracking from the frame where we receive the first annotation (clicks or mask)
  25. # and ignore the `start_frame_idx` passed to `propagate_in_video`
  26. always_start_from_first_ann_frame=False,
  27. # the maximum number of points to be used in the prompt encoder, which reduce the domain gap between training (that only has 8 points)
  28. # - if it's set to a positive integer, we only take the `max_point_num_in_prompt_enc//2` points and
  29. # the last `(max_point_num_in_prompt_enc - max_point_num_in_prompt_enc//2)` points in the prompt encoder
  30. # - if it's set to 0 or negative, this option is turned off and we use all points in the prompt encoder
  31. max_point_num_in_prompt_enc=16,
  32. non_overlap_masks_for_output=True,
  33. # checkpoint_file=None,
  34. **kwargs,
  35. ):
  36. super().__init__(**kwargs)
  37. self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
  38. self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
  39. self.fill_hole_area = fill_hole_area
  40. self.always_start_from_first_ann_frame = always_start_from_first_ann_frame
  41. self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc
  42. self.non_overlap_masks_for_output = non_overlap_masks_for_output
  43. self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
  44. self.bf16_context.__enter__() # keep using for the entire model process
  45. self.iter_use_prev_mask_pred = True
  46. self.add_all_frames_to_correct_as_cond = True
  47. @torch.inference_mode()
  48. def init_state(
  49. self,
  50. video_height=None,
  51. video_width=None,
  52. num_frames=None,
  53. video_path=None,
  54. cached_features=None,
  55. offload_video_to_cpu=False,
  56. offload_state_to_cpu=False,
  57. async_loading_frames=False,
  58. ):
  59. """Initialize a inference state."""
  60. inference_state = {}
  61. # whether to offload the video frames to CPU memory
  62. # turning on this option saves the GPU memory with only a very small overhead
  63. inference_state["offload_video_to_cpu"] = offload_video_to_cpu
  64. # whether to offload the inference state to CPU memory
  65. # turning on this option saves the GPU memory at the cost of a lower tracking fps
  66. # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
  67. # and from 24 to 21 when tracking two objects)
  68. inference_state["offload_state_to_cpu"] = offload_state_to_cpu
  69. inference_state["device"] = self.device
  70. if offload_state_to_cpu:
  71. inference_state["storage_device"] = torch.device("cpu")
  72. else:
  73. inference_state["storage_device"] = torch.device("cuda")
  74. if video_path is not None:
  75. images, video_height, video_width = load_video_frames(
  76. video_path=video_path,
  77. image_size=self.image_size,
  78. offload_video_to_cpu=offload_video_to_cpu,
  79. async_loading_frames=async_loading_frames,
  80. compute_device=inference_state["storage_device"],
  81. )
  82. inference_state["images"] = images
  83. inference_state["num_frames"] = len(images)
  84. inference_state["video_height"] = video_height
  85. inference_state["video_width"] = video_width
  86. else:
  87. # the original video height and width, used for resizing final output scores
  88. inference_state["video_height"] = video_height
  89. inference_state["video_width"] = video_width
  90. inference_state["num_frames"] = num_frames
  91. # inputs on each frame
  92. inference_state["point_inputs_per_obj"] = {}
  93. inference_state["mask_inputs_per_obj"] = {}
  94. # visual features on a small number of recently visited frames for quick interactions
  95. inference_state["cached_features"] = (
  96. {} if cached_features is None else cached_features
  97. )
  98. # values that don't change across frames (so we only need to hold one copy of them)
  99. inference_state["constants"] = {}
  100. # mapping between client-side object id and model-side object index
  101. inference_state["obj_id_to_idx"] = OrderedDict()
  102. inference_state["obj_idx_to_id"] = OrderedDict()
  103. inference_state["obj_ids"] = []
  104. # A storage to hold the model's tracking results and states on each frame
  105. inference_state["output_dict"] = {
  106. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  107. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  108. }
  109. # The index of the frame that received the first annotation
  110. inference_state["first_ann_frame_idx"] = None
  111. # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
  112. inference_state["output_dict_per_obj"] = {}
  113. # A temporary storage to hold new outputs when user interact with a frame
  114. # to add clicks or mask (it's merged into "output_dict" before propagation starts)
  115. inference_state["temp_output_dict_per_obj"] = {}
  116. # Frames that already holds consolidated outputs from click or mask inputs
  117. # (we directly use their consolidated outputs during tracking)
  118. inference_state["consolidated_frame_inds"] = {
  119. "cond_frame_outputs": set(), # set containing frame indices
  120. "non_cond_frame_outputs": set(), # set containing frame indices
  121. }
  122. # metadata for each tracking frame (e.g. which direction it's tracked)
  123. inference_state["tracking_has_started"] = False
  124. inference_state["frames_already_tracked"] = {}
  125. self.clear_all_points_in_video(inference_state)
  126. return inference_state
  127. def _obj_id_to_idx(self, inference_state, obj_id):
  128. """Map client-side object id to model-side object index."""
  129. obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
  130. if obj_idx is not None:
  131. return obj_idx
  132. # This is a new object id not sent to the server before. We only allow adding
  133. # new objects *before* the tracking starts.
  134. allow_new_object = not inference_state["tracking_has_started"]
  135. if allow_new_object:
  136. # get the next object slot
  137. obj_idx = len(inference_state["obj_id_to_idx"])
  138. inference_state["obj_id_to_idx"][obj_id] = obj_idx
  139. inference_state["obj_idx_to_id"][obj_idx] = obj_id
  140. inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
  141. # set up input and output structures for this object
  142. inference_state["point_inputs_per_obj"][obj_idx] = {}
  143. inference_state["mask_inputs_per_obj"][obj_idx] = {}
  144. inference_state["output_dict_per_obj"][obj_idx] = {
  145. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  146. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  147. }
  148. inference_state["temp_output_dict_per_obj"][obj_idx] = {
  149. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  150. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  151. }
  152. return obj_idx
  153. else:
  154. raise RuntimeError(
  155. f"Cannot add new object id {obj_id} after tracking starts. "
  156. f"All existing object ids: {inference_state['obj_ids']}."
  157. )
  158. def _obj_idx_to_id(self, inference_state, obj_idx):
  159. """Map model-side object index to client-side object id."""
  160. return inference_state["obj_idx_to_id"][obj_idx]
  161. def _get_obj_num(self, inference_state):
  162. """Get the total number of unique object ids received so far in this session."""
  163. return len(inference_state["obj_idx_to_id"])
  164. @torch.inference_mode()
  165. def add_new_points_or_box(
  166. self,
  167. inference_state,
  168. frame_idx,
  169. obj_id,
  170. points=None,
  171. labels=None,
  172. clear_old_points=True,
  173. rel_coordinates=True,
  174. use_prev_mem_frame=False,
  175. normalize_coords=True,
  176. box=None,
  177. ):
  178. """Add new points to a frame."""
  179. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  180. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  181. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  182. if (points is not None) != (labels is not None):
  183. raise ValueError("points and labels must be provided together")
  184. if points is None and box is None:
  185. raise ValueError("at least one of points or box must be provided as input")
  186. if points is None:
  187. points = torch.zeros(0, 2, dtype=torch.float32)
  188. elif not isinstance(points, torch.Tensor):
  189. points = torch.tensor(points, dtype=torch.float32)
  190. if labels is None:
  191. labels = torch.zeros(0, dtype=torch.int32)
  192. elif not isinstance(labels, torch.Tensor):
  193. labels = torch.tensor(labels, dtype=torch.int32)
  194. if points.dim() == 2:
  195. points = points.unsqueeze(0) # add batch dimension
  196. if labels.dim() == 1:
  197. labels = labels.unsqueeze(0) # add batch dimension
  198. if rel_coordinates:
  199. # convert the points from relative coordinates to absolute coordinates
  200. if points is not None:
  201. points = points * self.image_size
  202. if box is not None:
  203. box = box * self.image_size
  204. # If `box` is provided, we add it as the first two points with labels 2 and 3
  205. # along with the user-provided points (consistent with how SAM 2 is trained).
  206. if box is not None:
  207. if not clear_old_points:
  208. raise ValueError(
  209. "cannot add box without clearing old points, since "
  210. "box prompt must be provided before any point prompt "
  211. "(please use clear_old_points=True instead)"
  212. )
  213. if not isinstance(box, torch.Tensor):
  214. box = torch.tensor(box, dtype=torch.float32, device=points.device)
  215. box_coords = box.reshape(1, 2, 2)
  216. box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
  217. box_labels = box_labels.reshape(1, 2)
  218. points = torch.cat([box_coords, points], dim=1)
  219. labels = torch.cat([box_labels, labels], dim=1)
  220. points = points.to(inference_state["device"])
  221. labels = labels.to(inference_state["device"])
  222. if not clear_old_points:
  223. point_inputs = point_inputs_per_frame.get(frame_idx, None)
  224. else:
  225. point_inputs = None
  226. point_inputs = concat_points(point_inputs, points, labels)
  227. point_inputs_per_frame[frame_idx] = point_inputs
  228. mask_inputs_per_frame.pop(frame_idx, None)
  229. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  230. # frame, meaning that the inputs points are to generate segments on this frame without
  231. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  232. # the input points will be used to correct the already tracked masks.
  233. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  234. # whether to track in reverse time order
  235. if is_init_cond_frame:
  236. reverse = False
  237. else:
  238. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  239. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  240. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  241. # Add a frame to conditioning output if it's an initial conditioning frame or
  242. # if the model sees all frames receiving clicks/mask as conditioning frames.
  243. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  244. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  245. # Limit to a maximum number of input points to the prompt encoder (to reduce domain gap)
  246. num_points = point_inputs["point_coords"].size(1)
  247. if num_points > self.max_point_num_in_prompt_enc > 0:
  248. num_first = self.max_point_num_in_prompt_enc // 2
  249. num_last = self.max_point_num_in_prompt_enc - num_first
  250. point_inputs["point_coords"] = torch.cat(
  251. [
  252. point_inputs["point_coords"][:, :num_first],
  253. point_inputs["point_coords"][:, -num_last:],
  254. ],
  255. dim=1,
  256. )
  257. point_inputs["point_labels"] = torch.cat(
  258. [
  259. point_inputs["point_labels"][:, :num_first],
  260. point_inputs["point_labels"][:, -num_last:],
  261. ],
  262. dim=1,
  263. )
  264. logging.warning(
  265. f"Too many points ({num_points}) are provided on frame {frame_idx}. Only "
  266. f"the first {num_first} points and the last {num_last} points will be used."
  267. )
  268. # Get any previously predicted mask logits on this object and feed it along with
  269. # the new clicks into the SAM mask decoder when `self.iter_use_prev_mask_pred=True`.
  270. prev_sam_mask_logits = None
  271. if self.iter_use_prev_mask_pred:
  272. # lookup temporary output dict first, which contains the most recent output
  273. # (if not found, then lookup conditioning and non-conditioning frame output)
  274. prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
  275. if prev_out is None:
  276. prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
  277. if prev_out is None:
  278. prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
  279. if prev_out is not None and prev_out["pred_masks"] is not None:
  280. prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
  281. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
  282. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
  283. current_out, _ = self._run_single_frame_inference(
  284. inference_state=inference_state,
  285. output_dict=obj_output_dict, # run on the slice of a single object
  286. frame_idx=frame_idx,
  287. batch_size=1, # run on the slice of a single object
  288. is_init_cond_frame=is_init_cond_frame,
  289. point_inputs=point_inputs,
  290. mask_inputs=None,
  291. reverse=reverse,
  292. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  293. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  294. # allows us to enforce non-overlapping constraints on all objects before encoding
  295. # them into memory.
  296. run_mem_encoder=False,
  297. prev_sam_mask_logits=prev_sam_mask_logits,
  298. use_prev_mem_frame=use_prev_mem_frame,
  299. )
  300. # Add the output to the output dict (to be used as future memory)
  301. obj_temp_output_dict[storage_key][frame_idx] = current_out
  302. # Resize the output mask to the original video resolution
  303. obj_ids = inference_state["obj_ids"]
  304. consolidated_out = self._consolidate_temp_output_across_obj(
  305. inference_state,
  306. frame_idx,
  307. is_cond=is_cond,
  308. run_mem_encoder=False,
  309. consolidate_at_video_res=True,
  310. )
  311. _, video_res_masks = self._get_orig_video_res_output(
  312. inference_state, consolidated_out["pred_masks_video_res"]
  313. )
  314. low_res_masks = None # not needed by the demo
  315. return frame_idx, obj_ids, low_res_masks, video_res_masks
  316. @torch.inference_mode()
  317. def add_new_mask(
  318. self,
  319. inference_state,
  320. frame_idx,
  321. obj_id,
  322. mask,
  323. add_mask_to_memory=False,
  324. ):
  325. """Add new mask to a frame."""
  326. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  327. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  328. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  329. assert mask.dim() == 2
  330. mask_H, mask_W = mask.shape
  331. mask_inputs_orig = mask[None, None] # add batch and channel dimension
  332. mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
  333. # resize the mask if it doesn't match the model's input mask size
  334. if mask_H != self.input_mask_size or mask_W != self.input_mask_size:
  335. mask_inputs = torch.nn.functional.interpolate(
  336. mask_inputs_orig,
  337. size=(self.input_mask_size, self.input_mask_size),
  338. align_corners=False,
  339. mode="bilinear",
  340. antialias=True, # use antialias for downsampling
  341. )
  342. else:
  343. mask_inputs = mask_inputs_orig
  344. # also get the mask at the original video resolution (for outputting)
  345. video_H = inference_state["video_height"]
  346. video_W = inference_state["video_width"]
  347. if mask_H != video_H or mask_W != video_W:
  348. mask_inputs_video_res = torch.nn.functional.interpolate(
  349. mask_inputs_orig,
  350. size=(video_H, video_W),
  351. align_corners=False,
  352. mode="bilinear",
  353. antialias=True, # use antialias for potential downsampling
  354. )
  355. else:
  356. mask_inputs_video_res = mask_inputs_orig
  357. # convert mask_inputs_video_res to binary (threshold at 0.5 as it is in range 0~1)
  358. mask_inputs_video_res = mask_inputs_video_res > 0.5
  359. mask_inputs_per_frame[frame_idx] = mask_inputs_video_res
  360. point_inputs_per_frame.pop(frame_idx, None)
  361. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  362. # frame, meaning that the inputs points are to generate segments on this frame without
  363. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  364. # the input points will be used to correct the already tracked masks.
  365. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  366. # whether to track in reverse time order
  367. if is_init_cond_frame:
  368. reverse = False
  369. else:
  370. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  371. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  372. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  373. # Add a frame to conditioning output if it's an initial conditioning frame or
  374. # if the model sees all frames receiving clicks/mask as conditioning frames.
  375. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  376. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  377. current_out, _ = self._run_single_frame_inference(
  378. inference_state=inference_state,
  379. output_dict=obj_output_dict, # run on the slice of a single object
  380. frame_idx=frame_idx,
  381. batch_size=1, # run on the slice of a single object
  382. is_init_cond_frame=is_init_cond_frame,
  383. point_inputs=None,
  384. mask_inputs=mask_inputs,
  385. reverse=reverse,
  386. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  387. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  388. # allows us to enforce non-overlapping constraints on all objects before encoding
  389. # them into memory.
  390. run_mem_encoder=False,
  391. )
  392. # We directly use the input mask at video resolution as the output mask for a better
  393. # video editing experience (so that the masks don't change after each brushing).
  394. # Here NO_OBJ_SCORE is a large negative value to represent the background and
  395. # similarly -NO_OBJ_SCORE is a large positive value to represent the foreground.
  396. current_out["pred_masks"] = None
  397. current_out["pred_masks_video_res"] = torch.where(
  398. mask_inputs_video_res, -NO_OBJ_SCORE, NO_OBJ_SCORE
  399. )
  400. # Add the output to the output dict (to be used as future memory)
  401. obj_temp_output_dict[storage_key][frame_idx] = current_out
  402. # Remove the overlapping proportion of other objects' input masks on this frame
  403. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  404. for obj_idx2, obj_temp_output_dict2 in temp_output_dict_per_obj.items():
  405. if obj_idx2 == obj_idx:
  406. continue
  407. current_out2 = obj_temp_output_dict2[storage_key].get(frame_idx, None)
  408. if current_out2 is not None and "pred_masks_video_res" in current_out2:
  409. current_out2["pred_masks_video_res"] = torch.where(
  410. mask_inputs_video_res,
  411. NO_OBJ_SCORE,
  412. current_out2["pred_masks_video_res"],
  413. )
  414. # Resize the output mask to the original video resolution
  415. obj_ids = inference_state["obj_ids"]
  416. consolidated_out = self._consolidate_temp_output_across_obj(
  417. inference_state,
  418. frame_idx,
  419. is_cond=is_cond,
  420. run_mem_encoder=False,
  421. consolidate_at_video_res=True,
  422. )
  423. _, video_res_masks = self._get_orig_video_res_output(
  424. inference_state, consolidated_out["pred_masks_video_res"]
  425. )
  426. low_res_masks = None # not needed by the demo
  427. return frame_idx, obj_ids, low_res_masks, video_res_masks
  428. def add_new_points(self, *args, **kwargs):
  429. """Deprecated method. Please use `add_new_points_or_box` instead."""
  430. return self.add_new_points_or_box(*args, **kwargs)
  431. def _get_orig_video_res_output(self, inference_state, any_res_masks):
  432. """
  433. Resize the object scores to the original video resolution (video_res_masks)
  434. and apply non-overlapping constraints for final output.
  435. """
  436. device = inference_state["device"]
  437. video_H = inference_state["video_height"]
  438. video_W = inference_state["video_width"]
  439. any_res_masks = any_res_masks.to(device, non_blocking=True)
  440. if any_res_masks.shape[-2:] == (video_H, video_W):
  441. video_res_masks = any_res_masks
  442. else:
  443. video_res_masks = torch.nn.functional.interpolate(
  444. any_res_masks,
  445. size=(video_H, video_W),
  446. mode="bilinear",
  447. align_corners=False,
  448. )
  449. if self.non_overlap_masks_for_output:
  450. video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
  451. # potentially fill holes in the predicted masks
  452. if self.fill_hole_area > 0:
  453. video_res_masks = fill_holes_in_mask_scores(
  454. video_res_masks, self.fill_hole_area
  455. )
  456. return any_res_masks, video_res_masks
  457. def _consolidate_temp_output_across_obj(
  458. self,
  459. inference_state,
  460. frame_idx,
  461. is_cond,
  462. run_mem_encoder,
  463. consolidate_at_video_res=False,
  464. ):
  465. """
  466. Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
  467. a frame into a single output for all objects, including
  468. 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
  469. `output_dict_per_obj` for this frame) or leave them as placeholder values
  470. (if they don't exist in `output_dict_per_obj` for this frame);
  471. 2) if specified, rerun memory encoder after apply non-overlapping constraints
  472. on the object scores.
  473. """
  474. batch_size = self._get_obj_num(inference_state)
  475. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  476. # Optionally, we allow consolidating the temporary outputs at the original
  477. # video resolution (to provide a better editing experience for mask prompts).
  478. if consolidate_at_video_res:
  479. assert not run_mem_encoder, "memory encoder cannot run at video resolution"
  480. consolidated_H = inference_state["video_height"]
  481. consolidated_W = inference_state["video_width"]
  482. consolidated_mask_key = "pred_masks_video_res"
  483. else:
  484. consolidated_H = consolidated_W = self.low_res_mask_size
  485. consolidated_mask_key = "pred_masks"
  486. # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
  487. # will be added when rerunning the memory encoder after applying non-overlapping
  488. # constraints to object scores. Its "pred_masks" are prefilled with a large
  489. # negative value (NO_OBJ_SCORE) to represent missing objects.
  490. consolidated_out = {
  491. "maskmem_features": None,
  492. "maskmem_pos_enc": None,
  493. consolidated_mask_key: torch.full(
  494. size=(batch_size, 1, consolidated_H, consolidated_W),
  495. fill_value=NO_OBJ_SCORE,
  496. dtype=torch.float32,
  497. device=inference_state["storage_device"],
  498. ),
  499. "obj_ptr": torch.full(
  500. size=(batch_size, self.hidden_dim),
  501. fill_value=NO_OBJ_SCORE,
  502. dtype=torch.float32,
  503. device=inference_state["device"],
  504. ),
  505. "object_score_logits": torch.full(
  506. size=(batch_size, 1),
  507. # default to 10.0 for object_score_logits, i.e. assuming the object is
  508. # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
  509. fill_value=10.0,
  510. dtype=torch.float32,
  511. device=inference_state["device"],
  512. ),
  513. }
  514. if self.use_memory_selection:
  515. consolidated_out["iou_score"] = torch.full(
  516. size=(batch_size, 1),
  517. fill_value=0.0,
  518. dtype=torch.float32,
  519. device=inference_state["device"],
  520. )
  521. empty_mask_ptr = None
  522. for obj_idx in range(batch_size):
  523. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  524. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  525. out = obj_temp_output_dict[storage_key].get(frame_idx, None)
  526. # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
  527. # we fall back and look up its previous output in "output_dict_per_obj".
  528. # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
  529. # "output_dict_per_obj" to find a previous output for this object.
  530. if out is None:
  531. out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
  532. if out is None:
  533. out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
  534. # If the object doesn't appear in "output_dict_per_obj" either, we skip it
  535. # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
  536. # placeholder above) and set its object pointer to be a dummy pointer.
  537. if out is None:
  538. # Fill in dummy object pointers for those objects without any inputs or
  539. # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
  540. # i.e. when we need to build the memory for tracking).
  541. if run_mem_encoder:
  542. if empty_mask_ptr is None:
  543. empty_mask_ptr = self._get_empty_mask_ptr(
  544. inference_state, frame_idx
  545. )
  546. # fill object pointer with a dummy pointer (based on an empty mask)
  547. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
  548. continue
  549. # Add the temporary object output mask to consolidated output mask
  550. # (use "pred_masks_video_res" if it's available)
  551. obj_mask = out.get("pred_masks_video_res", out["pred_masks"])
  552. consolidated_pred_masks = consolidated_out[consolidated_mask_key]
  553. if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
  554. consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
  555. else:
  556. # Resize first if temporary object mask has a different resolution
  557. is_downsampling = "pred_masks_video_res" in out
  558. resized_obj_mask = torch.nn.functional.interpolate(
  559. obj_mask,
  560. size=consolidated_pred_masks.shape[-2:],
  561. mode="bilinear",
  562. align_corners=False,
  563. antialias=is_downsampling, # use antialias for downsampling
  564. )
  565. consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
  566. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
  567. consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
  568. "object_score_logits"
  569. ]
  570. if self.use_memory_selection:
  571. consolidated_out["iou_score"][obj_idx : obj_idx + 1] = out["iou_score"]
  572. # Optionally, apply non-overlapping constraints on the consolidated scores
  573. # and rerun the memory encoder
  574. if run_mem_encoder:
  575. device = inference_state["device"]
  576. high_res_masks = torch.nn.functional.interpolate(
  577. consolidated_out["pred_masks"].to(device, non_blocking=True),
  578. size=(self.image_size, self.image_size),
  579. mode="bilinear",
  580. align_corners=False,
  581. )
  582. high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
  583. maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
  584. inference_state=inference_state,
  585. frame_idx=frame_idx,
  586. batch_size=batch_size,
  587. high_res_masks=high_res_masks,
  588. object_score_logits=consolidated_out["object_score_logits"],
  589. is_mask_from_pts=True, # these frames are what the user interacted with
  590. )
  591. consolidated_out["maskmem_features"] = maskmem_features
  592. consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
  593. return consolidated_out
  594. def _get_empty_mask_ptr(self, inference_state, frame_idx):
  595. """Get a dummy object pointer based on an empty mask on the current frame."""
  596. # A dummy (empty) mask with a single object
  597. batch_size = 1
  598. mask_inputs = torch.zeros(
  599. (batch_size, 1, self.image_size, self.image_size),
  600. dtype=torch.float32,
  601. device=inference_state["device"],
  602. )
  603. # Retrieve correct image features
  604. (
  605. image,
  606. _,
  607. current_vision_feats,
  608. current_vision_pos_embeds,
  609. feat_sizes,
  610. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  611. # Feed the empty mask and image feature above to get a dummy object pointer
  612. current_out = self.track_step(
  613. frame_idx=frame_idx,
  614. is_init_cond_frame=True,
  615. current_vision_feats=current_vision_feats,
  616. current_vision_pos_embeds=current_vision_pos_embeds,
  617. feat_sizes=feat_sizes,
  618. image=image,
  619. point_inputs=None,
  620. mask_inputs=mask_inputs,
  621. output_dict={
  622. "cond_frame_outputs": {},
  623. "non_cond_frame_outputs": {},
  624. },
  625. num_frames=inference_state["num_frames"],
  626. track_in_reverse=False,
  627. run_mem_encoder=False,
  628. prev_sam_mask_logits=None,
  629. )
  630. return current_out["obj_ptr"]
  631. @torch.inference_mode()
  632. def propagate_in_video_preflight(self, inference_state, run_mem_encoder=True):
  633. """Prepare inference_state and consolidate temporary outputs before tracking."""
  634. # Tracking has started and we don't allow adding new objects until session is reset.
  635. inference_state["tracking_has_started"] = True
  636. batch_size = self._get_obj_num(inference_state)
  637. # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
  638. # add them into "output_dict".
  639. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  640. output_dict = inference_state["output_dict"]
  641. # "consolidated_frame_inds" contains indices of those frames where consolidated
  642. # temporary outputs have been added (either in this call or any previous calls
  643. # to `propagate_in_video_preflight`).
  644. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  645. for is_cond in [False, True]:
  646. # Separately consolidate conditioning and non-conditioning temp outptus
  647. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  648. # Find all the frames that contain temporary outputs for any objects
  649. # (these should be the frames that have just received clicks for mask inputs
  650. # via `add_new_points` or `add_new_mask`)
  651. temp_frame_inds = set()
  652. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  653. temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
  654. consolidated_frame_inds[storage_key].update(temp_frame_inds)
  655. # consolidate the temprary output across all objects on this frame
  656. for frame_idx in temp_frame_inds:
  657. consolidated_out = self._consolidate_temp_output_across_obj(
  658. inference_state,
  659. frame_idx,
  660. is_cond=is_cond,
  661. run_mem_encoder=run_mem_encoder,
  662. )
  663. # merge them into "output_dict" and also create per-object slices
  664. output_dict[storage_key][frame_idx] = consolidated_out
  665. self._add_output_per_object(
  666. inference_state, frame_idx, consolidated_out, storage_key
  667. )
  668. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  669. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  670. )
  671. if clear_non_cond_mem:
  672. # clear non-conditioning memory of the surrounding frames
  673. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  674. # clear temporary outputs in `temp_output_dict_per_obj`
  675. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  676. obj_temp_output_dict[storage_key].clear()
  677. # edge case: if an output is added to "cond_frame_outputs", we remove any prior
  678. # output on the same frame in "non_cond_frame_outputs"
  679. for frame_idx in output_dict["cond_frame_outputs"]:
  680. output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  681. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  682. for frame_idx in obj_output_dict["cond_frame_outputs"]:
  683. obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  684. for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  685. assert frame_idx in output_dict["cond_frame_outputs"]
  686. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  687. # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
  688. # with either points or mask inputs (which should be true under a correct demo workflow).
  689. all_consolidated_frame_inds = (
  690. consolidated_frame_inds["cond_frame_outputs"]
  691. | consolidated_frame_inds["non_cond_frame_outputs"]
  692. )
  693. input_frames_inds = set()
  694. for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
  695. input_frames_inds.update(point_inputs_per_frame.keys())
  696. for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
  697. input_frames_inds.update(mask_inputs_per_frame.keys())
  698. assert all_consolidated_frame_inds == input_frames_inds
  699. # Record the first interacted frame index (for tracking start)
  700. if inference_state["first_ann_frame_idx"] is None:
  701. inference_state["first_ann_frame_idx"] = min(
  702. input_frames_inds, default=None
  703. )
  704. # In case `first_ann_frame_idx` is not in the conditioning frames (e.g. because
  705. # we cleared the input points on that frame), pick the first conditioning frame
  706. if (
  707. inference_state["first_ann_frame_idx"]
  708. not in output_dict["cond_frame_outputs"]
  709. ):
  710. inference_state["first_ann_frame_idx"] = min(
  711. output_dict["cond_frame_outputs"], default=None
  712. )
  713. def _get_processing_order(
  714. self, inference_state, start_frame_idx, max_frame_num_to_track, reverse
  715. ):
  716. num_frames = inference_state["num_frames"]
  717. # set start index, end index, and processing order
  718. if self.always_start_from_first_ann_frame:
  719. # in this case, we always start tracking from the frame where we receive
  720. # the initial annotation and ignore the provided start_frame_idx
  721. start_frame_idx = inference_state["first_ann_frame_idx"]
  722. if start_frame_idx is None:
  723. # default: start from the earliest frame with input points
  724. start_frame_idx = min(inference_state["output_dict"]["cond_frame_outputs"])
  725. if max_frame_num_to_track is None:
  726. # default: track all the frames in the video
  727. max_frame_num_to_track = num_frames
  728. if reverse:
  729. end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
  730. if start_frame_idx > 0:
  731. processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
  732. else:
  733. # this is the edge case where we start from frame 0 and track in reverse order;
  734. # in this case, we track a single frame (frame 0)
  735. processing_order = [0]
  736. else:
  737. end_frame_idx = min(
  738. start_frame_idx + max_frame_num_to_track, num_frames - 1
  739. )
  740. processing_order = range(start_frame_idx, end_frame_idx + 1)
  741. return processing_order
  742. @torch.inference_mode()
  743. def propagate_in_video(
  744. self,
  745. inference_state,
  746. start_frame_idx,
  747. max_frame_num_to_track,
  748. reverse,
  749. tqdm_disable=False,
  750. obj_ids=None,
  751. run_mem_encoder=True,
  752. propagate_preflight=False,
  753. ):
  754. """Propagate the input points across frames to track in the entire video."""
  755. if propagate_preflight:
  756. self.propagate_in_video_preflight(inference_state)
  757. # NOTE: This is a copy from the parent class, except that we return object scores as well.
  758. output_dict = inference_state["output_dict"]
  759. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  760. if obj_ids is not None:
  761. raise NotImplementedError(
  762. "Per-object tracking yet for batched inference if not implemented."
  763. )
  764. obj_ids = inference_state["obj_ids"]
  765. batch_size = self._get_obj_num(inference_state)
  766. if len(output_dict["cond_frame_outputs"]) == 0:
  767. raise RuntimeError("No points are provided; please add points first")
  768. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  769. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  770. )
  771. processing_order = self._get_processing_order(
  772. inference_state,
  773. start_frame_idx,
  774. max_frame_num_to_track,
  775. reverse,
  776. )
  777. for frame_idx in tqdm(
  778. processing_order, desc="propagate in video", disable=tqdm_disable
  779. ):
  780. # We skip those frames already in consolidated outputs (these are frames
  781. # that received input clicks or mask). Note that we cannot directly run
  782. # batched forward on them via `_run_single_frame_inference` because the
  783. # number of clicks on each object might be different.
  784. if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  785. storage_key = "cond_frame_outputs"
  786. current_out = output_dict[storage_key][frame_idx]
  787. pred_masks = current_out["pred_masks"]
  788. obj_scores = current_out["object_score_logits"]
  789. if clear_non_cond_mem:
  790. # clear non-conditioning memory of the surrounding frames
  791. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  792. elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
  793. storage_key = "non_cond_frame_outputs"
  794. current_out = output_dict[storage_key][frame_idx]
  795. pred_masks = current_out["pred_masks"]
  796. obj_scores = current_out["object_score_logits"]
  797. else:
  798. storage_key = "non_cond_frame_outputs"
  799. current_out, pred_masks = self._run_single_frame_inference(
  800. inference_state=inference_state,
  801. output_dict=output_dict,
  802. frame_idx=frame_idx,
  803. batch_size=batch_size,
  804. is_init_cond_frame=False,
  805. point_inputs=None,
  806. mask_inputs=None,
  807. reverse=reverse,
  808. run_mem_encoder=run_mem_encoder,
  809. )
  810. obj_scores = current_out["object_score_logits"]
  811. output_dict[storage_key][frame_idx] = current_out
  812. # Create slices of per-object outputs for subsequent interaction with each
  813. # individual object after tracking.
  814. self._add_output_per_object(
  815. inference_state, frame_idx, current_out, storage_key
  816. )
  817. inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
  818. # Resize the output mask to the original video resolution (we directly use
  819. # the mask scores on GPU for output to avoid any CPU conversion in between)
  820. low_res_masks, video_res_masks = self._get_orig_video_res_output(
  821. inference_state, pred_masks
  822. )
  823. yield frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores
  824. def _add_output_per_object(
  825. self, inference_state, frame_idx, current_out, storage_key
  826. ):
  827. """
  828. Split a multi-object output into per-object output slices and add them into
  829. `output_dict_per_obj`. The resulting slices share the same tensor storage.
  830. """
  831. maskmem_features = current_out["maskmem_features"]
  832. assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
  833. maskmem_pos_enc = current_out["maskmem_pos_enc"]
  834. assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
  835. output_dict_per_obj = inference_state["output_dict_per_obj"]
  836. for obj_idx, obj_output_dict in output_dict_per_obj.items():
  837. obj_slice = slice(obj_idx, obj_idx + 1)
  838. obj_out = {
  839. "maskmem_features": None,
  840. "maskmem_pos_enc": None,
  841. "pred_masks": current_out["pred_masks"][obj_slice],
  842. "obj_ptr": current_out["obj_ptr"][obj_slice],
  843. "object_score_logits": current_out["object_score_logits"][obj_slice],
  844. }
  845. if self.use_memory_selection:
  846. obj_out["iou_score"] = current_out["iou_score"][obj_slice]
  847. if maskmem_features is not None:
  848. obj_out["maskmem_features"] = maskmem_features[obj_slice]
  849. if maskmem_pos_enc is not None:
  850. obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
  851. obj_output_dict[storage_key][frame_idx] = obj_out
  852. @torch.inference_mode()
  853. def clear_all_points_in_frame(
  854. self, inference_state, frame_idx, obj_id, need_output=True
  855. ):
  856. """Remove all input points or mask in a specific frame for a given object."""
  857. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  858. # Clear the conditioning information on the given frame
  859. inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
  860. inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
  861. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  862. temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
  863. temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
  864. # Check and see if there are still any inputs left on this frame
  865. batch_size = self._get_obj_num(inference_state)
  866. frame_has_input = False
  867. for obj_idx2 in range(batch_size):
  868. if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
  869. frame_has_input = True
  870. break
  871. if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
  872. frame_has_input = True
  873. break
  874. # If this frame has no remaining inputs for any objects, we further clear its
  875. # conditioning frame status
  876. if not frame_has_input:
  877. output_dict = inference_state["output_dict"]
  878. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  879. consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
  880. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  881. # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
  882. out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
  883. if out is not None:
  884. # The frame is not a conditioning frame anymore since it's not receiving inputs,
  885. # so we "downgrade" its output (if exists) to a non-conditioning frame output.
  886. output_dict["non_cond_frame_outputs"][frame_idx] = out
  887. inference_state["frames_already_tracked"].pop(frame_idx, None)
  888. # Similarly, do it for the sliced output on each object.
  889. for obj_idx2 in range(batch_size):
  890. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
  891. obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
  892. if obj_out is not None:
  893. obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
  894. # If all the conditioning frames have been removed, we also clear the tracking outputs
  895. if len(output_dict["cond_frame_outputs"]) == 0:
  896. self._reset_tracking_results(inference_state)
  897. if not need_output:
  898. return
  899. # Finally, output updated masks per object (after removing the inputs above)
  900. obj_ids = inference_state["obj_ids"]
  901. is_cond = any(
  902. frame_idx in obj_temp_output_dict["cond_frame_outputs"]
  903. for obj_temp_output_dict in temp_output_dict_per_obj.values()
  904. )
  905. consolidated_out = self._consolidate_temp_output_across_obj(
  906. inference_state,
  907. frame_idx,
  908. is_cond=is_cond,
  909. run_mem_encoder=False,
  910. consolidate_at_video_res=True,
  911. )
  912. _, video_res_masks = self._get_orig_video_res_output(
  913. inference_state, consolidated_out["pred_masks_video_res"]
  914. )
  915. low_res_masks = None # not needed by the demo
  916. return frame_idx, obj_ids, low_res_masks, video_res_masks
  917. @torch.inference_mode()
  918. def clear_all_points_in_video(self, inference_state):
  919. """Remove all input points or mask in all frames throughout the video."""
  920. self._reset_tracking_results(inference_state)
  921. # Remove all object ids
  922. inference_state["obj_id_to_idx"].clear()
  923. inference_state["obj_idx_to_id"].clear()
  924. inference_state["obj_ids"].clear()
  925. inference_state["point_inputs_per_obj"].clear()
  926. inference_state["mask_inputs_per_obj"].clear()
  927. inference_state["output_dict_per_obj"].clear()
  928. inference_state["temp_output_dict_per_obj"].clear()
  929. def _reset_tracking_results(self, inference_state):
  930. """Reset all tracking inputs and results across the videos."""
  931. for v in inference_state["point_inputs_per_obj"].values():
  932. v.clear()
  933. for v in inference_state["mask_inputs_per_obj"].values():
  934. v.clear()
  935. for v in inference_state["output_dict_per_obj"].values():
  936. v["cond_frame_outputs"].clear()
  937. v["non_cond_frame_outputs"].clear()
  938. for v in inference_state["temp_output_dict_per_obj"].values():
  939. v["cond_frame_outputs"].clear()
  940. v["non_cond_frame_outputs"].clear()
  941. inference_state["output_dict"]["cond_frame_outputs"].clear()
  942. inference_state["output_dict"]["non_cond_frame_outputs"].clear()
  943. inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
  944. inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
  945. inference_state["tracking_has_started"] = False
  946. inference_state["frames_already_tracked"].clear()
  947. inference_state["first_ann_frame_idx"] = None
  948. def _get_image_feature(self, inference_state, frame_idx, batch_size):
  949. """Compute the image features on a given frame."""
  950. # Look up in the cache
  951. image, backbone_out = inference_state["cached_features"].get(
  952. frame_idx, (None, None)
  953. )
  954. if backbone_out is None:
  955. if self.backbone is None:
  956. raise RuntimeError(
  957. f"Image features for frame {frame_idx} are not cached. "
  958. "Please run inference on this frame first."
  959. )
  960. else:
  961. # Cache miss -- we will run inference on a single image
  962. image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
  963. backbone_out = self.forward_image(image)
  964. # Cache the most recent frame's feature (for repeated interactions with
  965. # a frame; we can use an LRU cache for more frames in the future).
  966. inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
  967. if "tracker_backbone_out" in backbone_out:
  968. backbone_out = backbone_out["tracker_backbone_out"] # get backbone output
  969. # expand the features to have the same dimension as the number of objects
  970. expanded_image = image.expand(batch_size, -1, -1, -1)
  971. expanded_backbone_out = {
  972. "backbone_fpn": backbone_out["backbone_fpn"].copy(),
  973. "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
  974. }
  975. for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
  976. feat = feat.expand(batch_size, -1, -1, -1)
  977. expanded_backbone_out["backbone_fpn"][i] = feat
  978. for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
  979. pos = pos.expand(batch_size, -1, -1, -1)
  980. expanded_backbone_out["vision_pos_enc"][i] = pos
  981. features = self._prepare_backbone_features(expanded_backbone_out)
  982. features = (expanded_image,) + features
  983. return features
  984. def _run_single_frame_inference(
  985. self,
  986. inference_state,
  987. output_dict,
  988. frame_idx,
  989. batch_size,
  990. is_init_cond_frame,
  991. point_inputs,
  992. mask_inputs,
  993. reverse,
  994. run_mem_encoder,
  995. prev_sam_mask_logits=None,
  996. use_prev_mem_frame=True,
  997. ):
  998. """Run tracking on a single frame based on current inputs and previous memory."""
  999. # Retrieve correct image features
  1000. (
  1001. image,
  1002. _,
  1003. current_vision_feats,
  1004. current_vision_pos_embeds,
  1005. feat_sizes,
  1006. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  1007. # point and mask should not appear as input simultaneously on the same frame
  1008. assert point_inputs is None or mask_inputs is None
  1009. current_out = self.track_step(
  1010. frame_idx=frame_idx,
  1011. is_init_cond_frame=is_init_cond_frame,
  1012. current_vision_feats=current_vision_feats,
  1013. current_vision_pos_embeds=current_vision_pos_embeds,
  1014. feat_sizes=feat_sizes,
  1015. image=image,
  1016. point_inputs=point_inputs,
  1017. mask_inputs=mask_inputs,
  1018. output_dict=output_dict,
  1019. num_frames=inference_state["num_frames"],
  1020. track_in_reverse=reverse,
  1021. run_mem_encoder=run_mem_encoder,
  1022. prev_sam_mask_logits=prev_sam_mask_logits,
  1023. use_prev_mem_frame=use_prev_mem_frame,
  1024. )
  1025. # optionally offload the output to CPU memory to save GPU space
  1026. storage_device = inference_state["storage_device"]
  1027. maskmem_features = current_out["maskmem_features"]
  1028. if maskmem_features is not None:
  1029. maskmem_features = maskmem_features.to(torch.bfloat16)
  1030. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  1031. pred_masks_gpu = current_out["pred_masks"]
  1032. pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
  1033. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  1034. maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
  1035. # object pointer is a small tensor, so we always keep it on GPU memory for fast access
  1036. obj_ptr = current_out["obj_ptr"]
  1037. object_score_logits = current_out["object_score_logits"]
  1038. # make a compact version of this frame's output to reduce the state size
  1039. compact_current_out = {
  1040. "maskmem_features": maskmem_features,
  1041. "maskmem_pos_enc": maskmem_pos_enc,
  1042. "pred_masks": pred_masks,
  1043. "obj_ptr": obj_ptr,
  1044. "object_score_logits": object_score_logits,
  1045. }
  1046. if self.use_memory_selection:
  1047. compact_current_out["iou_score"] = current_out["iou_score"]
  1048. compact_current_out["eff_iou_score"] = current_out["eff_iou_score"]
  1049. return compact_current_out, pred_masks_gpu
  1050. def _run_memory_encoder(
  1051. self,
  1052. inference_state,
  1053. frame_idx,
  1054. batch_size,
  1055. high_res_masks,
  1056. object_score_logits,
  1057. is_mask_from_pts,
  1058. ):
  1059. """
  1060. Run the memory encoder on `high_res_masks`. This is usually after applying
  1061. non-overlapping constraints to object scores. Since their scores changed, their
  1062. memory also need to be computed again with the memory encoder.
  1063. """
  1064. # Retrieve correct image features
  1065. image, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
  1066. inference_state, frame_idx, batch_size
  1067. )
  1068. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  1069. image=image,
  1070. current_vision_feats=current_vision_feats,
  1071. feat_sizes=feat_sizes,
  1072. pred_masks_high_res=high_res_masks,
  1073. object_score_logits=object_score_logits,
  1074. is_mask_from_pts=is_mask_from_pts,
  1075. )
  1076. # optionally offload the output to CPU memory to save GPU space
  1077. storage_device = inference_state["storage_device"]
  1078. maskmem_features = maskmem_features.to(torch.bfloat16)
  1079. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  1080. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  1081. maskmem_pos_enc = self._get_maskmem_pos_enc(
  1082. inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
  1083. )
  1084. return maskmem_features, maskmem_pos_enc
  1085. def _get_maskmem_pos_enc(self, inference_state, current_out):
  1086. """
  1087. `maskmem_pos_enc` is the same across frames and objects, so we cache it as
  1088. a constant in the inference session to reduce session storage size.
  1089. """
  1090. model_constants = inference_state["constants"]
  1091. # "out_maskmem_pos_enc" should be either a list of tensors or None
  1092. out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
  1093. if out_maskmem_pos_enc is not None:
  1094. if "maskmem_pos_enc" not in model_constants:
  1095. assert isinstance(out_maskmem_pos_enc, list)
  1096. # only take the slice for one object, since it's same across objects
  1097. maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
  1098. model_constants["maskmem_pos_enc"] = maskmem_pos_enc
  1099. else:
  1100. maskmem_pos_enc = model_constants["maskmem_pos_enc"]
  1101. # expand the cached maskmem_pos_enc to the actual batch size
  1102. batch_size = out_maskmem_pos_enc[0].size(0)
  1103. expanded_maskmem_pos_enc = [
  1104. x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
  1105. ]
  1106. else:
  1107. expanded_maskmem_pos_enc = None
  1108. return expanded_maskmem_pos_enc
  1109. @torch.inference_mode()
  1110. def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
  1111. """
  1112. Remove an object id from the tracking state. If strict is True, we check whether
  1113. the object id actually exists and raise an error if it doesn't exist.
  1114. """
  1115. old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
  1116. updated_frames = []
  1117. # Check whether this object_id to remove actually exists and possibly raise an error.
  1118. if old_obj_idx_to_rm is None:
  1119. if not strict:
  1120. return inference_state["obj_ids"], updated_frames
  1121. raise RuntimeError(
  1122. f"Cannot remove object id {obj_id} as it doesn't exist. "
  1123. f"All existing object ids: {inference_state['obj_ids']}."
  1124. )
  1125. # If this is the only remaining object id, we simply reset the state.
  1126. if len(inference_state["obj_id_to_idx"]) == 1:
  1127. self.clear_all_points_in_video(inference_state)
  1128. return inference_state["obj_ids"], updated_frames
  1129. # There are still remaining objects after removing this object id. In this case,
  1130. # we need to delete the object storage from inference state tensors.
  1131. # Step 0: clear the input on those frames where this object id has point or mask input
  1132. # (note that this step is required as it might downgrade conditioning frames to
  1133. # non-conditioning ones)
  1134. obj_input_frames_inds = set()
  1135. obj_input_frames_inds.update(
  1136. inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
  1137. )
  1138. obj_input_frames_inds.update(
  1139. inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
  1140. )
  1141. for frame_idx in obj_input_frames_inds:
  1142. self.clear_all_points_in_frame(
  1143. inference_state, frame_idx, obj_id, need_output=False
  1144. )
  1145. # Step 1: Update the object id mapping (note that it must be done after Step 0,
  1146. # since Step 0 still requires the old object id mappings in inference_state)
  1147. old_obj_ids = inference_state["obj_ids"]
  1148. old_obj_inds = list(range(len(old_obj_ids)))
  1149. remain_old_obj_inds = old_obj_inds.copy()
  1150. remain_old_obj_inds.remove(old_obj_idx_to_rm)
  1151. new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
  1152. new_obj_inds = list(range(len(new_obj_ids)))
  1153. # build new mappings
  1154. old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
  1155. inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
  1156. inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
  1157. inference_state["obj_ids"] = new_obj_ids
  1158. # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
  1159. # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
  1160. # it's already handled in Step 0)
  1161. def _map_keys(container):
  1162. new_kvs = []
  1163. for k in old_obj_inds:
  1164. v = container.pop(k)
  1165. if k in old_idx_to_new_idx:
  1166. new_kvs.append((old_idx_to_new_idx[k], v))
  1167. container.update(new_kvs)
  1168. _map_keys(inference_state["point_inputs_per_obj"])
  1169. _map_keys(inference_state["mask_inputs_per_obj"])
  1170. _map_keys(inference_state["output_dict_per_obj"])
  1171. _map_keys(inference_state["temp_output_dict_per_obj"])
  1172. # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
  1173. def _slice_state(output_dict, storage_key):
  1174. for frame_idx, out in output_dict[storage_key].items():
  1175. out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
  1176. out["maskmem_pos_enc"] = [
  1177. x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
  1178. ]
  1179. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  1180. out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
  1181. out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
  1182. out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
  1183. out["object_score_logits"] = out["object_score_logits"][
  1184. remain_old_obj_inds
  1185. ]
  1186. if self.use_memory_selection:
  1187. out["iou_score"] = out["iou_score"][remain_old_obj_inds]
  1188. out["eff_iou_score"] = self.cal_mem_score(
  1189. out["object_score_logits"], out["iou_score"]
  1190. ) # recalculate the memory frame score
  1191. # also update the per-object slices
  1192. self._add_output_per_object(
  1193. inference_state, frame_idx, out, storage_key
  1194. )
  1195. _slice_state(inference_state["output_dict"], "cond_frame_outputs")
  1196. _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
  1197. # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
  1198. # could show an updated mask for objects previously occluded by the object being removed
  1199. if need_output:
  1200. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  1201. for frame_idx in obj_input_frames_inds:
  1202. is_cond = any(
  1203. frame_idx in obj_temp_output_dict["cond_frame_outputs"]
  1204. for obj_temp_output_dict in temp_output_dict_per_obj.values()
  1205. )
  1206. consolidated_out = self._consolidate_temp_output_across_obj(
  1207. inference_state,
  1208. frame_idx,
  1209. is_cond=is_cond,
  1210. run_mem_encoder=False,
  1211. consolidate_at_video_res=True,
  1212. )
  1213. _, video_res_masks = self._get_orig_video_res_output(
  1214. inference_state, consolidated_out["pred_masks_video_res"]
  1215. )
  1216. updated_frames.append((frame_idx, video_res_masks))
  1217. return inference_state["obj_ids"], updated_frames
  1218. def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
  1219. """
  1220. Remove the non-conditioning memory around the input frame. When users provide
  1221. correction clicks, the surrounding frames' non-conditioning memories can still
  1222. contain outdated object appearance information and could confuse the model.
  1223. This method clears those non-conditioning memories surrounding the interacted
  1224. frame to avoid giving the model both old and new information about the object.
  1225. """
  1226. r = self.memory_temporal_stride_for_eval
  1227. frame_idx_begin = frame_idx - r * self.num_maskmem
  1228. frame_idx_end = frame_idx + r * self.num_maskmem
  1229. batch_size = self._get_obj_num(inference_state)
  1230. for obj_idx in range(batch_size):
  1231. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  1232. non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
  1233. for t in range(frame_idx_begin, frame_idx_end + 1):
  1234. non_cond_frame_outputs.pop(t, None)
  1235. def _suppress_shrinked_masks(
  1236. self, pred_masks, new_pred_masks, shrink_threshold=0.3
  1237. ):
  1238. area_before = (pred_masks > 0).sum(dim=(-1, -2))
  1239. area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
  1240. area_before = torch.clamp(area_before, min=1.0)
  1241. area_ratio = area_after / area_before
  1242. keep = area_ratio >= shrink_threshold
  1243. keep_mask = keep[..., None, None].expand_as(pred_masks)
  1244. pred_masks_after = torch.where(
  1245. keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0)
  1246. )
  1247. return pred_masks_after
  1248. def _suppress_object_pw_area_shrinkage(self, pred_masks):
  1249. """
  1250. This function suppresses masks that shrink in area after applying pixelwise non-overlapping constriants.
  1251. Note that the final output can still be overlapping.
  1252. """
  1253. # Apply pixel-wise non-overlapping constraint based on mask scores
  1254. pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints(
  1255. pred_masks
  1256. )
  1257. # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
  1258. # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
  1259. pred_masks = self._suppress_shrinked_masks(
  1260. pred_masks, pixel_level_non_overlapping_masks
  1261. )
  1262. return pred_masks
  1263. def _apply_object_wise_non_overlapping_constraints(
  1264. self, pred_masks, obj_scores, background_value=-10.0
  1265. ):
  1266. """
  1267. Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region)
  1268. """
  1269. # Replace pixel scores with object scores
  1270. pred_masks_single_score = torch.where(
  1271. pred_masks > 0, obj_scores[..., None, None], background_value
  1272. )
  1273. # Apply pixel-wise non-overlapping constraint based on mask scores
  1274. pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints(
  1275. pred_masks_single_score
  1276. )
  1277. # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region
  1278. pred_masks = torch.where(
  1279. pixel_level_non_overlapping_masks > 0,
  1280. pred_masks,
  1281. torch.clamp(pred_masks, max=background_value),
  1282. )
  1283. return pred_masks