sam3_tracker_base.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. import torch
  5. import torch.nn.functional as F
  6. from sam3.model.memory import SimpleMaskEncoder
  7. from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
  8. from sam3.sam.mask_decoder import MaskDecoder, MLP
  9. from sam3.sam.prompt_encoder import PromptEncoder
  10. from sam3.sam.transformer import TwoWayTransformer
  11. from sam3.train.data.collator import BatchedDatapoint
  12. try:
  13. from timm.layers import trunc_normal_
  14. except ModuleNotFoundError:
  15. # compatibility for older timm versions
  16. from timm.models.layers import trunc_normal_
  17. # a large negative value as a placeholder score for missing objects
  18. NO_OBJ_SCORE = -1024.0
  19. class Sam3TrackerBase(torch.nn.Module):
  20. def __init__(
  21. self,
  22. backbone,
  23. transformer,
  24. maskmem_backbone,
  25. num_maskmem=7, # default 1 input frame + 6 previous frames as in CAE
  26. image_size=1008,
  27. backbone_stride=14, # stride of the image backbone output
  28. # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
  29. # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
  30. # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
  31. max_cond_frames_in_attn=-1,
  32. # Whether to always keep the first conditioning frame in case we exceed the maximum number of conditioning frames allowed
  33. keep_first_cond_frame=False,
  34. # whether to output multiple (3) masks for the first click on initial conditioning frames
  35. multimask_output_in_sam=False,
  36. # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
  37. # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
  38. multimask_min_pt_num=1,
  39. multimask_max_pt_num=1,
  40. # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
  41. multimask_output_for_tracking=False,
  42. # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
  43. # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
  44. forward_backbone_per_frame_for_eval=False,
  45. # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
  46. # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
  47. # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
  48. memory_temporal_stride_for_eval=1,
  49. # whether to offload outputs to CPU memory during evaluation, to avoid GPU OOM on very long videos or very large resolutions or too many objects
  50. # (it's recommended to use `forward_backbone_per_frame_for_eval=True` first before setting this option to True)
  51. offload_output_to_cpu_for_eval=False,
  52. # whether to trim the output of past non-conditioning frames (num_maskmem frames before the current frame) during evaluation
  53. # (this helps save GPU or CPU memory on very long videos for semi-supervised VOS eval, where only the first frame receives prompts)
  54. trim_past_non_cond_mem_for_eval=False,
  55. # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
  56. non_overlap_masks_for_mem_enc=False,
  57. # the maximum number of object pointers from other frames in encoder cross attention
  58. max_obj_ptrs_in_encoder=16,
  59. # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
  60. sam_mask_decoder_extra_args=None,
  61. # whether to compile all the model compoents
  62. compile_all_components=False,
  63. # select the frame with object existence
  64. use_memory_selection=False,
  65. # when using memory selection, the threshold to determine if the frame is good
  66. mf_threshold=0.01,
  67. ):
  68. super().__init__()
  69. # Part 1: the image backbone
  70. self.backbone = backbone
  71. self.num_feature_levels = 3
  72. self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
  73. # A conv layer to downsample the GT mask prompt to stride 4 (the same stride as
  74. # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
  75. # so that it can be fed into the SAM mask decoder to generate a pointer.
  76. self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
  77. # Part 2: encoder-only transformer to fuse current frame's visual features
  78. # with memories from past frames
  79. assert transformer.decoder is None, "transformer should be encoder-only"
  80. self.transformer = transformer
  81. self.hidden_dim = transformer.d_model
  82. # Part 3: memory encoder for the previous frame's outputs
  83. self.maskmem_backbone = maskmem_backbone
  84. self.mem_dim = self.hidden_dim
  85. if hasattr(self.maskmem_backbone, "out_proj") and hasattr(
  86. self.maskmem_backbone.out_proj, "weight"
  87. ):
  88. # if there is compression of memories along channel dim
  89. self.mem_dim = self.maskmem_backbone.out_proj.weight.shape[0]
  90. self.num_maskmem = num_maskmem # Number of memories accessible
  91. # Temporal encoding of the memories
  92. self.maskmem_tpos_enc = torch.nn.Parameter(
  93. torch.zeros(num_maskmem, 1, 1, self.mem_dim)
  94. )
  95. trunc_normal_(self.maskmem_tpos_enc, std=0.02)
  96. # a single token to indicate no memory embedding from previous frames
  97. self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  98. self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  99. trunc_normal_(self.no_mem_embed, std=0.02)
  100. trunc_normal_(self.no_mem_pos_enc, std=0.02)
  101. # Apply sigmoid to the output raw mask logits (to turn them from
  102. # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
  103. self.sigmoid_scale_for_mem_enc = 20.0
  104. self.sigmoid_bias_for_mem_enc = -10.0
  105. self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
  106. self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
  107. # On frames with mask input, whether to directly output the input mask without
  108. # using a SAM prompt encoder + mask decoder
  109. self.multimask_output_in_sam = multimask_output_in_sam
  110. self.multimask_min_pt_num = multimask_min_pt_num
  111. self.multimask_max_pt_num = multimask_max_pt_num
  112. self.multimask_output_for_tracking = multimask_output_for_tracking
  113. # Part 4: SAM-style prompt encoder (for both mask and point inputs)
  114. # and SAM-style mask decoder for the final mask output
  115. self.image_size = image_size
  116. self.backbone_stride = backbone_stride
  117. self.low_res_mask_size = self.image_size // self.backbone_stride * 4
  118. # we resize the mask if it doesn't match `self.input_mask_size` (which is always 4x
  119. # the low-res mask size, regardless of the actual input image size); this is because
  120. # `_use_mask_as_output` always downsamples the input masks by 4x
  121. self.input_mask_size = self.low_res_mask_size * 4
  122. self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
  123. self.offload_output_to_cpu_for_eval = offload_output_to_cpu_for_eval
  124. self.trim_past_non_cond_mem_for_eval = trim_past_non_cond_mem_for_eval
  125. self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
  126. self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
  127. trunc_normal_(self.no_obj_ptr, std=0.02)
  128. self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
  129. trunc_normal_(self.no_obj_embed_spatial, std=0.02)
  130. self._build_sam_heads()
  131. self.max_cond_frames_in_attn = max_cond_frames_in_attn
  132. self.keep_first_cond_frame = keep_first_cond_frame
  133. # Use frame filtering according to SAM2Long
  134. self.use_memory_selection = use_memory_selection
  135. self.mf_threshold = mf_threshold
  136. # Compile all components of the model
  137. self.compile_all_components = compile_all_components
  138. if self.compile_all_components:
  139. self._compile_all_components()
  140. @property
  141. def device(self):
  142. return next(self.parameters()).device
  143. def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False):
  144. if dummy:
  145. return torch.zeros(len(rel_pos_list), self.mem_dim, device=device)
  146. t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1
  147. pos_enc = (
  148. torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True)
  149. / t_diff_max
  150. )
  151. tpos_dim = self.hidden_dim
  152. pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim)
  153. pos_enc = self.obj_ptr_tpos_proj(pos_enc)
  154. return pos_enc
  155. def _build_sam_heads(self):
  156. """Build SAM-style prompt encoder and mask decoder."""
  157. self.sam_prompt_embed_dim = self.hidden_dim
  158. self.sam_image_embedding_size = self.image_size // self.backbone_stride
  159. # build PromptEncoder and MaskDecoder from SAM
  160. # (their hyperparameters like `mask_in_chans=16` are from SAM code)
  161. self.sam_prompt_encoder = PromptEncoder(
  162. embed_dim=self.sam_prompt_embed_dim,
  163. image_embedding_size=(
  164. self.sam_image_embedding_size,
  165. self.sam_image_embedding_size,
  166. ),
  167. input_image_size=(self.image_size, self.image_size),
  168. mask_in_chans=16,
  169. )
  170. self.sam_mask_decoder = MaskDecoder(
  171. num_multimask_outputs=3,
  172. transformer=TwoWayTransformer(
  173. depth=2,
  174. embedding_dim=self.sam_prompt_embed_dim,
  175. mlp_dim=2048,
  176. num_heads=8,
  177. ),
  178. transformer_dim=self.sam_prompt_embed_dim,
  179. iou_head_depth=3,
  180. iou_head_hidden_dim=256,
  181. use_high_res_features=True,
  182. iou_prediction_use_sigmoid=True,
  183. pred_obj_scores=True,
  184. pred_obj_scores_mlp=True,
  185. use_multimask_token_for_obj_ptr=True,
  186. **(self.sam_mask_decoder_extra_args or {}),
  187. )
  188. # a linear projection on SAM output tokens to turn them into object pointers
  189. self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
  190. self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
  191. # a linear projection on temporal positional encoding in object pointers to
  192. # avoid potential interference with spatial positional encoding
  193. self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
  194. def _forward_sam_heads(
  195. self,
  196. backbone_features,
  197. point_inputs=None,
  198. mask_inputs=None,
  199. high_res_features=None,
  200. multimask_output=False,
  201. gt_masks=None,
  202. ):
  203. """
  204. Forward SAM prompt encoders and mask heads.
  205. Inputs:
  206. - backbone_features: image features of [B, C, H, W] shape
  207. - point_inputs: a dictionary with "point_coords" and "point_labels", where
  208. 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
  209. absolute pixel-unit coordinate in (x, y) format of the P input points
  210. 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
  211. positive clicks, 0 means negative clicks, and -1 means padding
  212. - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
  213. same spatial size as the image.
  214. - high_res_features: either 1) None or 2) or a list of length 2 containing
  215. two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
  216. which will be used as high-resolution feature maps for SAM decoder.
  217. - multimask_output: if it's True, we output 3 candidate masks and their 3
  218. corresponding IoU estimates, and if it's False, we output only 1 mask and
  219. its corresponding IoU estimate.
  220. Outputs:
  221. - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
  222. `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
  223. output mask logits (before sigmoid) for the low-resolution masks, with 4x
  224. the resolution (1/4 stride) of the input backbone_features.
  225. - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
  226. if `multimask_output=True` and M = 1 if `multimask_output=False`),
  227. upsampled from the low-resolution masks, with shape size as the image
  228. (stride is 1 pixel).
  229. - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
  230. if `multimask_output=False`), the estimated IoU of each output mask.
  231. - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
  232. If `multimask_output=True`, it's the mask with the highest IoU estimate.
  233. If `multimask_output=False`, it's the same as `low_res_multimasks`.
  234. - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
  235. If `multimask_output=True`, it's the mask with the highest IoU estimate.
  236. If `multimask_output=False`, it's the same as `high_res_multimasks`.
  237. - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
  238. based on the output token from the SAM mask decoder.
  239. """
  240. B = backbone_features.size(0)
  241. device = backbone_features.device
  242. assert backbone_features.size(1) == self.sam_prompt_embed_dim
  243. assert backbone_features.size(2) == self.sam_image_embedding_size
  244. assert backbone_features.size(3) == self.sam_image_embedding_size
  245. # a) Handle point prompts
  246. if point_inputs is not None:
  247. sam_point_coords = point_inputs["point_coords"]
  248. sam_point_labels = point_inputs["point_labels"]
  249. assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
  250. else:
  251. # If no points are provide, pad with an empty point (with label -1)
  252. sam_point_coords = torch.zeros(B, 1, 2, device=device)
  253. sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
  254. # b) Handle mask prompts
  255. if mask_inputs is not None:
  256. # If mask_inputs is provided, downsize it into low-res mask input if needed
  257. # and feed it as a dense mask prompt into the SAM mask encoder
  258. assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
  259. if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
  260. sam_mask_prompt = F.interpolate(
  261. mask_inputs.float(),
  262. size=self.sam_prompt_encoder.mask_input_size,
  263. align_corners=False,
  264. mode="bilinear",
  265. antialias=True, # use antialias for downsampling
  266. )
  267. else:
  268. sam_mask_prompt = mask_inputs
  269. else:
  270. # Otherwise, simply feed None (and SAM's prompt encoder will add
  271. # a learned `no_mask_embed` to indicate no mask input in this case).
  272. sam_mask_prompt = None
  273. sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
  274. points=(sam_point_coords, sam_point_labels),
  275. boxes=None,
  276. masks=sam_mask_prompt,
  277. )
  278. # Clone image_pe and the outputs of sam_prompt_encoder
  279. # to enable compilation
  280. sparse_embeddings = self._maybe_clone(sparse_embeddings)
  281. dense_embeddings = self._maybe_clone(dense_embeddings)
  282. image_pe = self._maybe_clone(self.sam_prompt_encoder.get_dense_pe())
  283. with torch.profiler.record_function("sam_mask_decoder"):
  284. (
  285. low_res_multimasks,
  286. ious,
  287. sam_output_tokens,
  288. object_score_logits,
  289. ) = self.sam_mask_decoder(
  290. image_embeddings=backbone_features,
  291. image_pe=image_pe,
  292. sparse_prompt_embeddings=sparse_embeddings,
  293. dense_prompt_embeddings=dense_embeddings,
  294. multimask_output=multimask_output,
  295. repeat_image=False, # the image is already batched
  296. high_res_features=high_res_features,
  297. )
  298. # Clone the output of sam_mask_decoder
  299. # to enable compilation
  300. low_res_multimasks = self._maybe_clone(low_res_multimasks)
  301. ious = self._maybe_clone(ious)
  302. sam_output_tokens = self._maybe_clone(sam_output_tokens)
  303. object_score_logits = self._maybe_clone(object_score_logits)
  304. if self.training and self.teacher_force_obj_scores_for_mem:
  305. # we use gt to detect if there is an object or not to
  306. # select no obj ptr and use an empty mask for spatial memory
  307. is_obj_appearing = torch.any(gt_masks.float().flatten(1) > 0, dim=1)
  308. is_obj_appearing = is_obj_appearing[..., None]
  309. else:
  310. is_obj_appearing = object_score_logits > 0
  311. # Mask used for spatial memories is always a *hard* choice between obj and no obj,
  312. # consistent with the actual mask prediction
  313. low_res_multimasks = torch.where(
  314. is_obj_appearing[:, None, None],
  315. low_res_multimasks,
  316. NO_OBJ_SCORE,
  317. )
  318. # convert masks from possibly bfloat16 (or float16) to float32
  319. # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
  320. low_res_multimasks = low_res_multimasks.float()
  321. high_res_multimasks = F.interpolate(
  322. low_res_multimasks,
  323. size=(self.image_size, self.image_size),
  324. mode="bilinear",
  325. align_corners=False,
  326. )
  327. sam_output_token = sam_output_tokens[:, 0]
  328. if multimask_output:
  329. # take the best mask prediction (with the highest IoU estimation)
  330. best_iou_inds = torch.argmax(ious, dim=-1)
  331. batch_inds = torch.arange(B, device=device)
  332. low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  333. high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  334. if sam_output_tokens.size(1) > 1:
  335. sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
  336. else:
  337. low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
  338. # Extract object pointer from the SAM output token (with occlusion handling)
  339. obj_ptr = self.obj_ptr_proj(sam_output_token)
  340. lambda_is_obj_appearing = is_obj_appearing.float()
  341. obj_ptr = lambda_is_obj_appearing * obj_ptr
  342. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  343. return (
  344. low_res_multimasks,
  345. high_res_multimasks,
  346. ious,
  347. low_res_masks,
  348. high_res_masks,
  349. obj_ptr,
  350. object_score_logits,
  351. )
  352. def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
  353. """
  354. Directly turn binary `mask_inputs` into a output mask logits without using SAM.
  355. (same input and output shapes as in _forward_sam_heads above).
  356. """
  357. # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
  358. out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
  359. mask_inputs_float = mask_inputs.float()
  360. high_res_masks = mask_inputs_float * out_scale + out_bias
  361. low_res_masks = F.interpolate(
  362. high_res_masks,
  363. size=(
  364. high_res_masks.size(-2) // self.backbone_stride * 4,
  365. high_res_masks.size(-1) // self.backbone_stride * 4,
  366. ),
  367. align_corners=False,
  368. mode="bilinear",
  369. antialias=True, # use antialias for downsampling
  370. )
  371. # a dummy IoU prediction of all 1's under mask input
  372. ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
  373. # produce an object pointer using the SAM decoder from the mask input
  374. _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
  375. backbone_features=backbone_features,
  376. mask_inputs=self.mask_downsample(mask_inputs_float),
  377. high_res_features=high_res_features,
  378. gt_masks=mask_inputs,
  379. )
  380. # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
  381. # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
  382. # on the object_scores from the SAM decoder.
  383. is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
  384. is_obj_appearing = is_obj_appearing[..., None]
  385. lambda_is_obj_appearing = is_obj_appearing.float()
  386. object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
  387. obj_ptr = lambda_is_obj_appearing * obj_ptr
  388. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  389. return (
  390. low_res_masks,
  391. high_res_masks,
  392. ious,
  393. low_res_masks,
  394. high_res_masks,
  395. obj_ptr,
  396. object_score_logits,
  397. )
  398. def forward(self, input: BatchedDatapoint, is_inference=False):
  399. raise NotImplementedError(
  400. "Please use the corresponding methods in SAM3VideoPredictor for inference."
  401. "See examples/sam3_dense_video_tracking.ipynb for an inference example."
  402. )
  403. def forward_image(self, img_batch):
  404. """Get the image feature on the input batch."""
  405. # This line is the only change from the parent class
  406. # to use the SAM3 backbone instead of the SAM2 backbone.
  407. backbone_out = self.backbone.forward_image(img_batch)["sam2_backbone_out"]
  408. # precompute projected level 0 and level 1 features in SAM decoder
  409. # to avoid running it again on every SAM click
  410. backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
  411. backbone_out["backbone_fpn"][0]
  412. )
  413. backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
  414. backbone_out["backbone_fpn"][1]
  415. )
  416. # Clone to help torch.compile
  417. for i in range(len(backbone_out["backbone_fpn"])):
  418. backbone_out["backbone_fpn"][i] = self._maybe_clone(
  419. backbone_out["backbone_fpn"][i]
  420. )
  421. backbone_out["vision_pos_enc"][i] = self._maybe_clone(
  422. backbone_out["vision_pos_enc"][i]
  423. )
  424. return backbone_out
  425. def _prepare_backbone_features(self, backbone_out):
  426. """Prepare and flatten visual features (same as in MDETR_API model)."""
  427. backbone_out = backbone_out.copy()
  428. assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
  429. assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
  430. feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
  431. vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
  432. feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
  433. # flatten NxCxHxW to HWxNxC
  434. vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
  435. vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
  436. return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
  437. def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
  438. """Compute the image backbone features on the fly for the given img_ids."""
  439. # Only forward backbone on unique image ids to avoid repeatitive computation
  440. # (if `img_ids` has only one element, it's already unique so we skip this step).
  441. if img_ids.numel() > 1:
  442. unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
  443. else:
  444. unique_img_ids, inv_ids = img_ids, None
  445. # Compute the image features on those unique image ids
  446. image = img_batch[unique_img_ids]
  447. backbone_out = self.forward_image(image)
  448. (
  449. _,
  450. vision_feats,
  451. vision_pos_embeds,
  452. feat_sizes,
  453. ) = self._prepare_backbone_features(backbone_out)
  454. # Inverse-map image features for `unique_img_ids` to the final image features
  455. # for the original input `img_ids`.
  456. if inv_ids is not None:
  457. image = image[inv_ids]
  458. vision_feats = [x[:, inv_ids] for x in vision_feats]
  459. vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
  460. return image, vision_feats, vision_pos_embeds, feat_sizes
  461. def cal_mem_score(self, object_score_logits, iou_score):
  462. object_score_norm = torch.where(
  463. object_score_logits > 0,
  464. object_score_logits.sigmoid() * 2 - 1, ## rescale to [0, 1]
  465. torch.zeros_like(object_score_logits),
  466. )
  467. score_per_frame = (object_score_norm * iou_score).mean()
  468. return score_per_frame
  469. def frame_filter(self, output_dict, track_in_reverse, frame_idx, num_frames, r):
  470. if (frame_idx == 0 and not track_in_reverse) or (
  471. frame_idx == num_frames - 1 and track_in_reverse
  472. ):
  473. return []
  474. max_num = min(
  475. num_frames, self.max_obj_ptrs_in_encoder
  476. ) ## maximum number of pointer memory frames to consider
  477. if not track_in_reverse:
  478. start = frame_idx - 1
  479. end = 0
  480. step = -r
  481. must_include = frame_idx - 1
  482. else:
  483. start = frame_idx + 1
  484. end = num_frames
  485. step = r
  486. must_include = frame_idx + 1
  487. valid_indices = []
  488. for i in range(start, end, step):
  489. if (
  490. i not in output_dict["non_cond_frame_outputs"]
  491. or "eff_iou_score" not in output_dict["non_cond_frame_outputs"][i]
  492. ):
  493. continue
  494. score_per_frame = output_dict["non_cond_frame_outputs"][i]["eff_iou_score"]
  495. if score_per_frame > self.mf_threshold: # threshold
  496. valid_indices.insert(0, i)
  497. if len(valid_indices) >= max_num - 1:
  498. break
  499. if must_include not in valid_indices:
  500. valid_indices.append(must_include)
  501. return valid_indices
  502. def _prepare_memory_conditioned_features(
  503. self,
  504. frame_idx,
  505. is_init_cond_frame,
  506. current_vision_feats,
  507. current_vision_pos_embeds,
  508. feat_sizes,
  509. output_dict,
  510. num_frames,
  511. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  512. use_prev_mem_frame=True,
  513. ):
  514. """Fuse the current frame's visual feature map with previous memory."""
  515. B = current_vision_feats[-1].size(1) # batch size on this frame
  516. C = self.hidden_dim
  517. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  518. device = current_vision_feats[-1].device
  519. # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
  520. # In this case, we skip the fusion with any memory.
  521. if self.num_maskmem == 0: # Disable memory and skip fusion
  522. pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  523. return pix_feat
  524. num_obj_ptr_tokens = 0
  525. tpos_sign_mul = -1 if track_in_reverse else 1
  526. # Step 1: condition the visual features of the current frame on previous memories
  527. if not is_init_cond_frame and use_prev_mem_frame:
  528. # Retrieve the memories encoded with the maskmem backbone
  529. to_cat_prompt, to_cat_prompt_mask, to_cat_prompt_pos_embed = [], [], []
  530. # Add conditioning frames's output first (all cond frames have t_pos=0 for
  531. # when getting temporal positional embedding below)
  532. assert len(output_dict["cond_frame_outputs"]) > 0
  533. # Select a maximum number of temporally closest cond frames for cross attention
  534. cond_outputs = output_dict["cond_frame_outputs"]
  535. selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
  536. frame_idx,
  537. cond_outputs,
  538. self.max_cond_frames_in_attn,
  539. keep_first_cond_frame=self.keep_first_cond_frame,
  540. )
  541. t_pos_and_prevs = [
  542. ((frame_idx - t) * tpos_sign_mul, out, True)
  543. for t, out in selected_cond_outputs.items()
  544. ]
  545. # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
  546. # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
  547. # We also allow taking the memory frame non-consecutively (with r>1), in which case
  548. # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
  549. r = 1 if self.training else self.memory_temporal_stride_for_eval
  550. if self.use_memory_selection:
  551. valid_indices = self.frame_filter(
  552. output_dict, track_in_reverse, frame_idx, num_frames, r
  553. )
  554. for t_pos in range(1, self.num_maskmem):
  555. t_rel = self.num_maskmem - t_pos # how many frames before current frame
  556. if self.use_memory_selection:
  557. if t_rel > len(valid_indices):
  558. continue
  559. prev_frame_idx = valid_indices[-t_rel]
  560. else:
  561. if t_rel == 1:
  562. # for t_rel == 1, we take the last frame (regardless of r)
  563. if not track_in_reverse:
  564. # the frame immediately before this frame (i.e. frame_idx - 1)
  565. prev_frame_idx = frame_idx - t_rel
  566. else:
  567. # the frame immediately after this frame (i.e. frame_idx + 1)
  568. prev_frame_idx = frame_idx + t_rel
  569. else:
  570. # for t_rel >= 2, we take the memory frame from every r-th frames
  571. if not track_in_reverse:
  572. # first find the nearest frame among every r-th frames before this frame
  573. # for r=1, this would be (frame_idx - 2)
  574. prev_frame_idx = ((frame_idx - 2) // r) * r
  575. # then seek further among every r-th frames
  576. prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
  577. else:
  578. # first find the nearest frame among every r-th frames after this frame
  579. # for r=1, this would be (frame_idx + 2)
  580. prev_frame_idx = -(-(frame_idx + 2) // r) * r
  581. # then seek further among every r-th frames
  582. prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
  583. out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
  584. if out is None:
  585. # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
  586. # frames, we still attend to it as if it's a non-conditioning frame.
  587. out = unselected_cond_outputs.get(prev_frame_idx, None)
  588. t_pos_and_prevs.append((t_pos, out, False))
  589. for t_pos, prev, is_selected_cond_frame in t_pos_and_prevs:
  590. if prev is None:
  591. continue # skip padding frames
  592. # "maskmem_features" might have been offloaded to CPU in demo use cases,
  593. # so we load it back to GPU (it's a no-op if it's already on GPU).
  594. feats = prev["maskmem_features"].cuda(non_blocking=True)
  595. seq_len = feats.shape[-2] * feats.shape[-1]
  596. to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1))
  597. to_cat_prompt_mask.append(
  598. torch.zeros(B, seq_len, device=device, dtype=bool)
  599. )
  600. # Spatial positional encoding (it might have been offloaded to CPU in eval)
  601. maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
  602. maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
  603. if (
  604. is_selected_cond_frame
  605. and getattr(self, "cond_frame_spatial_embedding", None) is not None
  606. ):
  607. # add a spatial embedding for the conditioning frame
  608. maskmem_enc = maskmem_enc + self.cond_frame_spatial_embedding
  609. # Temporal positional encoding
  610. t = t_pos if not is_selected_cond_frame else 0
  611. maskmem_enc = (
  612. maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t - 1]
  613. )
  614. to_cat_prompt_pos_embed.append(maskmem_enc)
  615. # Construct the list of past object pointers
  616. # Optionally, select only a subset of spatial memory frames during trainining
  617. if (
  618. self.training
  619. and self.prob_to_dropout_spatial_mem > 0
  620. and self.rng.random() < self.prob_to_dropout_spatial_mem
  621. ):
  622. num_spatial_mem_keep = self.rng.integers(len(to_cat_prompt) + 1)
  623. keep = self.rng.choice(
  624. range(len(to_cat_prompt)), num_spatial_mem_keep, replace=False
  625. ).tolist()
  626. to_cat_prompt = [to_cat_prompt[i] for i in keep]
  627. to_cat_prompt_mask = [to_cat_prompt_mask[i] for i in keep]
  628. to_cat_prompt_pos_embed = [to_cat_prompt_pos_embed[i] for i in keep]
  629. max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
  630. # First add those object pointers from selected conditioning frames
  631. # (optionally, only include object pointers in the past during evaluation)
  632. if not self.training:
  633. ptr_cond_outputs = {
  634. t: out
  635. for t, out in selected_cond_outputs.items()
  636. if (t >= frame_idx if track_in_reverse else t <= frame_idx)
  637. }
  638. else:
  639. ptr_cond_outputs = selected_cond_outputs
  640. pos_and_ptrs = [
  641. # Temporal pos encoding contains how far away each pointer is from current frame
  642. (
  643. (frame_idx - t) * tpos_sign_mul,
  644. out["obj_ptr"],
  645. True, # is_selected_cond_frame
  646. )
  647. for t, out in ptr_cond_outputs.items()
  648. ]
  649. # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
  650. for t_diff in range(1, max_obj_ptrs_in_encoder):
  651. if not self.use_memory_selection:
  652. t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
  653. if t < 0 or (num_frames is not None and t >= num_frames):
  654. break
  655. else:
  656. if -t_diff <= -len(valid_indices):
  657. break
  658. t = valid_indices[-t_diff]
  659. out = output_dict["non_cond_frame_outputs"].get(
  660. t, unselected_cond_outputs.get(t, None)
  661. )
  662. if out is not None:
  663. pos_and_ptrs.append((t_diff, out["obj_ptr"], False))
  664. # If we have at least one object pointer, add them to the across attention
  665. if len(pos_and_ptrs) > 0:
  666. pos_list, ptrs_list, is_selected_cond_frame_list = zip(*pos_and_ptrs)
  667. # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
  668. obj_ptrs = torch.stack(ptrs_list, dim=0)
  669. if getattr(self, "cond_frame_obj_ptr_embedding", None) is not None:
  670. obj_ptrs = (
  671. obj_ptrs
  672. + self.cond_frame_obj_ptr_embedding
  673. * torch.tensor(is_selected_cond_frame_list, device=device)[
  674. ..., None, None
  675. ].float()
  676. )
  677. # a temporal positional embedding based on how far each object pointer is from
  678. # the current frame (sine embedding normalized by the max pointer num).
  679. obj_pos = self._get_tpos_enc(
  680. pos_list,
  681. max_abs_pos=max_obj_ptrs_in_encoder,
  682. device=device,
  683. )
  684. # expand to batch size
  685. obj_pos = obj_pos.unsqueeze(1).expand(-1, B, -1)
  686. if self.mem_dim < C:
  687. # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
  688. obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
  689. obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
  690. obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
  691. to_cat_prompt.append(obj_ptrs)
  692. to_cat_prompt_mask.append(None) # "to_cat_prompt_mask" is not used
  693. to_cat_prompt_pos_embed.append(obj_pos)
  694. num_obj_ptr_tokens = obj_ptrs.shape[0]
  695. else:
  696. num_obj_ptr_tokens = 0
  697. else:
  698. # directly add no-mem embedding (instead of using the transformer encoder)
  699. pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
  700. pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
  701. return pix_feat_with_mem
  702. # Use a dummy token on the first grame (to avoid emtpy memory input to tranformer encoder)
  703. to_cat_prompt = [self.no_mem_embed.expand(1, B, self.mem_dim)]
  704. to_cat_prompt_mask = [torch.zeros(B, 1, device=device, dtype=bool)]
  705. to_cat_prompt_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
  706. # Step 2: Concatenate the memories and forward through the transformer encoder
  707. prompt = torch.cat(to_cat_prompt, dim=0)
  708. prompt_mask = None # For now, we always masks are zeros anyways
  709. prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim=0)
  710. encoder_out = self.transformer.encoder(
  711. src=current_vision_feats,
  712. src_key_padding_mask=[None],
  713. src_pos=current_vision_pos_embeds,
  714. prompt=prompt,
  715. prompt_pos=prompt_pos_embed,
  716. prompt_key_padding_mask=prompt_mask,
  717. feat_sizes=feat_sizes,
  718. num_obj_ptr_tokens=num_obj_ptr_tokens,
  719. )
  720. # reshape the output (HW)BC => BCHW
  721. pix_feat_with_mem = encoder_out["memory"].permute(1, 2, 0).view(B, C, H, W)
  722. return pix_feat_with_mem
  723. def _encode_new_memory(
  724. self,
  725. image,
  726. current_vision_feats,
  727. feat_sizes,
  728. pred_masks_high_res,
  729. object_score_logits,
  730. is_mask_from_pts,
  731. output_dict=None,
  732. is_init_cond_frame=False,
  733. ):
  734. """Encode the current image and its prediction into a memory feature."""
  735. B = current_vision_feats[-1].size(1) # batch size on this frame
  736. C = self.hidden_dim
  737. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  738. # top-level feature, (HW)BC => BCHW
  739. pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  740. if self.non_overlap_masks_for_mem_enc and not self.training:
  741. # optionally, apply non-overlapping constraints to the masks (it's applied
  742. # in the batch dimension and should only be used during eval, where all
  743. # the objects come from the same video under batch size 1).
  744. pred_masks_high_res = self._apply_non_overlapping_constraints(
  745. pred_masks_high_res
  746. )
  747. # scale the raw mask logits with a temperature before applying sigmoid
  748. if is_mask_from_pts and not self.training:
  749. mask_for_mem = (pred_masks_high_res > 0).float()
  750. else:
  751. # apply sigmoid on the raw mask logits to turn them into range (0, 1)
  752. mask_for_mem = torch.sigmoid(pred_masks_high_res)
  753. # apply scale and bias terms to the sigmoid probabilities
  754. if self.sigmoid_scale_for_mem_enc != 1.0:
  755. mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
  756. if self.sigmoid_bias_for_mem_enc != 0.0:
  757. mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
  758. if isinstance(self.maskmem_backbone, SimpleMaskEncoder):
  759. pix_feat = pix_feat.view_as(pix_feat)
  760. maskmem_out = self.maskmem_backbone(
  761. pix_feat, mask_for_mem, skip_mask_sigmoid=True
  762. )
  763. else:
  764. maskmem_out = self.maskmem_backbone(image, pix_feat, mask_for_mem)
  765. # Clone the feats and pos_enc to enable compilation
  766. maskmem_features = self._maybe_clone(maskmem_out["vision_features"])
  767. maskmem_pos_enc = [self._maybe_clone(m) for m in maskmem_out["vision_pos_enc"]]
  768. # add a no-object embedding to the spatial memory to indicate that the frame
  769. # is predicted to be occluded (i.e. no object is appearing in the frame)
  770. is_obj_appearing = (object_score_logits > 0).float()
  771. maskmem_features += (
  772. 1 - is_obj_appearing[..., None, None]
  773. ) * self.no_obj_embed_spatial[..., None, None].expand(*maskmem_features.shape)
  774. return maskmem_features, maskmem_pos_enc
  775. def forward_tracking(self, backbone_out, input, return_dict=False):
  776. """Forward video tracking on each frame (and sample correction clicks)."""
  777. img_feats_already_computed = backbone_out["backbone_fpn"] is not None
  778. if img_feats_already_computed:
  779. # Prepare the backbone features
  780. # - vision_feats and vision_pos_embeds are in (HW)BC format
  781. (
  782. _,
  783. vision_feats,
  784. vision_pos_embeds,
  785. feat_sizes,
  786. ) = self._prepare_backbone_features(backbone_out)
  787. # Starting the stage loop
  788. num_frames = backbone_out["num_frames"]
  789. init_cond_frames = backbone_out["init_cond_frames"]
  790. frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
  791. # first process all the initial conditioning frames to encode them as memory,
  792. # and then conditioning on them to track the remaining frames
  793. processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
  794. output_dict = {
  795. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  796. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  797. }
  798. for stage_id in processing_order:
  799. # Get the image features for the current frames
  800. img_ids = input.find_inputs[stage_id].img_ids
  801. if img_feats_already_computed:
  802. # Retrieve image features according to img_ids (if they are already computed).
  803. current_image = input.img_batch[img_ids]
  804. current_vision_feats = [x[:, img_ids] for x in vision_feats]
  805. current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
  806. else:
  807. # Otherwise, compute the image features on the fly for the given img_ids
  808. # (this might be used for evaluation on long videos to avoid backbone OOM).
  809. (
  810. current_image,
  811. current_vision_feats,
  812. current_vision_pos_embeds,
  813. feat_sizes,
  814. ) = self._prepare_backbone_features_per_frame(input.img_batch, img_ids)
  815. # Get output masks based on this frame's prompts and previous memory
  816. current_out = self.track_step(
  817. frame_idx=stage_id,
  818. is_init_cond_frame=stage_id in init_cond_frames,
  819. current_vision_feats=current_vision_feats,
  820. current_vision_pos_embeds=current_vision_pos_embeds,
  821. feat_sizes=feat_sizes,
  822. image=current_image,
  823. point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
  824. mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
  825. output_dict=output_dict,
  826. num_frames=num_frames,
  827. )
  828. # Append the output, depending on whether it's a conditioning frame
  829. add_output_as_cond_frame = stage_id in init_cond_frames or (
  830. self.add_all_frames_to_correct_as_cond
  831. and stage_id in frames_to_add_correction_pt
  832. )
  833. if add_output_as_cond_frame:
  834. output_dict["cond_frame_outputs"][stage_id] = current_out
  835. else:
  836. output_dict["non_cond_frame_outputs"][stage_id] = current_out
  837. if return_dict:
  838. return output_dict
  839. # turn `output_dict` into a list for loss function
  840. all_frame_outputs = {}
  841. all_frame_outputs.update(output_dict["cond_frame_outputs"])
  842. all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
  843. all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
  844. # Make DDP happy with activation checkpointing by removing unused keys
  845. all_frame_outputs = [
  846. {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
  847. ]
  848. return all_frame_outputs
  849. def track_step(
  850. self,
  851. frame_idx,
  852. is_init_cond_frame,
  853. current_vision_feats,
  854. current_vision_pos_embeds,
  855. feat_sizes,
  856. image,
  857. point_inputs,
  858. mask_inputs,
  859. output_dict,
  860. num_frames,
  861. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  862. # Whether to run the memory encoder on the predicted masks. Sometimes we might want
  863. # to skip the memory encoder with `run_mem_encoder=False`. For example,
  864. # in demo we might call `track_step` multiple times for each user click,
  865. # and only encode the memory when the user finalizes their clicks. And in ablation
  866. # settings like SAM training on static images, we don't need the memory encoder.
  867. run_mem_encoder=True,
  868. # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
  869. prev_sam_mask_logits=None,
  870. use_prev_mem_frame=True,
  871. ):
  872. current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
  873. # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
  874. if len(current_vision_feats) > 1:
  875. high_res_features = [
  876. x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
  877. for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
  878. ]
  879. else:
  880. high_res_features = None
  881. if mask_inputs is not None:
  882. # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
  883. pix_feat = current_vision_feats[-1].permute(1, 2, 0)
  884. pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
  885. sam_outputs = self._use_mask_as_output(
  886. pix_feat, high_res_features, mask_inputs
  887. )
  888. else:
  889. # fused the visual feature with previous memory features in the memory bank
  890. pix_feat_with_mem = self._prepare_memory_conditioned_features(
  891. frame_idx=frame_idx,
  892. is_init_cond_frame=is_init_cond_frame,
  893. current_vision_feats=current_vision_feats[-1:],
  894. current_vision_pos_embeds=current_vision_pos_embeds[-1:],
  895. feat_sizes=feat_sizes[-1:],
  896. output_dict=output_dict,
  897. num_frames=num_frames,
  898. track_in_reverse=track_in_reverse,
  899. use_prev_mem_frame=use_prev_mem_frame,
  900. )
  901. # apply SAM-style segmentation head
  902. # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
  903. # e.g. in demo where such logits come from earlier interaction instead of correction sampling
  904. # (in this case, the SAM mask decoder should have `self.iter_use_prev_mask_pred=True`, and
  905. # any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
  906. if prev_sam_mask_logits is not None:
  907. assert self.iter_use_prev_mask_pred
  908. assert point_inputs is not None and mask_inputs is None
  909. mask_inputs = prev_sam_mask_logits
  910. multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
  911. sam_outputs = self._forward_sam_heads(
  912. backbone_features=pix_feat_with_mem,
  913. point_inputs=point_inputs,
  914. mask_inputs=mask_inputs,
  915. high_res_features=high_res_features,
  916. multimask_output=multimask_output,
  917. )
  918. (
  919. _,
  920. high_res_multimasks,
  921. ious,
  922. low_res_masks,
  923. high_res_masks,
  924. obj_ptr,
  925. object_score_logits,
  926. ) = sam_outputs
  927. # Use the final prediction (after all correction steps for output and eval)
  928. current_out["pred_masks"] = low_res_masks
  929. current_out["pred_masks_high_res"] = high_res_masks
  930. current_out["obj_ptr"] = obj_ptr
  931. if self.use_memory_selection:
  932. current_out["object_score_logits"] = object_score_logits
  933. iou_score = ious.max(-1)[0]
  934. current_out["iou_score"] = iou_score
  935. current_out["eff_iou_score"] = self.cal_mem_score(
  936. object_score_logits, iou_score
  937. )
  938. if not self.training:
  939. # Only add this in inference (to avoid unused param in activation checkpointing;
  940. # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
  941. current_out["object_score_logits"] = object_score_logits
  942. # Finally run the memory encoder on the predicted mask to encode
  943. # it into a new memory feature (that can be used in future frames)
  944. # (note that `self.num_maskmem == 0` is primarily used for reproducing SAM on
  945. # images, in which case we'll just skip memory encoder to save compute).
  946. if run_mem_encoder and self.num_maskmem > 0:
  947. high_res_masks_for_mem_enc = high_res_masks
  948. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  949. image=image,
  950. current_vision_feats=current_vision_feats,
  951. feat_sizes=feat_sizes,
  952. pred_masks_high_res=high_res_masks_for_mem_enc,
  953. object_score_logits=object_score_logits,
  954. is_mask_from_pts=(point_inputs is not None),
  955. output_dict=output_dict,
  956. is_init_cond_frame=is_init_cond_frame,
  957. )
  958. current_out["maskmem_features"] = maskmem_features
  959. current_out["maskmem_pos_enc"] = maskmem_pos_enc
  960. else:
  961. current_out["maskmem_features"] = None
  962. current_out["maskmem_pos_enc"] = None
  963. # Optionally, offload the outputs to CPU memory during evaluation to avoid
  964. # GPU OOM on very long videos or very large resolution or too many objects
  965. if self.offload_output_to_cpu_for_eval and not self.training:
  966. # Here we only keep those keys needed for evaluation to get a compact output
  967. trimmed_out = {
  968. "pred_masks": current_out["pred_masks"].cpu(),
  969. "pred_masks_high_res": current_out["pred_masks_high_res"].cpu(),
  970. # other items for evaluation (these are small tensors so we keep them on GPU)
  971. "obj_ptr": current_out["obj_ptr"],
  972. "object_score_logits": current_out["object_score_logits"],
  973. }
  974. if run_mem_encoder and self.num_maskmem > 0:
  975. trimmed_out["maskmem_features"] = maskmem_features.cpu()
  976. trimmed_out["maskmem_pos_enc"] = [x.cpu() for x in maskmem_pos_enc]
  977. if self.use_memory_selection:
  978. trimmed_out["iou_score"] = current_out["iou_score"].cpu()
  979. trimmed_out["eff_iou_score"] = current_out["eff_iou_score"].cpu()
  980. current_out = trimmed_out
  981. # Optionally, trim the output of past non-conditioning frame (r * num_maskmem frames
  982. # before the current frame) during evaluation. This is intended to save GPU or CPU
  983. # memory for semi-supervised VOS eval, where only the first frame receives prompts.
  984. def _trim_past_out(past_out, current_out):
  985. if past_out is None:
  986. return None
  987. return {
  988. "pred_masks": past_out["pred_masks"],
  989. "obj_ptr": past_out["obj_ptr"],
  990. "object_score_logits": past_out["object_score_logits"],
  991. }
  992. if self.trim_past_non_cond_mem_for_eval and not self.training:
  993. r = self.memory_temporal_stride_for_eval
  994. past_frame_idx = frame_idx - r * self.num_maskmem
  995. past_out = output_dict["non_cond_frame_outputs"].get(past_frame_idx, None)
  996. if past_out is not None:
  997. print(past_out.get("eff_iou_score", 0))
  998. if (
  999. self.use_memory_selection
  1000. and past_out.get("eff_iou_score", 0) < self.mf_threshold
  1001. ) or not self.use_memory_selection:
  1002. output_dict["non_cond_frame_outputs"][past_frame_idx] = (
  1003. _trim_past_out(past_out, current_out)
  1004. )
  1005. if (
  1006. self.use_memory_selection and not self.offload_output_to_cpu_for_eval
  1007. ): ## design for memory selection, trim too old frames to save memory
  1008. far_old_frame_idx = frame_idx - 20 * self.max_obj_ptrs_in_encoder
  1009. past_out = output_dict["non_cond_frame_outputs"].get(
  1010. far_old_frame_idx, None
  1011. )
  1012. if past_out is not None:
  1013. output_dict["non_cond_frame_outputs"][far_old_frame_idx] = (
  1014. _trim_past_out(past_out, current_out)
  1015. )
  1016. return current_out
  1017. def _use_multimask(self, is_init_cond_frame, point_inputs):
  1018. """Whether to use multimask output in the SAM head."""
  1019. num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
  1020. multimask_output = (
  1021. self.multimask_output_in_sam
  1022. and (is_init_cond_frame or self.multimask_output_for_tracking)
  1023. and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
  1024. )
  1025. return multimask_output
  1026. def _apply_non_overlapping_constraints(self, pred_masks):
  1027. """
  1028. Apply non-overlapping constraints to the object scores in pred_masks. Here we
  1029. keep only the highest scoring object at each spatial location in pred_masks.
  1030. """
  1031. batch_size = pred_masks.size(0)
  1032. if batch_size == 1:
  1033. return pred_masks
  1034. device = pred_masks.device
  1035. # "max_obj_inds": object index of the object with the highest score at each location
  1036. max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
  1037. # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
  1038. batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
  1039. keep = max_obj_inds == batch_obj_inds
  1040. # suppress overlapping regions' scores below -10.0 so that the foreground regions
  1041. # don't overlap (here sigmoid(-10.0)=4.5398e-05)
  1042. pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
  1043. return pred_masks
  1044. def _compile_all_components(self):
  1045. """Compile all model components for faster inference."""
  1046. # a larger cache size to hold varying number of shapes for torch.compile
  1047. # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49
  1048. torch._dynamo.config.cache_size_limit = 64
  1049. torch._dynamo.config.accumulated_cache_size_limit = 2048
  1050. from sam3.perflib.compile import compile_wrapper
  1051. logging.info("Compiling all components. First time may be very slow.")
  1052. self.maskmem_backbone.forward = compile_wrapper(
  1053. self.maskmem_backbone.forward,
  1054. mode="max-autotune",
  1055. fullgraph=True,
  1056. dynamic=False,
  1057. )
  1058. self.transformer.encoder.forward = compile_wrapper(
  1059. self.transformer.encoder.forward,
  1060. mode="max-autotune",
  1061. fullgraph=True,
  1062. dynamic=True, # Num. of memories varies
  1063. )
  1064. # We disable compilation of sam_prompt_encoder as it sometimes gives a large accuracy regression,
  1065. # especially when sam_mask_prompt (previous mask logits) is not None
  1066. # self.sam_prompt_encoder.forward = torch.compile(
  1067. # self.sam_prompt_encoder.forward,
  1068. # mode="max-autotune",
  1069. # fullgraph=True,
  1070. # dynamic=False, # Accuracy regression on True
  1071. # )
  1072. self.sam_mask_decoder.forward = compile_wrapper(
  1073. self.sam_mask_decoder.forward,
  1074. mode="max-autotune",
  1075. fullgraph=True,
  1076. dynamic=False, # Accuracy regression on True
  1077. )
  1078. def _maybe_clone(self, x):
  1079. """Clone a tensor if and only if `self.compile_all_components` is True."""
  1080. return x.clone() if self.compile_all_components else x
  1081. def concat_points(old_point_inputs, new_points, new_labels):
  1082. """Add new points and labels to previous point inputs (add at the end)."""
  1083. if old_point_inputs is None:
  1084. points, labels = new_points, new_labels
  1085. else:
  1086. points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
  1087. labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
  1088. return {"point_coords": points, "point_labels": labels}