sam2_video_predictor.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import warnings
  6. from collections import OrderedDict
  7. import torch
  8. import torch.nn.functional as F
  9. from tqdm import tqdm
  10. from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
  11. from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
  12. class SAM2VideoPredictor(SAM2Base):
  13. """The predictor class to handle user interactions and manage inference states."""
  14. def __init__(
  15. self,
  16. fill_hole_area=0,
  17. # whether to apply non-overlapping constraints on the output object masks
  18. non_overlap_masks=False,
  19. # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
  20. # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
  21. clear_non_cond_mem_around_input=False,
  22. # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
  23. # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
  24. add_all_frames_to_correct_as_cond=False,
  25. **kwargs,
  26. ):
  27. super().__init__(**kwargs)
  28. self.fill_hole_area = fill_hole_area
  29. self.non_overlap_masks = non_overlap_masks
  30. self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
  31. self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
  32. @torch.inference_mode()
  33. def init_state(
  34. self,
  35. video_path,
  36. offload_video_to_cpu=False,
  37. offload_state_to_cpu=False,
  38. async_loading_frames=False,
  39. ):
  40. """Initialize an inference state."""
  41. compute_device = self.device # device of the model
  42. images, video_height, video_width = load_video_frames(
  43. video_path=video_path,
  44. image_size=self.image_size,
  45. offload_video_to_cpu=offload_video_to_cpu,
  46. async_loading_frames=async_loading_frames,
  47. compute_device=compute_device,
  48. )
  49. inference_state = {}
  50. inference_state["images"] = images
  51. inference_state["num_frames"] = len(images)
  52. # whether to offload the video frames to CPU memory
  53. # turning on this option saves the GPU memory with only a very small overhead
  54. inference_state["offload_video_to_cpu"] = offload_video_to_cpu
  55. # whether to offload the inference state to CPU memory
  56. # turning on this option saves the GPU memory at the cost of a lower tracking fps
  57. # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
  58. # and from 24 to 21 when tracking two objects)
  59. inference_state["offload_state_to_cpu"] = offload_state_to_cpu
  60. # the original video height and width, used for resizing final output scores
  61. inference_state["video_height"] = video_height
  62. inference_state["video_width"] = video_width
  63. inference_state["device"] = compute_device
  64. if offload_state_to_cpu:
  65. inference_state["storage_device"] = torch.device("cpu")
  66. else:
  67. inference_state["storage_device"] = compute_device
  68. # inputs on each frame
  69. inference_state["point_inputs_per_obj"] = {}
  70. inference_state["mask_inputs_per_obj"] = {}
  71. # visual features on a small number of recently visited frames for quick interactions
  72. inference_state["cached_features"] = {}
  73. # values that don't change across frames (so we only need to hold one copy of them)
  74. inference_state["constants"] = {}
  75. # mapping between client-side object id and model-side object index
  76. inference_state["obj_id_to_idx"] = OrderedDict()
  77. inference_state["obj_idx_to_id"] = OrderedDict()
  78. inference_state["obj_ids"] = []
  79. # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
  80. inference_state["output_dict_per_obj"] = {}
  81. # A temporary storage to hold new outputs when user interact with a frame
  82. # to add clicks or mask (it's merged into "output_dict" before propagation starts)
  83. inference_state["temp_output_dict_per_obj"] = {}
  84. # Frames that already holds consolidated outputs from click or mask inputs
  85. # (we directly use their consolidated outputs during tracking)
  86. # metadata for each tracking frame (e.g. which direction it's tracked)
  87. inference_state["frames_tracked_per_obj"] = {}
  88. # Warm up the visual backbone and cache the image feature on frame 0
  89. self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
  90. return inference_state
  91. @classmethod
  92. def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
  93. """
  94. Load a pretrained model from the Hugging Face hub.
  95. Arguments:
  96. model_id (str): The Hugging Face repository ID.
  97. **kwargs: Additional arguments to pass to the model constructor.
  98. Returns:
  99. (SAM2VideoPredictor): The loaded model.
  100. """
  101. from sam2.build_sam import build_sam2_video_predictor_hf
  102. sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
  103. return sam_model
  104. def _obj_id_to_idx(self, inference_state, obj_id):
  105. """Map client-side object id to model-side object index."""
  106. obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
  107. if obj_idx is not None:
  108. return obj_idx
  109. # We always allow adding new objects (including after tracking starts).
  110. allow_new_object = True
  111. if allow_new_object:
  112. # get the next object slot
  113. obj_idx = len(inference_state["obj_id_to_idx"])
  114. inference_state["obj_id_to_idx"][obj_id] = obj_idx
  115. inference_state["obj_idx_to_id"][obj_idx] = obj_id
  116. inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
  117. # set up input and output structures for this object
  118. inference_state["point_inputs_per_obj"][obj_idx] = {}
  119. inference_state["mask_inputs_per_obj"][obj_idx] = {}
  120. inference_state["output_dict_per_obj"][obj_idx] = {
  121. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  122. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  123. }
  124. inference_state["temp_output_dict_per_obj"][obj_idx] = {
  125. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  126. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  127. }
  128. inference_state["frames_tracked_per_obj"][obj_idx] = {}
  129. return obj_idx
  130. else:
  131. raise RuntimeError(
  132. f"Cannot add new object id {obj_id} after tracking starts. "
  133. f"All existing object ids: {inference_state['obj_ids']}. "
  134. f"Please call 'reset_state' to restart from scratch."
  135. )
  136. def _obj_idx_to_id(self, inference_state, obj_idx):
  137. """Map model-side object index to client-side object id."""
  138. return inference_state["obj_idx_to_id"][obj_idx]
  139. def _get_obj_num(self, inference_state):
  140. """Get the total number of unique object ids received so far in this session."""
  141. return len(inference_state["obj_idx_to_id"])
  142. @torch.inference_mode()
  143. def add_new_points_or_box(
  144. self,
  145. inference_state,
  146. frame_idx,
  147. obj_id,
  148. points=None,
  149. labels=None,
  150. clear_old_points=True,
  151. normalize_coords=True,
  152. box=None,
  153. ):
  154. """Add new points to a frame."""
  155. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  156. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  157. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  158. if (points is not None) != (labels is not None):
  159. raise ValueError("points and labels must be provided together")
  160. if points is None and box is None:
  161. raise ValueError("at least one of points or box must be provided as input")
  162. if points is None:
  163. points = torch.zeros(0, 2, dtype=torch.float32)
  164. elif not isinstance(points, torch.Tensor):
  165. points = torch.tensor(points, dtype=torch.float32)
  166. if labels is None:
  167. labels = torch.zeros(0, dtype=torch.int32)
  168. elif not isinstance(labels, torch.Tensor):
  169. labels = torch.tensor(labels, dtype=torch.int32)
  170. if points.dim() == 2:
  171. points = points.unsqueeze(0) # add batch dimension
  172. if labels.dim() == 1:
  173. labels = labels.unsqueeze(0) # add batch dimension
  174. # If `box` is provided, we add it as the first two points with labels 2 and 3
  175. # along with the user-provided points (consistent with how SAM 2 is trained).
  176. if box is not None:
  177. if not clear_old_points:
  178. raise ValueError(
  179. "cannot add box without clearing old points, since "
  180. "box prompt must be provided before any point prompt "
  181. "(please use clear_old_points=True instead)"
  182. )
  183. if not isinstance(box, torch.Tensor):
  184. box = torch.tensor(box, dtype=torch.float32, device=points.device)
  185. box_coords = box.reshape(1, 2, 2)
  186. box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
  187. box_labels = box_labels.reshape(1, 2)
  188. points = torch.cat([box_coords, points], dim=1)
  189. labels = torch.cat([box_labels, labels], dim=1)
  190. if normalize_coords:
  191. video_H = inference_state["video_height"]
  192. video_W = inference_state["video_width"]
  193. points = points / torch.tensor([video_W, video_H]).to(points.device)
  194. # scale the (normalized) coordinates by the model's internal image size
  195. points = points * self.image_size
  196. points = points.to(inference_state["device"])
  197. labels = labels.to(inference_state["device"])
  198. if not clear_old_points:
  199. point_inputs = point_inputs_per_frame.get(frame_idx, None)
  200. else:
  201. point_inputs = None
  202. point_inputs = concat_points(point_inputs, points, labels)
  203. point_inputs_per_frame[frame_idx] = point_inputs
  204. mask_inputs_per_frame.pop(frame_idx, None)
  205. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  206. # frame, meaning that the inputs points are to generate segments on this frame without
  207. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  208. # the input points will be used to correct the already tracked masks.
  209. obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
  210. is_init_cond_frame = frame_idx not in obj_frames_tracked
  211. # whether to track in reverse time order
  212. if is_init_cond_frame:
  213. reverse = False
  214. else:
  215. reverse = obj_frames_tracked[frame_idx]["reverse"]
  216. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  217. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  218. # Add a frame to conditioning output if it's an initial conditioning frame or
  219. # if the model sees all frames receiving clicks/mask as conditioning frames.
  220. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  221. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  222. # Get any previously predicted mask logits on this object and feed it along with
  223. # the new clicks into the SAM mask decoder.
  224. prev_sam_mask_logits = None
  225. # lookup temporary output dict first, which contains the most recent output
  226. # (if not found, then lookup conditioning and non-conditioning frame output)
  227. prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
  228. if prev_out is None:
  229. prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
  230. if prev_out is None:
  231. prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
  232. if prev_out is not None and prev_out["pred_masks"] is not None:
  233. device = inference_state["device"]
  234. prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
  235. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
  236. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
  237. current_out, _ = self._run_single_frame_inference(
  238. inference_state=inference_state,
  239. output_dict=obj_output_dict, # run on the slice of a single object
  240. frame_idx=frame_idx,
  241. batch_size=1, # run on the slice of a single object
  242. is_init_cond_frame=is_init_cond_frame,
  243. point_inputs=point_inputs,
  244. mask_inputs=None,
  245. reverse=reverse,
  246. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  247. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  248. # allows us to enforce non-overlapping constraints on all objects before encoding
  249. # them into memory.
  250. run_mem_encoder=False,
  251. prev_sam_mask_logits=prev_sam_mask_logits,
  252. )
  253. # Add the output to the output dict (to be used as future memory)
  254. obj_temp_output_dict[storage_key][frame_idx] = current_out
  255. # Resize the output mask to the original video resolution
  256. obj_ids = inference_state["obj_ids"]
  257. consolidated_out = self._consolidate_temp_output_across_obj(
  258. inference_state,
  259. frame_idx,
  260. is_cond=is_cond,
  261. consolidate_at_video_res=True,
  262. )
  263. _, video_res_masks = self._get_orig_video_res_output(
  264. inference_state, consolidated_out["pred_masks_video_res"]
  265. )
  266. return frame_idx, obj_ids, video_res_masks
  267. def add_new_points(self, *args, **kwargs):
  268. """Deprecated method. Please use `add_new_points_or_box` instead."""
  269. return self.add_new_points_or_box(*args, **kwargs)
  270. @torch.inference_mode()
  271. def add_new_mask(
  272. self,
  273. inference_state,
  274. frame_idx,
  275. obj_id,
  276. mask,
  277. ):
  278. """Add new mask to a frame."""
  279. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  280. point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
  281. mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
  282. if not isinstance(mask, torch.Tensor):
  283. mask = torch.tensor(mask, dtype=torch.bool)
  284. assert mask.dim() == 2
  285. mask_H, mask_W = mask.shape
  286. mask_inputs_orig = mask[None, None] # add batch and channel dimension
  287. mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
  288. # resize the mask if it doesn't match the model's image size
  289. if mask_H != self.image_size or mask_W != self.image_size:
  290. mask_inputs = torch.nn.functional.interpolate(
  291. mask_inputs_orig,
  292. size=(self.image_size, self.image_size),
  293. align_corners=False,
  294. mode="bilinear",
  295. antialias=True, # use antialias for downsampling
  296. )
  297. mask_inputs = (mask_inputs >= 0.5).float()
  298. else:
  299. mask_inputs = mask_inputs_orig
  300. mask_inputs_per_frame[frame_idx] = mask_inputs
  301. point_inputs_per_frame.pop(frame_idx, None)
  302. # If this frame hasn't been tracked before, we treat it as an initial conditioning
  303. # frame, meaning that the inputs points are to generate segments on this frame without
  304. # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
  305. # the input points will be used to correct the already tracked masks.
  306. obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
  307. is_init_cond_frame = frame_idx not in obj_frames_tracked
  308. # whether to track in reverse time order
  309. if is_init_cond_frame:
  310. reverse = False
  311. else:
  312. reverse = obj_frames_tracked[frame_idx]["reverse"]
  313. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  314. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  315. # Add a frame to conditioning output if it's an initial conditioning frame or
  316. # if the model sees all frames receiving clicks/mask as conditioning frames.
  317. is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
  318. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  319. current_out, _ = self._run_single_frame_inference(
  320. inference_state=inference_state,
  321. output_dict=obj_output_dict, # run on the slice of a single object
  322. frame_idx=frame_idx,
  323. batch_size=1, # run on the slice of a single object
  324. is_init_cond_frame=is_init_cond_frame,
  325. point_inputs=None,
  326. mask_inputs=mask_inputs,
  327. reverse=reverse,
  328. # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
  329. # at the beginning of `propagate_in_video` (after user finalize their clicks). This
  330. # allows us to enforce non-overlapping constraints on all objects before encoding
  331. # them into memory.
  332. run_mem_encoder=False,
  333. )
  334. # Add the output to the output dict (to be used as future memory)
  335. obj_temp_output_dict[storage_key][frame_idx] = current_out
  336. # Resize the output mask to the original video resolution
  337. obj_ids = inference_state["obj_ids"]
  338. consolidated_out = self._consolidate_temp_output_across_obj(
  339. inference_state,
  340. frame_idx,
  341. is_cond=is_cond,
  342. consolidate_at_video_res=True,
  343. )
  344. _, video_res_masks = self._get_orig_video_res_output(
  345. inference_state, consolidated_out["pred_masks_video_res"]
  346. )
  347. return frame_idx, obj_ids, video_res_masks
  348. def _get_orig_video_res_output(self, inference_state, any_res_masks):
  349. """
  350. Resize the object scores to the original video resolution (video_res_masks)
  351. and apply non-overlapping constraints for final output.
  352. """
  353. device = inference_state["device"]
  354. video_H = inference_state["video_height"]
  355. video_W = inference_state["video_width"]
  356. any_res_masks = any_res_masks.to(device, non_blocking=True)
  357. if any_res_masks.shape[-2:] == (video_H, video_W):
  358. video_res_masks = any_res_masks
  359. else:
  360. video_res_masks = torch.nn.functional.interpolate(
  361. any_res_masks,
  362. size=(video_H, video_W),
  363. mode="bilinear",
  364. align_corners=False,
  365. )
  366. if self.non_overlap_masks:
  367. video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
  368. return any_res_masks, video_res_masks
  369. def _consolidate_temp_output_across_obj(
  370. self,
  371. inference_state,
  372. frame_idx,
  373. is_cond,
  374. consolidate_at_video_res=False,
  375. ):
  376. """
  377. Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
  378. a frame into a single output for all objects, including
  379. 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
  380. `output_dict_per_obj` for this frame) or leave them as placeholder values
  381. (if they don't exist in `output_dict_per_obj` for this frame);
  382. 2) if specified, rerun memory encoder after apply non-overlapping constraints
  383. on the object scores.
  384. """
  385. batch_size = self._get_obj_num(inference_state)
  386. storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  387. # Optionally, we allow consolidating the temporary outputs at the original
  388. # video resolution (to provide a better editing experience for mask prompts).
  389. if consolidate_at_video_res:
  390. consolidated_H = inference_state["video_height"]
  391. consolidated_W = inference_state["video_width"]
  392. consolidated_mask_key = "pred_masks_video_res"
  393. else:
  394. consolidated_H = consolidated_W = self.image_size // 4
  395. consolidated_mask_key = "pred_masks"
  396. # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
  397. # will be added when rerunning the memory encoder after applying non-overlapping
  398. # constraints to object scores. Its "pred_masks" are prefilled with a large
  399. # negative value (NO_OBJ_SCORE) to represent missing objects.
  400. consolidated_out = {
  401. consolidated_mask_key: torch.full(
  402. size=(batch_size, 1, consolidated_H, consolidated_W),
  403. fill_value=NO_OBJ_SCORE,
  404. dtype=torch.float32,
  405. device=inference_state["storage_device"],
  406. ),
  407. }
  408. for obj_idx in range(batch_size):
  409. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  410. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  411. out = obj_temp_output_dict[storage_key].get(frame_idx, None)
  412. # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
  413. # we fall back and look up its previous output in "output_dict_per_obj".
  414. # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
  415. # "output_dict_per_obj" to find a previous output for this object.
  416. if out is None:
  417. out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
  418. if out is None:
  419. out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
  420. # If the object doesn't appear in "output_dict_per_obj" either, we skip it
  421. # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
  422. # placeholder above) and set its object pointer to be a dummy pointer.
  423. if out is None:
  424. continue
  425. # Add the temporary object output mask to consolidated output mask
  426. obj_mask = out["pred_masks"]
  427. consolidated_pred_masks = consolidated_out[consolidated_mask_key]
  428. if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
  429. consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
  430. else:
  431. # Resize first if temporary object mask has a different resolution
  432. resized_obj_mask = torch.nn.functional.interpolate(
  433. obj_mask,
  434. size=consolidated_pred_masks.shape[-2:],
  435. mode="bilinear",
  436. align_corners=False,
  437. )
  438. consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
  439. return consolidated_out
  440. @torch.inference_mode()
  441. def propagate_in_video_preflight(self, inference_state):
  442. """Prepare inference_state and consolidate temporary outputs before tracking."""
  443. # Check and make sure that every object has received input points or masks.
  444. batch_size = self._get_obj_num(inference_state)
  445. if batch_size == 0:
  446. raise RuntimeError(
  447. "No input points or masks are provided for any object; please add inputs first."
  448. )
  449. # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
  450. # add them into "output_dict".
  451. for obj_idx in range(batch_size):
  452. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  453. obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
  454. for is_cond in [False, True]:
  455. # Separately consolidate conditioning and non-conditioning temp outputs
  456. storage_key = (
  457. "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
  458. )
  459. # Find all the frames that contain temporary outputs for any objects
  460. # (these should be the frames that have just received clicks for mask inputs
  461. # via `add_new_points_or_box` or `add_new_mask`)
  462. for frame_idx, out in obj_temp_output_dict[storage_key].items():
  463. # Run memory encoder on the temporary outputs (if the memory feature is missing)
  464. if out["maskmem_features"] is None:
  465. high_res_masks = torch.nn.functional.interpolate(
  466. out["pred_masks"].to(inference_state["device"]),
  467. size=(self.image_size, self.image_size),
  468. mode="bilinear",
  469. align_corners=False,
  470. )
  471. maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
  472. inference_state=inference_state,
  473. frame_idx=frame_idx,
  474. batch_size=1, # run on the slice of a single object
  475. high_res_masks=high_res_masks,
  476. object_score_logits=out["object_score_logits"],
  477. # these frames are what the user interacted with
  478. is_mask_from_pts=True,
  479. )
  480. out["maskmem_features"] = maskmem_features
  481. out["maskmem_pos_enc"] = maskmem_pos_enc
  482. obj_output_dict[storage_key][frame_idx] = out
  483. if self.clear_non_cond_mem_around_input:
  484. # clear non-conditioning memory of the surrounding frames
  485. self._clear_obj_non_cond_mem_around_input(
  486. inference_state, frame_idx, obj_idx
  487. )
  488. # clear temporary outputs in `temp_output_dict_per_obj`
  489. obj_temp_output_dict[storage_key].clear()
  490. # check and make sure that every object has received input points or masks
  491. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  492. if len(obj_output_dict["cond_frame_outputs"]) == 0:
  493. obj_id = self._obj_idx_to_id(inference_state, obj_idx)
  494. raise RuntimeError(
  495. f"No input points or masks are provided for object id {obj_id}; please add inputs first."
  496. )
  497. # edge case: if an output is added to "cond_frame_outputs", we remove any prior
  498. # output on the same frame in "non_cond_frame_outputs"
  499. for frame_idx in obj_output_dict["cond_frame_outputs"]:
  500. obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
  501. @torch.inference_mode()
  502. def propagate_in_video(
  503. self,
  504. inference_state,
  505. start_frame_idx=None,
  506. max_frame_num_to_track=None,
  507. reverse=False,
  508. ):
  509. """Propagate the input points across frames to track in the entire video."""
  510. self.propagate_in_video_preflight(inference_state)
  511. obj_ids = inference_state["obj_ids"]
  512. num_frames = inference_state["num_frames"]
  513. batch_size = self._get_obj_num(inference_state)
  514. # set start index, end index, and processing order
  515. if start_frame_idx is None:
  516. # default: start from the earliest frame with input points
  517. start_frame_idx = min(
  518. t
  519. for obj_output_dict in inference_state["output_dict_per_obj"].values()
  520. for t in obj_output_dict["cond_frame_outputs"]
  521. )
  522. if max_frame_num_to_track is None:
  523. # default: track all the frames in the video
  524. max_frame_num_to_track = num_frames
  525. if reverse:
  526. end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
  527. if start_frame_idx > 0:
  528. processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
  529. else:
  530. processing_order = [] # skip reverse tracking if starting from frame 0
  531. else:
  532. end_frame_idx = min(
  533. start_frame_idx + max_frame_num_to_track, num_frames - 1
  534. )
  535. processing_order = range(start_frame_idx, end_frame_idx + 1)
  536. for frame_idx in tqdm(processing_order, desc="propagate in video"):
  537. pred_masks_per_obj = [None] * batch_size
  538. for obj_idx in range(batch_size):
  539. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  540. # We skip those frames already in consolidated outputs (these are frames
  541. # that received input clicks or mask). Note that we cannot directly run
  542. # batched forward on them via `_run_single_frame_inference` because the
  543. # number of clicks on each object might be different.
  544. if frame_idx in obj_output_dict["cond_frame_outputs"]:
  545. storage_key = "cond_frame_outputs"
  546. current_out = obj_output_dict[storage_key][frame_idx]
  547. device = inference_state["device"]
  548. pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
  549. if self.clear_non_cond_mem_around_input:
  550. # clear non-conditioning memory of the surrounding frames
  551. self._clear_obj_non_cond_mem_around_input(
  552. inference_state, frame_idx, obj_idx
  553. )
  554. else:
  555. storage_key = "non_cond_frame_outputs"
  556. current_out, pred_masks = self._run_single_frame_inference(
  557. inference_state=inference_state,
  558. output_dict=obj_output_dict,
  559. frame_idx=frame_idx,
  560. batch_size=1, # run on the slice of a single object
  561. is_init_cond_frame=False,
  562. point_inputs=None,
  563. mask_inputs=None,
  564. reverse=reverse,
  565. run_mem_encoder=True,
  566. )
  567. obj_output_dict[storage_key][frame_idx] = current_out
  568. inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
  569. "reverse": reverse
  570. }
  571. pred_masks_per_obj[obj_idx] = pred_masks
  572. # Resize the output mask to the original video resolution (we directly use
  573. # the mask scores on GPU for output to avoid any CPU conversion in between)
  574. if len(pred_masks_per_obj) > 1:
  575. all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
  576. else:
  577. all_pred_masks = pred_masks_per_obj[0]
  578. _, video_res_masks = self._get_orig_video_res_output(
  579. inference_state, all_pred_masks
  580. )
  581. yield frame_idx, obj_ids, video_res_masks
  582. @torch.inference_mode()
  583. def clear_all_prompts_in_frame(
  584. self, inference_state, frame_idx, obj_id, need_output=True
  585. ):
  586. """Remove all input points or mask in a specific frame for a given object."""
  587. obj_idx = self._obj_id_to_idx(inference_state, obj_id)
  588. # Clear the conditioning information on the given frame
  589. inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
  590. inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
  591. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  592. temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
  593. temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
  594. # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
  595. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  596. out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
  597. if out is not None:
  598. # The frame is not a conditioning frame anymore since it's not receiving inputs,
  599. # so we "downgrade" its output (if exists) to a non-conditioning frame output.
  600. obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
  601. inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
  602. if not need_output:
  603. return
  604. # Finally, output updated masks per object (after removing the inputs above)
  605. obj_ids = inference_state["obj_ids"]
  606. is_cond = any(
  607. frame_idx in obj_temp_output_dict["cond_frame_outputs"]
  608. for obj_temp_output_dict in temp_output_dict_per_obj.values()
  609. )
  610. consolidated_out = self._consolidate_temp_output_across_obj(
  611. inference_state,
  612. frame_idx,
  613. is_cond=is_cond,
  614. consolidate_at_video_res=True,
  615. )
  616. _, video_res_masks = self._get_orig_video_res_output(
  617. inference_state, consolidated_out["pred_masks_video_res"]
  618. )
  619. return frame_idx, obj_ids, video_res_masks
  620. @torch.inference_mode()
  621. def reset_state(self, inference_state):
  622. """Remove all input points or mask in all frames throughout the video."""
  623. self._reset_tracking_results(inference_state)
  624. # Remove all object ids
  625. inference_state["obj_id_to_idx"].clear()
  626. inference_state["obj_idx_to_id"].clear()
  627. inference_state["obj_ids"].clear()
  628. inference_state["point_inputs_per_obj"].clear()
  629. inference_state["mask_inputs_per_obj"].clear()
  630. inference_state["output_dict_per_obj"].clear()
  631. inference_state["temp_output_dict_per_obj"].clear()
  632. inference_state["frames_tracked_per_obj"].clear()
  633. def _reset_tracking_results(self, inference_state):
  634. """Reset all tracking inputs and results across the videos."""
  635. for v in inference_state["point_inputs_per_obj"].values():
  636. v.clear()
  637. for v in inference_state["mask_inputs_per_obj"].values():
  638. v.clear()
  639. for v in inference_state["output_dict_per_obj"].values():
  640. v["cond_frame_outputs"].clear()
  641. v["non_cond_frame_outputs"].clear()
  642. for v in inference_state["temp_output_dict_per_obj"].values():
  643. v["cond_frame_outputs"].clear()
  644. v["non_cond_frame_outputs"].clear()
  645. for v in inference_state["frames_tracked_per_obj"].values():
  646. v.clear()
  647. def _get_image_feature(self, inference_state, frame_idx, batch_size):
  648. """Compute the image features on a given frame."""
  649. # Look up in the cache first
  650. image, backbone_out = inference_state["cached_features"].get(
  651. frame_idx, (None, None)
  652. )
  653. if backbone_out is None:
  654. # Cache miss -- we will run inference on a single image
  655. device = inference_state["device"]
  656. image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
  657. backbone_out = self.forward_image(image)
  658. # Cache the most recent frame's feature (for repeated interactions with
  659. # a frame; we can use an LRU cache for more frames in the future).
  660. inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
  661. # expand the features to have the same dimension as the number of objects
  662. expanded_image = image.expand(batch_size, -1, -1, -1)
  663. expanded_backbone_out = {
  664. "backbone_fpn": backbone_out["backbone_fpn"].copy(),
  665. "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
  666. }
  667. for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
  668. expanded_backbone_out["backbone_fpn"][i] = feat.expand(
  669. batch_size, -1, -1, -1
  670. )
  671. for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
  672. pos = pos.expand(batch_size, -1, -1, -1)
  673. expanded_backbone_out["vision_pos_enc"][i] = pos
  674. features = self._prepare_backbone_features(expanded_backbone_out)
  675. features = (expanded_image,) + features
  676. return features
  677. def _run_single_frame_inference(
  678. self,
  679. inference_state,
  680. output_dict,
  681. frame_idx,
  682. batch_size,
  683. is_init_cond_frame,
  684. point_inputs,
  685. mask_inputs,
  686. reverse,
  687. run_mem_encoder,
  688. prev_sam_mask_logits=None,
  689. ):
  690. """Run tracking on a single frame based on current inputs and previous memory."""
  691. # Retrieve correct image features
  692. (
  693. _,
  694. _,
  695. current_vision_feats,
  696. current_vision_pos_embeds,
  697. feat_sizes,
  698. ) = self._get_image_feature(inference_state, frame_idx, batch_size)
  699. # point and mask should not appear as input simultaneously on the same frame
  700. assert point_inputs is None or mask_inputs is None
  701. current_out = self.track_step(
  702. frame_idx=frame_idx,
  703. is_init_cond_frame=is_init_cond_frame,
  704. current_vision_feats=current_vision_feats,
  705. current_vision_pos_embeds=current_vision_pos_embeds,
  706. feat_sizes=feat_sizes,
  707. point_inputs=point_inputs,
  708. mask_inputs=mask_inputs,
  709. output_dict=output_dict,
  710. num_frames=inference_state["num_frames"],
  711. track_in_reverse=reverse,
  712. run_mem_encoder=run_mem_encoder,
  713. prev_sam_mask_logits=prev_sam_mask_logits,
  714. )
  715. # optionally offload the output to CPU memory to save GPU space
  716. storage_device = inference_state["storage_device"]
  717. maskmem_features = current_out["maskmem_features"]
  718. if maskmem_features is not None:
  719. maskmem_features = maskmem_features.to(torch.bfloat16)
  720. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  721. pred_masks_gpu = current_out["pred_masks"]
  722. # potentially fill holes in the predicted masks
  723. if self.fill_hole_area > 0:
  724. pred_masks_gpu = fill_holes_in_mask_scores(
  725. pred_masks_gpu, self.fill_hole_area
  726. )
  727. pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
  728. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  729. maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
  730. # object pointer is a small tensor, so we always keep it on GPU memory for fast access
  731. obj_ptr = current_out["obj_ptr"]
  732. object_score_logits = current_out["object_score_logits"]
  733. # make a compact version of this frame's output to reduce the state size
  734. compact_current_out = {
  735. "maskmem_features": maskmem_features,
  736. "maskmem_pos_enc": maskmem_pos_enc,
  737. "pred_masks": pred_masks,
  738. "obj_ptr": obj_ptr,
  739. "object_score_logits": object_score_logits,
  740. }
  741. return compact_current_out, pred_masks_gpu
  742. def _run_memory_encoder(
  743. self,
  744. inference_state,
  745. frame_idx,
  746. batch_size,
  747. high_res_masks,
  748. object_score_logits,
  749. is_mask_from_pts,
  750. ):
  751. """
  752. Run the memory encoder on `high_res_masks`. This is usually after applying
  753. non-overlapping constraints to object scores. Since their scores changed, their
  754. memory also need to be computed again with the memory encoder.
  755. """
  756. # Retrieve correct image features
  757. _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
  758. inference_state, frame_idx, batch_size
  759. )
  760. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  761. current_vision_feats=current_vision_feats,
  762. feat_sizes=feat_sizes,
  763. pred_masks_high_res=high_res_masks,
  764. object_score_logits=object_score_logits,
  765. is_mask_from_pts=is_mask_from_pts,
  766. )
  767. # optionally offload the output to CPU memory to save GPU space
  768. storage_device = inference_state["storage_device"]
  769. maskmem_features = maskmem_features.to(torch.bfloat16)
  770. maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
  771. # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
  772. maskmem_pos_enc = self._get_maskmem_pos_enc(
  773. inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
  774. )
  775. return maskmem_features, maskmem_pos_enc
  776. def _get_maskmem_pos_enc(self, inference_state, current_out):
  777. """
  778. `maskmem_pos_enc` is the same across frames and objects, so we cache it as
  779. a constant in the inference session to reduce session storage size.
  780. """
  781. model_constants = inference_state["constants"]
  782. # "out_maskmem_pos_enc" should be either a list of tensors or None
  783. out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
  784. if out_maskmem_pos_enc is not None:
  785. if "maskmem_pos_enc" not in model_constants:
  786. assert isinstance(out_maskmem_pos_enc, list)
  787. # only take the slice for one object, since it's same across objects
  788. maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
  789. model_constants["maskmem_pos_enc"] = maskmem_pos_enc
  790. else:
  791. maskmem_pos_enc = model_constants["maskmem_pos_enc"]
  792. # expand the cached maskmem_pos_enc to the actual batch size
  793. batch_size = out_maskmem_pos_enc[0].size(0)
  794. expanded_maskmem_pos_enc = [
  795. x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
  796. ]
  797. else:
  798. expanded_maskmem_pos_enc = None
  799. return expanded_maskmem_pos_enc
  800. @torch.inference_mode()
  801. def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
  802. """
  803. Remove an object id from the tracking state. If strict is True, we check whether
  804. the object id actually exists and raise an error if it doesn't exist.
  805. """
  806. old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
  807. updated_frames = []
  808. # Check whether this object_id to remove actually exists and possibly raise an error.
  809. if old_obj_idx_to_rm is None:
  810. if not strict:
  811. return inference_state["obj_ids"], updated_frames
  812. raise RuntimeError(
  813. f"Cannot remove object id {obj_id} as it doesn't exist. "
  814. f"All existing object ids: {inference_state['obj_ids']}."
  815. )
  816. # If this is the only remaining object id, we simply reset the state.
  817. if len(inference_state["obj_id_to_idx"]) == 1:
  818. self.reset_state(inference_state)
  819. return inference_state["obj_ids"], updated_frames
  820. # There are still remaining objects after removing this object id. In this case,
  821. # we need to delete the object storage from inference state tensors.
  822. # Step 0: clear the input on those frames where this object id has point or mask input
  823. # (note that this step is required as it might downgrade conditioning frames to
  824. # non-conditioning ones)
  825. obj_input_frames_inds = set()
  826. obj_input_frames_inds.update(
  827. inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
  828. )
  829. obj_input_frames_inds.update(
  830. inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
  831. )
  832. for frame_idx in obj_input_frames_inds:
  833. self.clear_all_prompts_in_frame(
  834. inference_state, frame_idx, obj_id, need_output=False
  835. )
  836. # Step 1: Update the object id mapping (note that it must be done after Step 0,
  837. # since Step 0 still requires the old object id mappings in inference_state)
  838. old_obj_ids = inference_state["obj_ids"]
  839. old_obj_inds = list(range(len(old_obj_ids)))
  840. remain_old_obj_inds = old_obj_inds.copy()
  841. remain_old_obj_inds.remove(old_obj_idx_to_rm)
  842. new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
  843. new_obj_inds = list(range(len(new_obj_ids)))
  844. # build new mappings
  845. old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
  846. inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
  847. inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
  848. inference_state["obj_ids"] = new_obj_ids
  849. # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
  850. def _map_keys(container):
  851. new_kvs = []
  852. for k in old_obj_inds:
  853. v = container.pop(k)
  854. if k in old_idx_to_new_idx:
  855. new_kvs.append((old_idx_to_new_idx[k], v))
  856. container.update(new_kvs)
  857. _map_keys(inference_state["point_inputs_per_obj"])
  858. _map_keys(inference_state["mask_inputs_per_obj"])
  859. _map_keys(inference_state["output_dict_per_obj"])
  860. _map_keys(inference_state["temp_output_dict_per_obj"])
  861. _map_keys(inference_state["frames_tracked_per_obj"])
  862. # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
  863. # could show an updated mask for objects previously occluded by the object being removed
  864. if need_output:
  865. temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
  866. for frame_idx in obj_input_frames_inds:
  867. is_cond = any(
  868. frame_idx in obj_temp_output_dict["cond_frame_outputs"]
  869. for obj_temp_output_dict in temp_output_dict_per_obj.values()
  870. )
  871. consolidated_out = self._consolidate_temp_output_across_obj(
  872. inference_state,
  873. frame_idx,
  874. is_cond=is_cond,
  875. consolidate_at_video_res=True,
  876. )
  877. _, video_res_masks = self._get_orig_video_res_output(
  878. inference_state, consolidated_out["pred_masks_video_res"]
  879. )
  880. updated_frames.append((frame_idx, video_res_masks))
  881. return inference_state["obj_ids"], updated_frames
  882. def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
  883. """
  884. Remove the non-conditioning memory around the input frame. When users provide
  885. correction clicks, the surrounding frames' non-conditioning memories can still
  886. contain outdated object appearance information and could confuse the model.
  887. This method clears those non-conditioning memories surrounding the interacted
  888. frame to avoid giving the model both old and new information about the object.
  889. """
  890. r = self.memory_temporal_stride_for_eval
  891. frame_idx_begin = frame_idx - r * self.num_maskmem
  892. frame_idx_end = frame_idx + r * self.num_maskmem
  893. batch_size = self._get_obj_num(inference_state)
  894. for obj_idx in range(batch_size):
  895. obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
  896. non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
  897. for t in range(frame_idx_begin, frame_idx_end + 1):
  898. non_cond_frame_outputs.pop(t, None)
  899. class SAM2VideoPredictorVOS(SAM2VideoPredictor):
  900. """Optimized for the VOS setting"""
  901. def __init__(self, *args, **kwargs):
  902. super().__init__(*args, **kwargs)
  903. self._compile_all_components()
  904. def _compile_all_components(self):
  905. print("Compiling all components for VOS setting. First time may be very slow.")
  906. self.memory_encoder.forward = torch.compile(
  907. self.memory_encoder.forward,
  908. mode="max-autotune",
  909. fullgraph=True,
  910. dynamic=False,
  911. )
  912. self.memory_attention.forward = torch.compile(
  913. self.memory_attention.forward,
  914. mode="max-autotune",
  915. fullgraph=True,
  916. dynamic=True, # Num. of memories varies
  917. )
  918. self.sam_prompt_encoder.forward = torch.compile(
  919. self.sam_prompt_encoder.forward,
  920. mode="max-autotune",
  921. fullgraph=True,
  922. dynamic=False, # Accuracy regression on True
  923. )
  924. self.sam_mask_decoder.forward = torch.compile(
  925. self.sam_mask_decoder.forward,
  926. mode="max-autotune",
  927. fullgraph=True,
  928. dynamic=False, # Accuracy regression on True
  929. )
  930. def forward_image(self, img_batch: torch.Tensor):
  931. """
  932. Identical to the corresponding method in the parent (SAM2VideoPredictor), but
  933. cloning the backbone features and pos encoding to enable compilation.
  934. """
  935. backbone_out = self.image_encoder(img_batch)
  936. if self.use_high_res_features_in_sam:
  937. # precompute projected level 0 and level 1 features in SAM decoder
  938. # to avoid running it again on every SAM click
  939. backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
  940. backbone_out["backbone_fpn"][0]
  941. )
  942. backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
  943. backbone_out["backbone_fpn"][1]
  944. )
  945. # Clone to help torch.compile
  946. for i in range(len(backbone_out["backbone_fpn"])):
  947. backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
  948. backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
  949. i
  950. ].clone()
  951. return backbone_out
  952. def _forward_sam_heads(
  953. self,
  954. backbone_features,
  955. point_inputs=None,
  956. mask_inputs=None,
  957. high_res_features=None,
  958. multimask_output=False,
  959. ):
  960. """
  961. Identical to the corresponding method in the parent (SAM2VideoPredictor), but
  962. cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
  963. """
  964. B = backbone_features.size(0)
  965. device = backbone_features.device
  966. assert backbone_features.size(1) == self.sam_prompt_embed_dim
  967. assert backbone_features.size(2) == self.sam_image_embedding_size
  968. assert backbone_features.size(3) == self.sam_image_embedding_size
  969. # a) Handle point prompts
  970. if point_inputs is not None:
  971. sam_point_coords = point_inputs["point_coords"]
  972. sam_point_labels = point_inputs["point_labels"]
  973. assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
  974. else:
  975. # If no points are provide, pad with an empty point (with label -1)
  976. sam_point_coords = torch.zeros(B, 1, 2, device=device)
  977. sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
  978. # b) Handle mask prompts
  979. if mask_inputs is not None:
  980. # If mask_inputs is provided, downsize it into low-res mask input if needed
  981. # and feed it as a dense mask prompt into the SAM mask encoder
  982. assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
  983. if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
  984. sam_mask_prompt = F.interpolate(
  985. mask_inputs.float(),
  986. size=self.sam_prompt_encoder.mask_input_size,
  987. align_corners=False,
  988. mode="bilinear",
  989. antialias=True, # use antialias for downsampling
  990. )
  991. else:
  992. sam_mask_prompt = mask_inputs
  993. else:
  994. # Otherwise, simply feed None (and SAM's prompt encoder will add
  995. # a learned `no_mask_embed` to indicate no mask input in this case).
  996. sam_mask_prompt = None
  997. sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
  998. points=(sam_point_coords, sam_point_labels),
  999. boxes=None,
  1000. masks=sam_mask_prompt,
  1001. )
  1002. # Clone image_pe and the outputs of sam_prompt_encoder
  1003. # to enable compilation
  1004. sparse_embeddings = sparse_embeddings.clone()
  1005. dense_embeddings = dense_embeddings.clone()
  1006. image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
  1007. (
  1008. low_res_multimasks,
  1009. ious,
  1010. sam_output_tokens,
  1011. object_score_logits,
  1012. ) = self.sam_mask_decoder(
  1013. image_embeddings=backbone_features,
  1014. image_pe=image_pe,
  1015. sparse_prompt_embeddings=sparse_embeddings,
  1016. dense_prompt_embeddings=dense_embeddings,
  1017. multimask_output=multimask_output,
  1018. repeat_image=False, # the image is already batched
  1019. high_res_features=high_res_features,
  1020. )
  1021. # Clone the output of sam_mask_decoder
  1022. # to enable compilation
  1023. low_res_multimasks = low_res_multimasks.clone()
  1024. ious = ious.clone()
  1025. sam_output_tokens = sam_output_tokens.clone()
  1026. object_score_logits = object_score_logits.clone()
  1027. if self.pred_obj_scores:
  1028. is_obj_appearing = object_score_logits > 0
  1029. # Mask used for spatial memories is always a *hard* choice between obj and no obj,
  1030. # consistent with the actual mask prediction
  1031. low_res_multimasks = torch.where(
  1032. is_obj_appearing[:, None, None],
  1033. low_res_multimasks,
  1034. NO_OBJ_SCORE,
  1035. )
  1036. # convert masks from possibly bfloat16 (or float16) to float32
  1037. # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
  1038. low_res_multimasks = low_res_multimasks.float()
  1039. high_res_multimasks = F.interpolate(
  1040. low_res_multimasks,
  1041. size=(self.image_size, self.image_size),
  1042. mode="bilinear",
  1043. align_corners=False,
  1044. )
  1045. sam_output_token = sam_output_tokens[:, 0]
  1046. if multimask_output:
  1047. # take the best mask prediction (with the highest IoU estimation)
  1048. best_iou_inds = torch.argmax(ious, dim=-1)
  1049. batch_inds = torch.arange(B, device=device)
  1050. low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  1051. high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  1052. if sam_output_tokens.size(1) > 1:
  1053. sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
  1054. else:
  1055. low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
  1056. # Extract object pointer from the SAM output token (with occlusion handling)
  1057. obj_ptr = self.obj_ptr_proj(sam_output_token)
  1058. if self.pred_obj_scores:
  1059. # Allow *soft* no obj ptr, unlike for masks
  1060. if self.soft_no_obj_ptr:
  1061. lambda_is_obj_appearing = object_score_logits.sigmoid()
  1062. else:
  1063. lambda_is_obj_appearing = is_obj_appearing.float()
  1064. if self.fixed_no_obj_ptr:
  1065. obj_ptr = lambda_is_obj_appearing * obj_ptr
  1066. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  1067. return (
  1068. low_res_multimasks,
  1069. high_res_multimasks,
  1070. ious,
  1071. low_res_masks,
  1072. high_res_masks,
  1073. obj_ptr,
  1074. object_score_logits,
  1075. )
  1076. def _encode_new_memory(
  1077. self,
  1078. current_vision_feats,
  1079. feat_sizes,
  1080. pred_masks_high_res,
  1081. object_score_logits,
  1082. is_mask_from_pts,
  1083. ):
  1084. """
  1085. Identical to the corresponding method in the parent (SAM2VideoPredictor), but
  1086. cloning the memories and their pos enc to enable compilation.
  1087. """
  1088. B = current_vision_feats[-1].size(1) # batch size on this frame
  1089. C = self.hidden_dim
  1090. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  1091. # top-level feature, (HW)BC => BCHW
  1092. pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  1093. if self.non_overlap_masks_for_mem_enc and not self.training:
  1094. # optionally, apply non-overlapping constraints to the masks (it's applied
  1095. # in the batch dimension and should only be used during eval, where all
  1096. # the objects come from the same video under batch size 1).
  1097. pred_masks_high_res = self._apply_non_overlapping_constraints(
  1098. pred_masks_high_res
  1099. )
  1100. # scale the raw mask logits with a temperature before applying sigmoid
  1101. binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
  1102. if binarize and not self.training:
  1103. mask_for_mem = (pred_masks_high_res > 0).float()
  1104. else:
  1105. # apply sigmoid on the raw mask logits to turn them into range (0, 1)
  1106. mask_for_mem = torch.sigmoid(pred_masks_high_res)
  1107. # apply scale and bias terms to the sigmoid probabilities
  1108. if self.sigmoid_scale_for_mem_enc != 1.0:
  1109. mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
  1110. if self.sigmoid_bias_for_mem_enc != 0.0:
  1111. mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
  1112. maskmem_out = self.memory_encoder(
  1113. pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
  1114. )
  1115. # Clone the feats and pos_enc to enable compilation
  1116. maskmem_features = maskmem_out["vision_features"].clone()
  1117. maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
  1118. # add a no-object embedding to the spatial memory to indicate that the frame
  1119. # is predicted to be occluded (i.e. no object is appearing in the frame)
  1120. if self.no_obj_embed_spatial is not None:
  1121. is_obj_appearing = (object_score_logits > 0).float()
  1122. maskmem_features += (
  1123. 1 - is_obj_appearing[..., None, None]
  1124. ) * self.no_obj_embed_spatial[..., None, None].expand(
  1125. *maskmem_features.shape
  1126. )
  1127. return maskmem_features, maskmem_pos_enc