sam2_video_predictor.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import warnings
  6. from collections import OrderedDict
  7. import torch
  8. from tqdm import tqdm
  9. from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
  10. from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
  11. class SAM2VideoPredictor(SAM2Base):
  12. """The predictor class to handle user interactions and manage inference states."""
  13. def __init__(
  14. self,
  15. fill_hole_area=0,
  16. # whether to apply non-overlapping constraints on the output object masks
  17. non_overlap_masks=False,
  18. # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
  19. # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
  20. clear_non_cond_mem_around_input=False,
  21. # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
  22. clear_non_cond_mem_for_multi_obj=False,
  23. # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
  24. # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
  25. add_all_frames_to_correct_as_cond=False,
  26. **kwargs,
  27. ):
  28. super().__init__(**kwargs)
  29. self.fill_hole_area = fill_hole_area
  30. self.non_overlap_masks = non_overlap_masks
  31. self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
  32. self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
  33. self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
  34. @torch.inference_mode()
  35. def init_state(
  36. self,
  37. video_path,
  38. offload_video_to_cpu=False,
  39. offload_state_to_cpu=False,
  40. async_loading_frames=False,
  41. ):
  42. """Initialize an inference state."""
  43. compute_device = self.device # device of the model
  44. images, video_height, video_width = load_video_frames(
  45. video_path=video_path,
  46. image_size=self.image_size,
  47. offload_video_to_cpu=offload_video_to_cpu,
  48. async_loading_frames=async_loading_frames,
  49. compute_device=compute_device,
  50. )
  51. inference_state = {}
  52. inference_state["images"] = images
  53. inference_state["num_frames"] = len(images)
  54. # whether to offload the video frames to CPU memory
  55. # turning on this option saves the GPU memory with only a very small overhead
  56. inference_state["offload_video_to_cpu"] = offload_video_to_cpu
  57. # whether to offload the inference state to CPU memory
  58. # turning on this option saves the GPU memory at the cost of a lower tracking fps
  59. # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
  60. # and from 24 to 21 when tracking two objects)
  61. inference_state["offload_state_to_cpu"] = offload_state_to_cpu
  62. # the original video height and width, used for resizing final output scores
  63. inference_state["video_height"] = video_height
  64. inference_state["video_width"] = video_width
  65. inference_state["device"] = compute_device
  66. if offload_state_to_cpu:
  67. inference_state["storage_device"] = torch.device("cpu")
  68. else:
  69. inference_state["storage_device"] = compute_device
  70. # inputs on each frame
  71. inference_state["point_inputs_per_obj"] = {}
  72. inference_state["mask_inputs_per_obj"] = {}
  73. # visual features on a small number of recently visited frames for quick interactions
  74. inference_state["cached_features"] = {}
  75. # values that don't change across frames (so we only need to hold one copy of them)
  76. inference_state["constants"] = {}
  77. # mapping between client-side object id and model-side object index
  78. inference_state["obj_id_to_idx"] = OrderedDict()
  79. inference_state["obj_idx_to_id"] = OrderedDict()
  80. inference_state["obj_ids"] = []
  81. # A storage to hold the model's tracking results and states on each frame
  82. inference_state["output_dict"] = {
  83. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  84. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  85. }
  86. # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
  87. inference_state["output_dict_per_obj"] = {}
  88. # A temporary storage to hold new outputs when user interact with a frame
  89. # to add clicks or mask (it's merged into "output_dict" before propagation starts)
  90. inference_state["temp_output_dict_per_obj"] = {}
  91. # Frames that already holds consolidated outputs from click or mask inputs
  92. # (we directly use their consolidated outputs during tracking)
  93. inference_state["consolidated_frame_inds"] = {
  94. "cond_frame_outputs": set(), # set containing frame indices
  95. "non_cond_frame_outputs": set(), # set containing frame indices
  96. }
  97. # metadata for each tracking frame (e.g. which direction it's tracked)
  98. inference_state["tracking_has_started"] = False
  99. inference_state["frames_already_tracked"] = {}
  100. # Warm up the visual backbone and cache the image feature on frame 0
  101. self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
  102. return inference_state
  103. @classmethod
  104. def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
  105. """
  106. Load a pretrained model from the Hugging Face hub.
  107. Arguments:
  108. model_id (str): The Hugging Face repository ID.
  109. **kwargs: Additional arguments to pass to the model constructor.
  110. Returns:
  111. (SAM2VideoPredictor): The loaded model.
  112. """
  113. from sam2.build_sam import build_sam2_video_predictor_hf
  114. sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
  115. return sam_model
  116. def _obj_id_to_idx(self, inference_state, obj_id):
  117. """Map client-side object id to model-side object index."""
  118. obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
  119. if obj_idx is not None:
  120. return obj_idx
  121. # This is a new object id not sent to the server before. We only allow adding
  122. # new objects *before* the tracking starts.
  123. allow_new_object = not inference_state["tracking_has_started"]
  124. if allow_new_object:
  125. # get the next object slot
  126. obj_idx = len(inference_state["obj_id_to_idx"])
  127. inference_state["obj_id_to_idx"][obj_id] = obj_idx
  128. inference_state["obj_idx_to_id"][obj_idx] = obj_id
  129. inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
  130. # set up input and output structures for this object
  131. inference_state["point_inputs_per_obj"][obj_idx] = {}
  132. inference_state["mask_inputs_per_obj"][obj_idx] = {}
  133. inference_state["output_dict_per_obj"][obj_idx] = {
  134. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  135. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  136. }
  137. inference_state["temp_output_dict_per_obj"][obj_idx] = {
  138. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  139. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  140. }
  141. return obj_idx
  142. else:
  143. raise RuntimeError(
  144. f"Cannot add new object id {obj_id} after tracking starts. "
  145. f"All existing object ids: {inference_state['obj_ids']}. "
  146. f"Please call 'reset_state' to restart from scratch."
  147. )
  148. def _obj_idx_to_id(self, inference_state, obj_idx):
  149. """Map model-side object index to client-side object id."""
  150. return inference_state["obj_idx_to_id"][obj_idx]
  151. def _get_obj_num(self, inference_state):
  152. """Get the total number of unique object ids received so far in this session."""
  153. return len(inference_state["obj_idx_to_id"])
  154. @torch.inference_mode()
  155. def add_new_points_or_box(
  156. self,
  157. inference_state,
  158. frame_idx,
  159. obj_id,
  160. points=None,
  161. labels=None,
  162. clear_old_points=True,
  163. normalize_coords=True,
  164. box=None,
  165. ):
  166. """Add new points to a frame."""
  167. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  168. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  169. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  170. if (points is not None) != (labels is not None):
  171. raise ValueError("points and labels must be provided together")
  172. if points is None and box is None:
  173. raise ValueError("at least one of points or box must be provided as input")
  174. if points is None:
  175. points = torch.zeros(0, 2, dtype=torch.float32)
  176. elif not isinstance(points, torch.Tensor):
  177. points = torch.tensor(points, dtype=torch.float32)
  178. if labels is None:
  179. labels = torch.zeros(0, dtype=torch.int32)
  180. elif not isinstance(labels, torch.Tensor):
  181. labels = torch.tensor(labels, dtype=torch.int32)
  182. if points.dim() == 2:
  183. points = points.unsqueeze(0) # add batch dimension
  184. if labels.dim() == 1:
  185. labels = labels.unsqueeze(0) # add batch dimension
  186. # If `box` is provided, we add it as the first two points with labels 2 and 3
  187. # along with the user-provided points (consistent with how SAM 2 is trained).
  188. if box is not None:
  189. if not clear_old_points:
  190. raise ValueError(
  191. "cannot add box without clearing old points, since "
  192. "box prompt must be provided before any point prompt "
  193. "(please use clear_old_points=True instead)"
  194. )
  195. if inference_state["tracking_has_started"]:
  196. warnings.warn(
  197. "You are adding a box after tracking starts. SAM 2 may not always be "
  198. "able to incorporate a box prompt for *refinement*. If you intend to "
  199. "use box prompt as an *initial* input before tracking, please call "
  200. "'reset_state' on the inference state to restart from scratch.",
  201. category=UserWarning,
  202. stacklevel=2,
  203. )
  204. if not isinstance(box, torch.Tensor):
  205. box = torch.tensor(box, dtype=torch.float32, device=points.device)
  206. box_coords = box.reshape(1, 2, 2)
  207. box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
  208. box_labels = box_labels.reshape(1, 2)
  209. points = torch.cat([box_coords, points], dim=1)
  210. labels = torch.cat([box_labels, labels], dim=1)
  211. if normalize_coords:
  212. video_H = inference_state["video_height"]
  213. video_W = inference_state["video_width"]
  214. points = points / torch.tensor([video_W, video_H]).to(points.device)
  215. # scale the (normalized) coordinates by the model's internal image size
  216. points = points * self.image_size
  217. points = points.to(inference_state["device"])
  218. labels = labels.to(inference_state["device"])
  219. if not clear_old_points:
  220. point_inputs = point_inputs_per_frame.get(frame_idx, None)
  221. else:
  222. point_inputs = None
  223. point_inputs = concat_points(point_inputs, points, labels)
  224. point_inputs_per_frame[frame_idx] = point_inputs
  225. mask_inputs_per_frame.pop(frame_idx, None)
  226. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  227. # frame, meaning that the inputs points are to generate segments on this frame without
  228. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  229. # the input points will be used to correct the already tracked masks.
  230. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  231. # whether to track in reverse time order
  232. if is_init_cond_frame:
  233. reverse = False
  234. else:
  235. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  236. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  237. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  238. # Add a frame to conditioning output if it's an initial conditioning frame or
  239. # if the model sees all frames receiving clicks/mask as conditioning frames.
  240. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  241. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  242. # Get any previously predicted mask logits on this object and feed it along with
  243. # the new clicks into the SAM mask decoder.
  244. prev_sam_mask_logits = None
  245. # lookup temporary output dict first, which contains the most recent output
  246. # (if not found, then lookup conditioning and non-conditioning frame output)
  247. prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
  248. if prev_out is None:
  249. prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
  250. if prev_out is None:
  251. prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
  252. if prev_out is not None and prev_out["pred_masks"] is not None:
  253. device = inference_state["device"]
  254. prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
  255. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
  256. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
  257. current_out, _ = self._run_single_frame_inference(
  258. inference_state=inference_state,
  259. output_dict=obj_output_dict, # run on the slice of a single object
  260. frame_idx=frame_idx,
  261. batch_size=1, # run on the slice of a single object
  262. is_init_cond_frame=is_init_cond_frame,
  263. point_inputs=point_inputs,
  264. mask_inputs=None,
  265. reverse=reverse,
  266. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  267. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  268. # allows us to enforce non-overlapping constraints on all objects before encoding
  269. # them into memory.
  270. run_mem_encoder=False,
  271. prev_sam_mask_logits=prev_sam_mask_logits,
  272. )
  273. # Add the output to the output dict (to be used as future memory)
  274. obj_temp_output_dict[storage_key][frame_idx] = current_out
  275. # Resize the output mask to the original video resolution
  276. obj_ids = inference_state["obj_ids"]
  277. consolidated_out = self._consolidate_temp_output_across_obj(
  278. inference_state,
  279. frame_idx,
  280. is_cond=is_cond,
  281. run_mem_encoder=False,
  282. consolidate_at_video_res=True,
  283. )
  284. _, video_res_masks = self._get_orig_video_res_output(
  285. inference_state, consolidated_out["pred_masks_video_res"]
  286. )
  287. return frame_idx, obj_ids, video_res_masks
  288. def add_new_points(self, *args, **kwargs):
  289. """Deprecated method. Please use `add_new_points_or_box` instead."""
  290. return self.add_new_points_or_box(*args, **kwargs)
  291. @torch.inference_mode()
  292. def add_new_mask(
  293. self,
  294. inference_state,
  295. frame_idx,
  296. obj_id,
  297. mask,
  298. ):
  299. """Add new mask to a frame."""
  300. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  301. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  302. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  303. if not isinstance(mask, torch.Tensor):
  304. mask = torch.tensor(mask, dtype=torch.bool)
  305. assert mask.dim() == 2
  306. mask_H, mask_W = mask.shape
  307. mask_inputs_orig = mask[None, None] # add batch and channel dimension
  308. mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
  309. # resize the mask if it doesn't match the model's image size
  310. if mask_H != self.image_size or mask_W != self.image_size:
  311. mask_inputs = torch.nn.functional.interpolate(
  312. mask_inputs_orig,
  313. size=(self.image_size, self.image_size),
  314. align_corners=False,
  315. mode="bilinear",
  316. antialias=True, # use antialias for downsampling
  317. )
  318. mask_inputs = (mask_inputs >= 0.5).float()
  319. else:
  320. mask_inputs = mask_inputs_orig
  321. mask_inputs_per_frame[frame_idx] = mask_inputs
  322. point_inputs_per_frame.pop(frame_idx, None)
  323. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  324. # frame, meaning that the inputs points are to generate segments on this frame without
  325. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  326. # the input points will be used to correct the already tracked masks.
  327. is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
  328. # whether to track in reverse time order
  329. if is_init_cond_frame:
  330. reverse = False
  331. else:
  332. reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
  333. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  334. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  335. # Add a frame to conditioning output if it's an initial conditioning frame or
  336. # if the model sees all frames receiving clicks/mask as conditioning frames.
  337. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  338. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  339. current_out, _ = self._run_single_frame_inference(
  340. inference_state=inference_state,
  341. output_dict=obj_output_dict, # run on the slice of a single object
  342. frame_idx=frame_idx,
  343. batch_size=1, # run on the slice of a single object
  344. is_init_cond_frame=is_init_cond_frame,
  345. point_inputs=None,
  346. mask_inputs=mask_inputs,
  347. reverse=reverse,
  348. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  349. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  350. # allows us to enforce non-overlapping constraints on all objects before encoding
  351. # them into memory.
  352. run_mem_encoder=False,
  353. )
  354. # Add the output to the output dict (to be used as future memory)
  355. obj_temp_output_dict[storage_key][frame_idx] = current_out
  356. # Resize the output mask to the original video resolution
  357. obj_ids = inference_state["obj_ids"]
  358. consolidated_out = self._consolidate_temp_output_across_obj(
  359. inference_state,
  360. frame_idx,
  361. is_cond=is_cond,
  362. run_mem_encoder=False,
  363. consolidate_at_video_res=True,
  364. )
  365. _, video_res_masks = self._get_orig_video_res_output(
  366. inference_state, consolidated_out["pred_masks_video_res"]
  367. )
  368. return frame_idx, obj_ids, video_res_masks
  369. def _get_orig_video_res_output(self, inference_state, any_res_masks):
  370. """
  371. Resize the object scores to the original video resolution (video_res_masks)
  372. and apply non-overlapping constraints for final output.
  373. """
  374. device = inference_state["device"]
  375. video_H = inference_state["video_height"]
  376. video_W = inference_state["video_width"]
  377. any_res_masks = any_res_masks.to(device, non_blocking=True)
  378. if any_res_masks.shape[-2:] == (video_H, video_W):
  379. video_res_masks = any_res_masks
  380. else:
  381. video_res_masks = torch.nn.functional.interpolate(
  382. any_res_masks,
  383. size=(video_H, video_W),
  384. mode="bilinear",
  385. align_corners=False,
  386. )
  387. if self.non_overlap_masks:
  388. video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
  389. return any_res_masks, video_res_masks
  390. def _consolidate_temp_output_across_obj(
  391. self,
  392. inference_state,
  393. frame_idx,
  394. is_cond,
  395. run_mem_encoder,
  396. consolidate_at_video_res=False,
  397. ):
  398. """
  399. Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
  400. a frame into a single output for all objects, including
  401. 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
  402. `output_dict_per_obj` for this frame) or leave them as placeholder values
  403. (if they don't exist in `output_dict_per_obj` for this frame);
  404. 2) if specified, rerun memory encoder after apply non-overlapping constraints
  405. on the object scores.
  406. """
  407. batch_size = self._get_obj_num(inference_state)
  408. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  409. # Optionally, we allow consolidating the temporary outputs at the original
  410. # video resolution (to provide a better editing experience for mask prompts).
  411. if consolidate_at_video_res:
  412. assert not run_mem_encoder, "memory encoder cannot run at video resolution"
  413. consolidated_H = inference_state["video_height"]
  414. consolidated_W = inference_state["video_width"]
  415. consolidated_mask_key = "pred_masks_video_res"
  416. else:
  417. consolidated_H = consolidated_W = self.image_size // 4
  418. consolidated_mask_key = "pred_masks"
  419. # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
  420. # will be added when rerunning the memory encoder after applying non-overlapping
  421. # constraints to object scores. Its "pred_masks" are prefilled with a large
  422. # negative value (NO_OBJ_SCORE) to represent missing objects.
  423. consolidated_out = {
  424. "maskmem_features": None,
  425. "maskmem_pos_enc": None,
  426. consolidated_mask_key: torch.full(
  427. size=(batch_size, 1, consolidated_H, consolidated_W),
  428. fill_value=NO_OBJ_SCORE,
  429. dtype=torch.float32,
  430. device=inference_state["storage_device"],
  431. ),
  432. "obj_ptr": torch.full(
  433. size=(batch_size, self.hidden_dim),
  434. fill_value=NO_OBJ_SCORE,
  435. dtype=torch.float32,
  436. device=inference_state["device"],
  437. ),
  438. "object_score_logits": torch.full(
  439. size=(batch_size, 1),
  440. # default to 10.0 for object_score_logits, i.e. assuming the object is
  441. # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
  442. fill_value=10.0,
  443. dtype=torch.float32,
  444. device=inference_state["device"],
  445. ),
  446. }
  447. empty_mask_ptr = None
  448. for obj_idx in range(batch_size):
  449. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  450. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  451. out = obj_temp_output_dict[storage_key].get(frame_idx, None)
  452. # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
  453. # we fall back and look up its previous output in "output_dict_per_obj".
  454. # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
  455. # "output_dict_per_obj" to find a previous output for this object.
  456. if out is None:
  457. out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
  458. if out is None:
  459. out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
  460. # If the object doesn't appear in "output_dict_per_obj" either, we skip it
  461. # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
  462. # placeholder above) and set its object pointer to be a dummy pointer.
  463. if out is None:
  464. # Fill in dummy object pointers for those objects without any inputs or
  465. # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
  466. # i.e. when we need to build the memory for tracking).
  467. if run_mem_encoder:
  468. if empty_mask_ptr is None:
  469. empty_mask_ptr = self._get_empty_mask_ptr(
  470. inference_state, frame_idx
  471. )
  472. # fill object pointer with a dummy pointer (based on an empty mask)
  473. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
  474. continue
  475. # Add the temporary object output mask to consolidated output mask
  476. obj_mask = out["pred_masks"]
  477. consolidated_pred_masks = consolidated_out[consolidated_mask_key]
  478. if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
  479. consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
  480. else:
  481. # Resize first if temporary object mask has a different resolution
  482. resized_obj_mask = torch.nn.functional.interpolate(
  483. obj_mask,
  484. size=consolidated_pred_masks.shape[-2:],
  485. mode="bilinear",
  486. align_corners=False,
  487. )
  488. consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
  489. consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
  490. consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
  491. "object_score_logits"
  492. ]
  493. # Optionally, apply non-overlapping constraints on the consolidated scores
  494. # and rerun the memory encoder
  495. if run_mem_encoder:
  496. device = inference_state["device"]
  497. high_res_masks = torch.nn.functional.interpolate(
  498. consolidated_out["pred_masks"].to(device, non_blocking=True),
  499. size=(self.image_size, self.image_size),
  500. mode="bilinear",
  501. align_corners=False,
  502. )
  503. if self.non_overlap_masks_for_mem_enc:
  504. high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
  505. maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
  506. inference_state=inference_state,
  507. frame_idx=frame_idx,
  508. batch_size=batch_size,
  509. high_res_masks=high_res_masks,
  510. object_score_logits=consolidated_out["object_score_logits"],
  511. is_mask_from_pts=True, # these frames are what the user interacted with
  512. )
  513. consolidated_out["maskmem_features"] = maskmem_features
  514. consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
  515. return consolidated_out
  516. def _get_empty_mask_ptr(self, inference_state, frame_idx):
  517. """Get a dummy object pointer based on an empty mask on the current frame."""
  518. # A dummy (empty) mask with a single object
  519. batch_size = 1
  520. mask_inputs = torch.zeros(
  521. (batch_size, 1, self.image_size, self.image_size),
  522. dtype=torch.float32,
  523. device=inference_state["device"],
  524. )
  525. # Retrieve correct image features
  526. (
  527. _,
  528. _,
  529. current_vision_feats,
  530. current_vision_pos_embeds,
  531. feat_sizes,
  532. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  533. # Feed the empty mask and image feature above to get a dummy object pointer
  534. current_out = self.track_step(
  535. frame_idx=frame_idx,
  536. is_init_cond_frame=True,
  537. current_vision_feats=current_vision_feats,
  538. current_vision_pos_embeds=current_vision_pos_embeds,
  539. feat_sizes=feat_sizes,
  540. point_inputs=None,
  541. mask_inputs=mask_inputs,
  542. output_dict={},
  543. num_frames=inference_state["num_frames"],
  544. track_in_reverse=False,
  545. run_mem_encoder=False,
  546. prev_sam_mask_logits=None,
  547. )
  548. return current_out["obj_ptr"]
  549. @torch.inference_mode()
  550. def propagate_in_video_preflight(self, inference_state):
  551. """Prepare inference_state and consolidate temporary outputs before tracking."""
  552. # Tracking has started and we don't allow adding new objects until session is reset.
  553. inference_state["tracking_has_started"] = True
  554. batch_size = self._get_obj_num(inference_state)
  555. # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
  556. # add them into "output_dict".
  557. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  558. output_dict = inference_state["output_dict"]
  559. # "consolidated_frame_inds" contains indices of those frames where consolidated
  560. # temporary outputs have been added (either in this call or any previous calls
  561. # to `propagate_in_video_preflight`).
  562. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  563. for is_cond in [False, True]:
  564. # Separately consolidate conditioning and non-conditioning temp outputs
  565. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  566. # Find all the frames that contain temporary outputs for any objects
  567. # (these should be the frames that have just received clicks for mask inputs
  568. # via `add_new_points_or_box` or `add_new_mask`)
  569. temp_frame_inds = set()
  570. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  571. temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
  572. consolidated_frame_inds[storage_key].update(temp_frame_inds)
  573. # consolidate the temporary output across all objects on this frame
  574. for frame_idx in temp_frame_inds:
  575. consolidated_out = self._consolidate_temp_output_across_obj(
  576. inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
  577. )
  578. # merge them into "output_dict" and also create per-object slices
  579. output_dict[storage_key][frame_idx] = consolidated_out
  580. self._add_output_per_object(
  581. inference_state, frame_idx, consolidated_out, storage_key
  582. )
  583. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  584. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  585. )
  586. if clear_non_cond_mem:
  587. # clear non-conditioning memory of the surrounding frames
  588. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  589. # clear temporary outputs in `temp_output_dict_per_obj`
  590. for obj_temp_output_dict in temp_output_dict_per_obj.values():
  591. obj_temp_output_dict[storage_key].clear()
  592. # edge case: if an output is added to "cond_frame_outputs", we remove any prior
  593. # output on the same frame in "non_cond_frame_outputs"
  594. for frame_idx in output_dict["cond_frame_outputs"]:
  595. output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  596. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  597. for frame_idx in obj_output_dict["cond_frame_outputs"]:
  598. obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  599. for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  600. assert frame_idx in output_dict["cond_frame_outputs"]
  601. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  602. # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
  603. # with either points or mask inputs (which should be true under a correct workflow).
  604. all_consolidated_frame_inds = (
  605. consolidated_frame_inds["cond_frame_outputs"]
  606. | consolidated_frame_inds["non_cond_frame_outputs"]
  607. )
  608. input_frames_inds = set()
  609. for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
  610. input_frames_inds.update(point_inputs_per_frame.keys())
  611. for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
  612. input_frames_inds.update(mask_inputs_per_frame.keys())
  613. assert all_consolidated_frame_inds == input_frames_inds
  614. @torch.inference_mode()
  615. def propagate_in_video(
  616. self,
  617. inference_state,
  618. start_frame_idx=None,
  619. max_frame_num_to_track=None,
  620. reverse=False,
  621. ):
  622. """Propagate the input points across frames to track in the entire video."""
  623. self.propagate_in_video_preflight(inference_state)
  624. output_dict = inference_state["output_dict"]
  625. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  626. obj_ids = inference_state["obj_ids"]
  627. num_frames = inference_state["num_frames"]
  628. batch_size = self._get_obj_num(inference_state)
  629. if len(output_dict["cond_frame_outputs"]) == 0:
  630. raise RuntimeError("No points are provided; please add points first")
  631. clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
  632. self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
  633. )
  634. # set start index, end index, and processing order
  635. if start_frame_idx is None:
  636. # default: start from the earliest frame with input points
  637. start_frame_idx = min(output_dict["cond_frame_outputs"])
  638. if max_frame_num_to_track is None:
  639. # default: track all the frames in the video
  640. max_frame_num_to_track = num_frames
  641. if reverse:
  642. end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
  643. if start_frame_idx > 0:
  644. processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
  645. else:
  646. processing_order = [] # skip reverse tracking if starting from frame 0
  647. else:
  648. end_frame_idx = min(
  649. start_frame_idx + max_frame_num_to_track, num_frames - 1
  650. )
  651. processing_order = range(start_frame_idx, end_frame_idx + 1)
  652. for frame_idx in tqdm(processing_order, desc="propagate in video"):
  653. # We skip those frames already in consolidated outputs (these are frames
  654. # that received input clicks or mask). Note that we cannot directly run
  655. # batched forward on them via `_run_single_frame_inference` because the
  656. # number of clicks on each object might be different.
  657. if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
  658. storage_key = "cond_frame_outputs"
  659. current_out = output_dict[storage_key][frame_idx]
  660. pred_masks = current_out["pred_masks"]
  661. if clear_non_cond_mem:
  662. # clear non-conditioning memory of the surrounding frames
  663. self._clear_non_cond_mem_around_input(inference_state, frame_idx)
  664. elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
  665. storage_key = "non_cond_frame_outputs"
  666. current_out = output_dict[storage_key][frame_idx]
  667. pred_masks = current_out["pred_masks"]
  668. else:
  669. storage_key = "non_cond_frame_outputs"
  670. current_out, pred_masks = self._run_single_frame_inference(
  671. inference_state=inference_state,
  672. output_dict=output_dict,
  673. frame_idx=frame_idx,
  674. batch_size=batch_size,
  675. is_init_cond_frame=False,
  676. point_inputs=None,
  677. mask_inputs=None,
  678. reverse=reverse,
  679. run_mem_encoder=True,
  680. )
  681. output_dict[storage_key][frame_idx] = current_out
  682. # Create slices of per-object outputs for subsequent interaction with each
  683. # individual object after tracking.
  684. self._add_output_per_object(
  685. inference_state, frame_idx, current_out, storage_key
  686. )
  687. inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
  688. # Resize the output mask to the original video resolution (we directly use
  689. # the mask scores on GPU for output to avoid any CPU conversion in between)
  690. _, video_res_masks = self._get_orig_video_res_output(
  691. inference_state, pred_masks
  692. )
  693. yield frame_idx, obj_ids, video_res_masks
  694. def _add_output_per_object(
  695. self, inference_state, frame_idx, current_out, storage_key
  696. ):
  697. """
  698. Split a multi-object output into per-object output slices and add them into
  699. `output_dict_per_obj`. The resulting slices share the same tensor storage.
  700. """
  701. maskmem_features = current_out["maskmem_features"]
  702. assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
  703. maskmem_pos_enc = current_out["maskmem_pos_enc"]
  704. assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
  705. output_dict_per_obj = inference_state["output_dict_per_obj"]
  706. for obj_idx, obj_output_dict in output_dict_per_obj.items():
  707. obj_slice = slice(obj_idx, obj_idx + 1)
  708. obj_out = {
  709. "maskmem_features": None,
  710. "maskmem_pos_enc": None,
  711. "pred_masks": current_out["pred_masks"][obj_slice],
  712. "obj_ptr": current_out["obj_ptr"][obj_slice],
  713. "object_score_logits": current_out["object_score_logits"][obj_slice],
  714. }
  715. if maskmem_features is not None:
  716. obj_out["maskmem_features"] = maskmem_features[obj_slice]
  717. if maskmem_pos_enc is not None:
  718. obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
  719. obj_output_dict[storage_key][frame_idx] = obj_out
  720. @torch.inference_mode()
  721. def clear_all_prompts_in_frame(
  722. self, inference_state, frame_idx, obj_id, need_output=True
  723. ):
  724. """Remove all input points or mask in a specific frame for a given object."""
  725. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  726. # Clear the conditioning information on the given frame
  727. inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
  728. inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
  729. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  730. temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
  731. temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
  732. # Check and see if there are still any inputs left on this frame
  733. batch_size = self._get_obj_num(inference_state)
  734. frame_has_input = False
  735. for obj_idx2 in range(batch_size):
  736. if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
  737. frame_has_input = True
  738. break
  739. if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
  740. frame_has_input = True
  741. break
  742. # If this frame has no remaining inputs for any objects, we further clear its
  743. # conditioning frame status
  744. if not frame_has_input:
  745. output_dict = inference_state["output_dict"]
  746. consolidated_frame_inds = inference_state["consolidated_frame_inds"]
  747. consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
  748. consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
  749. # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
  750. out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
  751. if out is not None:
  752. # The frame is not a conditioning frame anymore since it's not receiving inputs,
  753. # so we "downgrade" its output (if exists) to a non-conditioning frame output.
  754. output_dict["non_cond_frame_outputs"][frame_idx] = out
  755. inference_state["frames_already_tracked"].pop(frame_idx, None)
  756. # Similarly, do it for the sliced output on each object.
  757. for obj_idx2 in range(batch_size):
  758. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
  759. obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
  760. if obj_out is not None:
  761. obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
  762. # If all the conditioning frames have been removed, we also clear the tracking outputs
  763. if len(output_dict["cond_frame_outputs"]) == 0:
  764. self._reset_tracking_results(inference_state)
  765. if not need_output:
  766. return
  767. # Finally, output updated masks per object (after removing the inputs above)
  768. obj_ids = inference_state["obj_ids"]
  769. is_cond = any(
  770. frame_idx in obj_temp_output_dict["cond_frame_outputs"]
  771. for obj_temp_output_dict in temp_output_dict_per_obj.values()
  772. )
  773. consolidated_out = self._consolidate_temp_output_across_obj(
  774. inference_state,
  775. frame_idx,
  776. is_cond=is_cond,
  777. run_mem_encoder=False,
  778. consolidate_at_video_res=True,
  779. )
  780. _, video_res_masks = self._get_orig_video_res_output(
  781. inference_state, consolidated_out["pred_masks_video_res"]
  782. )
  783. return frame_idx, obj_ids, video_res_masks
  784. @torch.inference_mode()
  785. def reset_state(self, inference_state):
  786. """Remove all input points or mask in all frames throughout the video."""
  787. self._reset_tracking_results(inference_state)
  788. # Remove all object ids
  789. inference_state["obj_id_to_idx"].clear()
  790. inference_state["obj_idx_to_id"].clear()
  791. inference_state["obj_ids"].clear()
  792. inference_state["point_inputs_per_obj"].clear()
  793. inference_state["mask_inputs_per_obj"].clear()
  794. inference_state["output_dict_per_obj"].clear()
  795. inference_state["temp_output_dict_per_obj"].clear()
  796. def _reset_tracking_results(self, inference_state):
  797. """Reset all tracking inputs and results across the videos."""
  798. for v in inference_state["point_inputs_per_obj"].values():
  799. v.clear()
  800. for v in inference_state["mask_inputs_per_obj"].values():
  801. v.clear()
  802. for v in inference_state["output_dict_per_obj"].values():
  803. v["cond_frame_outputs"].clear()
  804. v["non_cond_frame_outputs"].clear()
  805. for v in inference_state["temp_output_dict_per_obj"].values():
  806. v["cond_frame_outputs"].clear()
  807. v["non_cond_frame_outputs"].clear()
  808. inference_state["output_dict"]["cond_frame_outputs"].clear()
  809. inference_state["output_dict"]["non_cond_frame_outputs"].clear()
  810. inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
  811. inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
  812. inference_state["tracking_has_started"] = False
  813. inference_state["frames_already_tracked"].clear()
  814. def _get_image_feature(self, inference_state, frame_idx, batch_size):
  815. """Compute the image features on a given frame."""
  816. # Look up in the cache first
  817. image, backbone_out = inference_state["cached_features"].get(
  818. frame_idx, (None, None)
  819. )
  820. if backbone_out is None:
  821. # Cache miss -- we will run inference on a single image
  822. device = inference_state["device"]
  823. image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
  824. backbone_out = self.forward_image(image)
  825. # Cache the most recent frame's feature (for repeated interactions with
  826. # a frame; we can use an LRU cache for more frames in the future).
  827. inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
  828. # expand the features to have the same dimension as the number of objects
  829. expanded_image = image.expand(batch_size, -1, -1, -1)
  830. expanded_backbone_out = {
  831. "backbone_fpn": backbone_out["backbone_fpn"].copy(),
  832. "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
  833. }
  834. for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
  835. expanded_backbone_out["backbone_fpn"][i] = feat.expand(
  836. batch_size, -1, -1, -1
  837. )
  838. for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
  839. pos = pos.expand(batch_size, -1, -1, -1)
  840. expanded_backbone_out["vision_pos_enc"][i] = pos
  841. features = self._prepare_backbone_features(expanded_backbone_out)
  842. features = (expanded_image,) + features
  843. return features
  844. def _run_single_frame_inference(
  845. self,
  846. inference_state,
  847. output_dict,
  848. frame_idx,
  849. batch_size,
  850. is_init_cond_frame,
  851. point_inputs,
  852. mask_inputs,
  853. reverse,
  854. run_mem_encoder,
  855. prev_sam_mask_logits=None,
  856. ):
  857. """Run tracking on a single frame based on current inputs and previous memory."""
  858. # Retrieve correct image features
  859. (
  860. _,
  861. _,
  862. current_vision_feats,
  863. current_vision_pos_embeds,
  864. feat_sizes,
  865. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  866. # point and mask should not appear as input simultaneously on the same frame
  867. assert point_inputs is None or mask_inputs is None
  868. current_out = self.track_step(
  869. frame_idx=frame_idx,
  870. is_init_cond_frame=is_init_cond_frame,
  871. current_vision_feats=current_vision_feats,
  872. current_vision_pos_embeds=current_vision_pos_embeds,
  873. feat_sizes=feat_sizes,
  874. point_inputs=point_inputs,
  875. mask_inputs=mask_inputs,
  876. output_dict=output_dict,
  877. num_frames=inference_state["num_frames"],
  878. track_in_reverse=reverse,
  879. run_mem_encoder=run_mem_encoder,
  880. prev_sam_mask_logits=prev_sam_mask_logits,
  881. )
  882. # optionally offload the output to CPU memory to save GPU space
  883. storage_device = inference_state["storage_device"]
  884. maskmem_features = current_out["maskmem_features"]
  885. if maskmem_features is not None:
  886. maskmem_features = maskmem_features.to(torch.bfloat16)
  887. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  888. pred_masks_gpu = current_out["pred_masks"]
  889. # potentially fill holes in the predicted masks
  890. if self.fill_hole_area > 0:
  891. pred_masks_gpu = fill_holes_in_mask_scores(
  892. pred_masks_gpu, self.fill_hole_area
  893. )
  894. pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
  895. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  896. maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
  897. # object pointer is a small tensor, so we always keep it on GPU memory for fast access
  898. obj_ptr = current_out["obj_ptr"]
  899. object_score_logits = current_out["object_score_logits"]
  900. # make a compact version of this frame's output to reduce the state size
  901. compact_current_out = {
  902. "maskmem_features": maskmem_features,
  903. "maskmem_pos_enc": maskmem_pos_enc,
  904. "pred_masks": pred_masks,
  905. "obj_ptr": obj_ptr,
  906. "object_score_logits": object_score_logits,
  907. }
  908. return compact_current_out, pred_masks_gpu
  909. def _run_memory_encoder(
  910. self,
  911. inference_state,
  912. frame_idx,
  913. batch_size,
  914. high_res_masks,
  915. object_score_logits,
  916. is_mask_from_pts,
  917. ):
  918. """
  919. Run the memory encoder on `high_res_masks`. This is usually after applying
  920. non-overlapping constraints to object scores. Since their scores changed, their
  921. memory also need to be computed again with the memory encoder.
  922. """
  923. # Retrieve correct image features
  924. _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
  925. inference_state, frame_idx, batch_size
  926. )
  927. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  928. current_vision_feats=current_vision_feats,
  929. feat_sizes=feat_sizes,
  930. pred_masks_high_res=high_res_masks,
  931. object_score_logits=object_score_logits,
  932. is_mask_from_pts=is_mask_from_pts,
  933. )
  934. # optionally offload the output to CPU memory to save GPU space
  935. storage_device = inference_state["storage_device"]
  936. maskmem_features = maskmem_features.to(torch.bfloat16)
  937. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  938. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  939. maskmem_pos_enc = self._get_maskmem_pos_enc(
  940. inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
  941. )
  942. return maskmem_features, maskmem_pos_enc
  943. def _get_maskmem_pos_enc(self, inference_state, current_out):
  944. """
  945. `maskmem_pos_enc` is the same across frames and objects, so we cache it as
  946. a constant in the inference session to reduce session storage size.
  947. """
  948. model_constants = inference_state["constants"]
  949. # "out_maskmem_pos_enc" should be either a list of tensors or None
  950. out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
  951. if out_maskmem_pos_enc is not None:
  952. if "maskmem_pos_enc" not in model_constants:
  953. assert isinstance(out_maskmem_pos_enc, list)
  954. # only take the slice for one object, since it's same across objects
  955. maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
  956. model_constants["maskmem_pos_enc"] = maskmem_pos_enc
  957. else:
  958. maskmem_pos_enc = model_constants["maskmem_pos_enc"]
  959. # expand the cached maskmem_pos_enc to the actual batch size
  960. batch_size = out_maskmem_pos_enc[0].size(0)
  961. expanded_maskmem_pos_enc = [
  962. x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
  963. ]
  964. else:
  965. expanded_maskmem_pos_enc = None
  966. return expanded_maskmem_pos_enc
  967. @torch.inference_mode()
  968. def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
  969. """
  970. Remove an object id from the tracking state. If strict is True, we check whether
  971. the object id actually exists and raise an error if it doesn't exist.
  972. """
  973. old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
  974. updated_frames = []
  975. # Check whether this object_id to remove actually exists and possibly raise an error.
  976. if old_obj_idx_to_rm is None:
  977. if not strict:
  978. return inference_state["obj_ids"], updated_frames
  979. raise RuntimeError(
  980. f"Cannot remove object id {obj_id} as it doesn't exist. "
  981. f"All existing object ids: {inference_state['obj_ids']}."
  982. )
  983. # If this is the only remaining object id, we simply reset the state.
  984. if len(inference_state["obj_id_to_idx"]) == 1:
  985. self.reset_state(inference_state)
  986. return inference_state["obj_ids"], updated_frames
  987. # There are still remaining objects after removing this object id. In this case,
  988. # we need to delete the object storage from inference state tensors.
  989. # Step 0: clear the input on those frames where this object id has point or mask input
  990. # (note that this step is required as it might downgrade conditioning frames to
  991. # non-conditioning ones)
  992. obj_input_frames_inds = set()
  993. obj_input_frames_inds.update(
  994. inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
  995. )
  996. obj_input_frames_inds.update(
  997. inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
  998. )
  999. for frame_idx in obj_input_frames_inds:
  1000. self.clear_all_prompts_in_frame(
  1001. inference_state, frame_idx, obj_id, need_output=False
  1002. )
  1003. # Step 1: Update the object id mapping (note that it must be done after Step 0,
  1004. # since Step 0 still requires the old object id mappings in inference_state)
  1005. old_obj_ids = inference_state["obj_ids"]
  1006. old_obj_inds = list(range(len(old_obj_ids)))
  1007. remain_old_obj_inds = old_obj_inds.copy()
  1008. remain_old_obj_inds.remove(old_obj_idx_to_rm)
  1009. new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
  1010. new_obj_inds = list(range(len(new_obj_ids)))
  1011. # build new mappings
  1012. old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
  1013. inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
  1014. inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
  1015. inference_state["obj_ids"] = new_obj_ids
  1016. # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
  1017. # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
  1018. # it's already handled in Step 0)
  1019. def _map_keys(container):
  1020. new_kvs = []
  1021. for k in old_obj_inds:
  1022. v = container.pop(k)
  1023. if k in old_idx_to_new_idx:
  1024. new_kvs.append((old_idx_to_new_idx[k], v))
  1025. container.update(new_kvs)
  1026. _map_keys(inference_state["point_inputs_per_obj"])
  1027. _map_keys(inference_state["mask_inputs_per_obj"])
  1028. _map_keys(inference_state["output_dict_per_obj"])
  1029. _map_keys(inference_state["temp_output_dict_per_obj"])
  1030. # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
  1031. def _slice_state(output_dict, storage_key):
  1032. for frame_idx, out in output_dict[storage_key].items():
  1033. out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
  1034. out["maskmem_pos_enc"] = [
  1035. x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
  1036. ]
  1037. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  1038. out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
  1039. out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
  1040. out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
  1041. out["object_score_logits"] = out["object_score_logits"][
  1042. remain_old_obj_inds
  1043. ]
  1044. # also update the per-object slices
  1045. self._add_output_per_object(
  1046. inference_state, frame_idx, out, storage_key
  1047. )
  1048. _slice_state(inference_state["output_dict"], "cond_frame_outputs")
  1049. _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
  1050. # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
  1051. # could show an updated mask for objects previously occluded by the object being removed
  1052. if need_output:
  1053. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  1054. for frame_idx in obj_input_frames_inds:
  1055. is_cond = any(
  1056. frame_idx in obj_temp_output_dict["cond_frame_outputs"]
  1057. for obj_temp_output_dict in temp_output_dict_per_obj.values()
  1058. )
  1059. consolidated_out = self._consolidate_temp_output_across_obj(
  1060. inference_state,
  1061. frame_idx,
  1062. is_cond=is_cond,
  1063. run_mem_encoder=False,
  1064. consolidate_at_video_res=True,
  1065. )
  1066. _, video_res_masks = self._get_orig_video_res_output(
  1067. inference_state, consolidated_out["pred_masks_video_res"]
  1068. )
  1069. updated_frames.append((frame_idx, video_res_masks))
  1070. return inference_state["obj_ids"], updated_frames
  1071. def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
  1072. """
  1073. Remove the non-conditioning memory around the input frame. When users provide
  1074. correction clicks, the surrounding frames' non-conditioning memories can still
  1075. contain outdated object appearance information and could confuse the model.
  1076. This method clears those non-conditioning memories surrounding the interacted
  1077. frame to avoid giving the model both old and new information about the object.
  1078. """
  1079. r = self.memory_temporal_stride_for_eval
  1080. frame_idx_begin = frame_idx - r * self.num_maskmem
  1081. frame_idx_end = frame_idx + r * self.num_maskmem
  1082. output_dict = inference_state["output_dict"]
  1083. non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
  1084. for t in range(frame_idx_begin, frame_idx_end + 1):
  1085. non_cond_frame_outputs.pop(t, None)
  1086. for obj_output_dict in inference_state["output_dict_per_obj"].values():
  1087. obj_output_dict["non_cond_frame_outputs"].pop(t, None)