sam2_base.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import torch.distributed
  7. import torch.nn.functional as F
  8. from torch.nn.init import trunc_normal_
  9. from sam2.modeling.sam.mask_decoder import MaskDecoder
  10. from sam2.modeling.sam.prompt_encoder import PromptEncoder
  11. from sam2.modeling.sam.transformer import TwoWayTransformer
  12. from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
  13. # a large negative value as a placeholder score for missing objects
  14. NO_OBJ_SCORE = -1024.0
  15. class SAM2Base(torch.nn.Module):
  16. def __init__(
  17. self,
  18. image_encoder,
  19. memory_attention,
  20. memory_encoder,
  21. num_maskmem=7, # default 1 input frame + 6 previous frames
  22. image_size=512,
  23. backbone_stride=16, # stride of the image backbone output
  24. sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
  25. sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
  26. # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
  27. binarize_mask_from_pts_for_mem_enc=False,
  28. use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
  29. # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
  30. # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
  31. # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
  32. max_cond_frames_in_attn=-1,
  33. # on the first frame, whether to directly add the no-memory embedding to the image feature
  34. # (instead of using the transformer encoder)
  35. directly_add_no_mem_embed=False,
  36. # whether to use high-resolution feature maps in the SAM mask decoder
  37. use_high_res_features_in_sam=False,
  38. # whether to output multiple (3) masks for the first click on initial conditioning frames
  39. multimask_output_in_sam=False,
  40. # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
  41. # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
  42. multimask_min_pt_num=1,
  43. multimask_max_pt_num=1,
  44. # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
  45. multimask_output_for_tracking=False,
  46. # Whether to use multimask tokens for obj ptr; Only relevant when both
  47. # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
  48. use_multimask_token_for_obj_ptr: bool = False,
  49. # whether to use sigmoid to restrict ious prediction to [0-1]
  50. iou_prediction_use_sigmoid=False,
  51. # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
  52. # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
  53. # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
  54. memory_temporal_stride_for_eval=1,
  55. # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
  56. non_overlap_masks_for_mem_enc=False,
  57. # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
  58. use_obj_ptrs_in_encoder=False,
  59. # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
  60. max_obj_ptrs_in_encoder=16,
  61. # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
  62. add_tpos_enc_to_obj_ptrs=True,
  63. # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
  64. # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
  65. proj_tpos_enc_in_obj_ptrs=False,
  66. # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
  67. # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
  68. use_signed_tpos_enc_to_obj_ptrs=False,
  69. # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
  70. # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
  71. only_obj_ptrs_in_the_past_for_eval=False,
  72. # Whether to predict if there is an object in the frame
  73. pred_obj_scores: bool = False,
  74. # Whether to use an MLP to predict object scores
  75. pred_obj_scores_mlp: bool = False,
  76. # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
  77. # Whether to have a fixed no obj pointer when there is no object present
  78. # or to use it as an additive embedding with obj_ptr produced by decoder
  79. fixed_no_obj_ptr: bool = False,
  80. # Soft no object, i.e. mix in no_obj_ptr softly,
  81. # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
  82. soft_no_obj_ptr: bool = False,
  83. use_mlp_for_obj_ptr_proj: bool = False,
  84. # add no obj embedding to spatial frames
  85. no_obj_embed_spatial: bool = False,
  86. # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
  87. sam_mask_decoder_extra_args=None,
  88. compile_image_encoder: bool = False,
  89. ):
  90. super().__init__()
  91. # Part 1: the image backbone
  92. self.image_encoder = image_encoder
  93. # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
  94. self.use_high_res_features_in_sam = use_high_res_features_in_sam
  95. self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
  96. self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
  97. self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
  98. if use_obj_ptrs_in_encoder:
  99. # A conv layer to downsample the mask prompt to stride 4 (the same stride as
  100. # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
  101. # so that it can be fed into the SAM mask decoder to generate a pointer.
  102. self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
  103. self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
  104. if proj_tpos_enc_in_obj_ptrs:
  105. assert add_tpos_enc_to_obj_ptrs # these options need to be used together
  106. self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
  107. self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
  108. self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
  109. # Part 2: memory attention to condition current frame's visual features
  110. # with memories (and obj ptrs) from past frames
  111. self.memory_attention = memory_attention
  112. self.hidden_dim = image_encoder.neck.d_model
  113. # Part 3: memory encoder for the previous frame's outputs
  114. self.memory_encoder = memory_encoder
  115. self.mem_dim = self.hidden_dim
  116. if hasattr(self.memory_encoder, "out_proj") and hasattr(
  117. self.memory_encoder.out_proj, "weight"
  118. ):
  119. # if there is compression of memories along channel dim
  120. self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
  121. self.num_maskmem = num_maskmem # Number of memories accessible
  122. # Temporal encoding of the memories
  123. self.maskmem_tpos_enc = torch.nn.Parameter(
  124. torch.zeros(num_maskmem, 1, 1, self.mem_dim)
  125. )
  126. trunc_normal_(self.maskmem_tpos_enc, std=0.02)
  127. # a single token to indicate no memory embedding from previous frames
  128. self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  129. self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  130. trunc_normal_(self.no_mem_embed, std=0.02)
  131. trunc_normal_(self.no_mem_pos_enc, std=0.02)
  132. self.directly_add_no_mem_embed = directly_add_no_mem_embed
  133. # Apply sigmoid to the output raw mask logits (to turn them from
  134. # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
  135. self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
  136. self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
  137. self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
  138. self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
  139. self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
  140. # On frames with mask input, whether to directly output the input mask without
  141. # using a SAM prompt encoder + mask decoder
  142. self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
  143. self.multimask_output_in_sam = multimask_output_in_sam
  144. self.multimask_min_pt_num = multimask_min_pt_num
  145. self.multimask_max_pt_num = multimask_max_pt_num
  146. self.multimask_output_for_tracking = multimask_output_for_tracking
  147. self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
  148. self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
  149. # Part 4: SAM-style prompt encoder (for both mask and point inputs)
  150. # and SAM-style mask decoder for the final mask output
  151. self.image_size = image_size
  152. self.backbone_stride = backbone_stride
  153. self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
  154. self.pred_obj_scores = pred_obj_scores
  155. self.pred_obj_scores_mlp = pred_obj_scores_mlp
  156. self.fixed_no_obj_ptr = fixed_no_obj_ptr
  157. self.soft_no_obj_ptr = soft_no_obj_ptr
  158. if self.fixed_no_obj_ptr:
  159. assert self.pred_obj_scores
  160. assert self.use_obj_ptrs_in_encoder
  161. if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
  162. self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
  163. trunc_normal_(self.no_obj_ptr, std=0.02)
  164. self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
  165. self.no_obj_embed_spatial = None
  166. if no_obj_embed_spatial:
  167. self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
  168. trunc_normal_(self.no_obj_embed_spatial, std=0.02)
  169. self._build_sam_heads()
  170. self.max_cond_frames_in_attn = max_cond_frames_in_attn
  171. # Model compilation
  172. if compile_image_encoder:
  173. # Compile the forward function (not the full module) to allow loading checkpoints.
  174. print(
  175. "Image encoder compilation is enabled. First forward pass will be slow."
  176. )
  177. self.image_encoder.forward = torch.compile(
  178. self.image_encoder.forward,
  179. mode="max-autotune",
  180. fullgraph=True,
  181. dynamic=False,
  182. )
  183. @property
  184. def device(self):
  185. return next(self.parameters()).device
  186. def forward(self, *args, **kwargs):
  187. raise NotImplementedError(
  188. "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
  189. "See notebooks/video_predictor_example.ipynb for an inference example."
  190. )
  191. def _build_sam_heads(self):
  192. """Build SAM-style prompt encoder and mask decoder."""
  193. self.sam_prompt_embed_dim = self.hidden_dim
  194. self.sam_image_embedding_size = self.image_size // self.backbone_stride
  195. # build PromptEncoder and MaskDecoder from SAM
  196. # (their hyperparameters like `mask_in_chans=16` are from SAM code)
  197. self.sam_prompt_encoder = PromptEncoder(
  198. embed_dim=self.sam_prompt_embed_dim,
  199. image_embedding_size=(
  200. self.sam_image_embedding_size,
  201. self.sam_image_embedding_size,
  202. ),
  203. input_image_size=(self.image_size, self.image_size),
  204. mask_in_chans=16,
  205. )
  206. self.sam_mask_decoder = MaskDecoder(
  207. num_multimask_outputs=3,
  208. transformer=TwoWayTransformer(
  209. depth=2,
  210. embedding_dim=self.sam_prompt_embed_dim,
  211. mlp_dim=2048,
  212. num_heads=8,
  213. ),
  214. transformer_dim=self.sam_prompt_embed_dim,
  215. iou_head_depth=3,
  216. iou_head_hidden_dim=256,
  217. use_high_res_features=self.use_high_res_features_in_sam,
  218. iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
  219. pred_obj_scores=self.pred_obj_scores,
  220. pred_obj_scores_mlp=self.pred_obj_scores_mlp,
  221. use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
  222. **(self.sam_mask_decoder_extra_args or {}),
  223. )
  224. if self.use_obj_ptrs_in_encoder:
  225. # a linear projection on SAM output tokens to turn them into object pointers
  226. self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
  227. if self.use_mlp_for_obj_ptr_proj:
  228. self.obj_ptr_proj = MLP(
  229. self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
  230. )
  231. else:
  232. self.obj_ptr_proj = torch.nn.Identity()
  233. if self.proj_tpos_enc_in_obj_ptrs:
  234. # a linear projection on temporal positional encoding in object pointers to
  235. # avoid potential interference with spatial positional encoding
  236. self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
  237. else:
  238. self.obj_ptr_tpos_proj = torch.nn.Identity()
  239. def _forward_sam_heads(
  240. self,
  241. backbone_features,
  242. point_inputs=None,
  243. mask_inputs=None,
  244. high_res_features=None,
  245. multimask_output=False,
  246. ):
  247. """
  248. Forward SAM prompt encoders and mask heads.
  249. Inputs:
  250. - backbone_features: image features of [B, C, H, W] shape
  251. - point_inputs: a dictionary with "point_coords" and "point_labels", where
  252. 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
  253. absolute pixel-unit coordinate in (x, y) format of the P input points
  254. 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
  255. positive clicks, 0 means negative clicks, and -1 means padding
  256. - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
  257. same spatial size as the image.
  258. - high_res_features: either 1) None or 2) or a list of length 2 containing
  259. two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
  260. which will be used as high-resolution feature maps for SAM decoder.
  261. - multimask_output: if it's True, we output 3 candidate masks and their 3
  262. corresponding IoU estimates, and if it's False, we output only 1 mask and
  263. its corresponding IoU estimate.
  264. Outputs:
  265. - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
  266. `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
  267. output mask logits (before sigmoid) for the low-resolution masks, with 4x
  268. the resolution (1/4 stride) of the input backbone_features.
  269. - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
  270. if `multimask_output=True` and M = 1 if `multimask_output=False`),
  271. upsampled from the low-resolution masks, with shape size as the image
  272. (stride is 1 pixel).
  273. - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
  274. if `multimask_output=False`), the estimated IoU of each output mask.
  275. - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
  276. If `multimask_output=True`, it's the mask with the highest IoU estimate.
  277. If `multimask_output=False`, it's the same as `low_res_multimasks`.
  278. - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
  279. If `multimask_output=True`, it's the mask with the highest IoU estimate.
  280. If `multimask_output=False`, it's the same as `high_res_multimasks`.
  281. - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
  282. based on the output token from the SAM mask decoder.
  283. """
  284. B = backbone_features.size(0)
  285. device = backbone_features.device
  286. assert backbone_features.size(1) == self.sam_prompt_embed_dim
  287. assert backbone_features.size(2) == self.sam_image_embedding_size
  288. assert backbone_features.size(3) == self.sam_image_embedding_size
  289. # a) Handle point prompts
  290. if point_inputs is not None:
  291. sam_point_coords = point_inputs["point_coords"]
  292. sam_point_labels = point_inputs["point_labels"]
  293. assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
  294. else:
  295. # If no points are provide, pad with an empty point (with label -1)
  296. sam_point_coords = torch.zeros(B, 1, 2, device=device)
  297. sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
  298. # b) Handle mask prompts
  299. if mask_inputs is not None:
  300. # If mask_inputs is provided, downsize it into low-res mask input if needed
  301. # and feed it as a dense mask prompt into the SAM mask encoder
  302. assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
  303. if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
  304. sam_mask_prompt = F.interpolate(
  305. mask_inputs.float(),
  306. size=self.sam_prompt_encoder.mask_input_size,
  307. align_corners=False,
  308. mode="bilinear",
  309. antialias=True, # use antialias for downsampling
  310. )
  311. else:
  312. sam_mask_prompt = mask_inputs
  313. else:
  314. # Otherwise, simply feed None (and SAM's prompt encoder will add
  315. # a learned `no_mask_embed` to indicate no mask input in this case).
  316. sam_mask_prompt = None
  317. sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
  318. points=(sam_point_coords, sam_point_labels),
  319. boxes=None,
  320. masks=sam_mask_prompt,
  321. )
  322. (
  323. low_res_multimasks,
  324. ious,
  325. sam_output_tokens,
  326. object_score_logits,
  327. ) = self.sam_mask_decoder(
  328. image_embeddings=backbone_features,
  329. image_pe=self.sam_prompt_encoder.get_dense_pe(),
  330. sparse_prompt_embeddings=sparse_embeddings,
  331. dense_prompt_embeddings=dense_embeddings,
  332. multimask_output=multimask_output,
  333. repeat_image=False, # the image is already batched
  334. high_res_features=high_res_features,
  335. )
  336. if self.pred_obj_scores:
  337. is_obj_appearing = object_score_logits > 0
  338. # Mask used for spatial memories is always a *hard* choice between obj and no obj,
  339. # consistent with the actual mask prediction
  340. low_res_multimasks = torch.where(
  341. is_obj_appearing[:, None, None],
  342. low_res_multimasks,
  343. NO_OBJ_SCORE,
  344. )
  345. # convert masks from possibly bfloat16 (or float16) to float32
  346. # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
  347. low_res_multimasks = low_res_multimasks.float()
  348. high_res_multimasks = F.interpolate(
  349. low_res_multimasks,
  350. size=(self.image_size, self.image_size),
  351. mode="bilinear",
  352. align_corners=False,
  353. )
  354. sam_output_token = sam_output_tokens[:, 0]
  355. if multimask_output:
  356. # take the best mask prediction (with the highest IoU estimation)
  357. best_iou_inds = torch.argmax(ious, dim=-1)
  358. batch_inds = torch.arange(B, device=device)
  359. low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  360. high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  361. if sam_output_tokens.size(1) > 1:
  362. sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
  363. else:
  364. low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
  365. # Extract object pointer from the SAM output token (with occlusion handling)
  366. obj_ptr = self.obj_ptr_proj(sam_output_token)
  367. if self.pred_obj_scores:
  368. # Allow *soft* no obj ptr, unlike for masks
  369. if self.soft_no_obj_ptr:
  370. lambda_is_obj_appearing = object_score_logits.sigmoid()
  371. else:
  372. lambda_is_obj_appearing = is_obj_appearing.float()
  373. if self.fixed_no_obj_ptr:
  374. obj_ptr = lambda_is_obj_appearing * obj_ptr
  375. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  376. return (
  377. low_res_multimasks,
  378. high_res_multimasks,
  379. ious,
  380. low_res_masks,
  381. high_res_masks,
  382. obj_ptr,
  383. object_score_logits,
  384. )
  385. def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
  386. """
  387. Directly turn binary `mask_inputs` into a output mask logits without using SAM.
  388. (same input and output shapes as in _forward_sam_heads above).
  389. """
  390. # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
  391. out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
  392. mask_inputs_float = mask_inputs.float()
  393. high_res_masks = mask_inputs_float * out_scale + out_bias
  394. low_res_masks = F.interpolate(
  395. high_res_masks,
  396. size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
  397. align_corners=False,
  398. mode="bilinear",
  399. antialias=True, # use antialias for downsampling
  400. )
  401. # a dummy IoU prediction of all 1's under mask input
  402. ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
  403. if not self.use_obj_ptrs_in_encoder:
  404. # all zeros as a dummy object pointer (of shape [B, C])
  405. obj_ptr = torch.zeros(
  406. mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
  407. )
  408. else:
  409. # produce an object pointer using the SAM decoder from the mask input
  410. _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
  411. backbone_features=backbone_features,
  412. mask_inputs=self.mask_downsample(mask_inputs_float),
  413. high_res_features=high_res_features,
  414. )
  415. # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
  416. # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
  417. # on the object_scores from the SAM decoder.
  418. is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
  419. is_obj_appearing = is_obj_appearing[..., None]
  420. lambda_is_obj_appearing = is_obj_appearing.float()
  421. object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
  422. if self.pred_obj_scores:
  423. if self.fixed_no_obj_ptr:
  424. obj_ptr = lambda_is_obj_appearing * obj_ptr
  425. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  426. return (
  427. low_res_masks,
  428. high_res_masks,
  429. ious,
  430. low_res_masks,
  431. high_res_masks,
  432. obj_ptr,
  433. object_score_logits,
  434. )
  435. def forward_image(self, img_batch: torch.Tensor):
  436. """Get the image feature on the input batch."""
  437. backbone_out = self.image_encoder(img_batch)
  438. if self.use_high_res_features_in_sam:
  439. # precompute projected level 0 and level 1 features in SAM decoder
  440. # to avoid running it again on every SAM click
  441. backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
  442. backbone_out["backbone_fpn"][0]
  443. )
  444. backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
  445. backbone_out["backbone_fpn"][1]
  446. )
  447. return backbone_out
  448. def _prepare_backbone_features(self, backbone_out):
  449. """Prepare and flatten visual features."""
  450. backbone_out = backbone_out.copy()
  451. assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
  452. assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
  453. feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
  454. vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
  455. feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
  456. # flatten NxCxHxW to HWxNxC
  457. vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
  458. vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
  459. return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
  460. def _prepare_memory_conditioned_features(
  461. self,
  462. frame_idx,
  463. is_init_cond_frame,
  464. current_vision_feats,
  465. current_vision_pos_embeds,
  466. feat_sizes,
  467. output_dict,
  468. num_frames,
  469. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  470. ):
  471. """Fuse the current frame's visual feature map with previous memory."""
  472. B = current_vision_feats[-1].size(1) # batch size on this frame
  473. C = self.hidden_dim
  474. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  475. device = current_vision_feats[-1].device
  476. # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
  477. # In this case, we skip the fusion with any memory.
  478. if self.num_maskmem == 0: # Disable memory and skip fusion
  479. pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  480. return pix_feat
  481. num_obj_ptr_tokens = 0
  482. tpos_sign_mul = -1 if track_in_reverse else 1
  483. # Step 1: condition the visual features of the current frame on previous memories
  484. if not is_init_cond_frame:
  485. # Retrieve the memories encoded with the maskmem backbone
  486. to_cat_memory, to_cat_memory_pos_embed = [], []
  487. # Add conditioning frames's output first (all cond frames have t_pos=0 for
  488. # when getting temporal positional embedding below)
  489. assert len(output_dict["cond_frame_outputs"]) > 0
  490. # Select a maximum number of temporally closest cond frames for cross attention
  491. cond_outputs = output_dict["cond_frame_outputs"]
  492. selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
  493. frame_idx, cond_outputs, self.max_cond_frames_in_attn
  494. )
  495. t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
  496. # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
  497. # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
  498. # We also allow taking the memory frame non-consecutively (with stride>1), in which case
  499. # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
  500. stride = 1 if self.training else self.memory_temporal_stride_for_eval
  501. for t_pos in range(1, self.num_maskmem):
  502. t_rel = self.num_maskmem - t_pos # how many frames before current frame
  503. if t_rel == 1:
  504. # for t_rel == 1, we take the last frame (regardless of r)
  505. if not track_in_reverse:
  506. # the frame immediately before this frame (i.e. frame_idx - 1)
  507. prev_frame_idx = frame_idx - t_rel
  508. else:
  509. # the frame immediately after this frame (i.e. frame_idx + 1)
  510. prev_frame_idx = frame_idx + t_rel
  511. else:
  512. # for t_rel >= 2, we take the memory frame from every r-th frames
  513. if not track_in_reverse:
  514. # first find the nearest frame among every r-th frames before this frame
  515. # for r=1, this would be (frame_idx - 2)
  516. prev_frame_idx = ((frame_idx - 2) // stride) * stride
  517. # then seek further among every r-th frames
  518. prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
  519. else:
  520. # first find the nearest frame among every r-th frames after this frame
  521. # for r=1, this would be (frame_idx + 2)
  522. prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
  523. # then seek further among every r-th frames
  524. prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
  525. out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
  526. if out is None:
  527. # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
  528. # frames, we still attend to it as if it's a non-conditioning frame.
  529. out = unselected_cond_outputs.get(prev_frame_idx, None)
  530. t_pos_and_prevs.append((t_pos, out))
  531. for t_pos, prev in t_pos_and_prevs:
  532. if prev is None:
  533. continue # skip padding frames
  534. # "maskmem_features" might have been offloaded to CPU in demo use cases,
  535. # so we load it back to GPU (it's a no-op if it's already on GPU).
  536. feats = prev["maskmem_features"].to(device, non_blocking=True)
  537. to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
  538. # Spatial positional encoding (it might have been offloaded to CPU in eval)
  539. maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
  540. maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
  541. # Temporal positional encoding
  542. maskmem_enc = (
  543. maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
  544. )
  545. to_cat_memory_pos_embed.append(maskmem_enc)
  546. # Construct the list of past object pointers
  547. if self.use_obj_ptrs_in_encoder:
  548. max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
  549. # First add those object pointers from selected conditioning frames
  550. # (optionally, only include object pointers in the past during evaluation)
  551. if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
  552. ptr_cond_outputs = {
  553. t: out
  554. for t, out in selected_cond_outputs.items()
  555. if (t >= frame_idx if track_in_reverse else t <= frame_idx)
  556. }
  557. else:
  558. ptr_cond_outputs = selected_cond_outputs
  559. pos_and_ptrs = [
  560. # Temporal pos encoding contains how far away each pointer is from current frame
  561. (
  562. (
  563. (frame_idx - t) * tpos_sign_mul
  564. if self.use_signed_tpos_enc_to_obj_ptrs
  565. else abs(frame_idx - t)
  566. ),
  567. out["obj_ptr"],
  568. )
  569. for t, out in ptr_cond_outputs.items()
  570. ]
  571. # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
  572. for t_diff in range(1, max_obj_ptrs_in_encoder):
  573. t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
  574. if t < 0 or (num_frames is not None and t >= num_frames):
  575. break
  576. out = output_dict["non_cond_frame_outputs"].get(
  577. t, unselected_cond_outputs.get(t, None)
  578. )
  579. if out is not None:
  580. pos_and_ptrs.append((t_diff, out["obj_ptr"]))
  581. # If we have at least one object pointer, add them to the across attention
  582. if len(pos_and_ptrs) > 0:
  583. pos_list, ptrs_list = zip(*pos_and_ptrs)
  584. # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
  585. obj_ptrs = torch.stack(ptrs_list, dim=0)
  586. # a temporal positional embedding based on how far each object pointer is from
  587. # the current frame (sine embedding normalized by the max pointer num).
  588. if self.add_tpos_enc_to_obj_ptrs:
  589. t_diff_max = max_obj_ptrs_in_encoder - 1
  590. tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
  591. obj_pos = (
  592. torch.tensor(pos_list)
  593. .pin_memory()
  594. .to(device=device, non_blocking=True)
  595. )
  596. obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
  597. obj_pos = self.obj_ptr_tpos_proj(obj_pos)
  598. obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
  599. else:
  600. obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
  601. if self.mem_dim < C:
  602. # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
  603. obj_ptrs = obj_ptrs.reshape(
  604. -1, B, C // self.mem_dim, self.mem_dim
  605. )
  606. obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
  607. obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
  608. to_cat_memory.append(obj_ptrs)
  609. to_cat_memory_pos_embed.append(obj_pos)
  610. num_obj_ptr_tokens = obj_ptrs.shape[0]
  611. else:
  612. num_obj_ptr_tokens = 0
  613. else:
  614. # for initial conditioning frames, encode them without using any previous memory
  615. if self.directly_add_no_mem_embed:
  616. # directly add no-mem embedding (instead of using the transformer encoder)
  617. pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
  618. pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
  619. return pix_feat_with_mem
  620. # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
  621. to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
  622. to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
  623. # Step 2: Concatenate the memories and forward through the transformer encoder
  624. memory = torch.cat(to_cat_memory, dim=0)
  625. memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
  626. pix_feat_with_mem = self.memory_attention(
  627. curr=current_vision_feats,
  628. curr_pos=current_vision_pos_embeds,
  629. memory=memory,
  630. memory_pos=memory_pos_embed,
  631. num_obj_ptr_tokens=num_obj_ptr_tokens,
  632. )
  633. # reshape the output (HW)BC => BCHW
  634. pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
  635. return pix_feat_with_mem
  636. def _encode_new_memory(
  637. self,
  638. current_vision_feats,
  639. feat_sizes,
  640. pred_masks_high_res,
  641. object_score_logits,
  642. is_mask_from_pts,
  643. ):
  644. """Encode the current image and its prediction into a memory feature."""
  645. B = current_vision_feats[-1].size(1) # batch size on this frame
  646. C = self.hidden_dim
  647. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  648. # top-level feature, (HW)BC => BCHW
  649. pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  650. if self.non_overlap_masks_for_mem_enc and not self.training:
  651. # optionally, apply non-overlapping constraints to the masks (it's applied
  652. # in the batch dimension and should only be used during eval, where all
  653. # the objects come from the same video under batch size 1).
  654. pred_masks_high_res = self._apply_non_overlapping_constraints(
  655. pred_masks_high_res
  656. )
  657. # scale the raw mask logits with a temperature before applying sigmoid
  658. binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
  659. if binarize and not self.training:
  660. mask_for_mem = (pred_masks_high_res > 0).float()
  661. else:
  662. # apply sigmoid on the raw mask logits to turn them into range (0, 1)
  663. mask_for_mem = torch.sigmoid(pred_masks_high_res)
  664. # apply scale and bias terms to the sigmoid probabilities
  665. if self.sigmoid_scale_for_mem_enc != 1.0:
  666. mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
  667. if self.sigmoid_bias_for_mem_enc != 0.0:
  668. mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
  669. maskmem_out = self.memory_encoder(
  670. pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
  671. )
  672. maskmem_features = maskmem_out["vision_features"]
  673. maskmem_pos_enc = maskmem_out["vision_pos_enc"]
  674. # add a no-object embedding to the spatial memory to indicate that the frame
  675. # is predicted to be occluded (i.e. no object is appearing in the frame)
  676. if self.no_obj_embed_spatial is not None:
  677. is_obj_appearing = (object_score_logits > 0).float()
  678. maskmem_features += (
  679. 1 - is_obj_appearing[..., None, None]
  680. ) * self.no_obj_embed_spatial[..., None, None].expand(
  681. *maskmem_features.shape
  682. )
  683. return maskmem_features, maskmem_pos_enc
  684. def _track_step(
  685. self,
  686. frame_idx,
  687. is_init_cond_frame,
  688. current_vision_feats,
  689. current_vision_pos_embeds,
  690. feat_sizes,
  691. point_inputs,
  692. mask_inputs,
  693. output_dict,
  694. num_frames,
  695. track_in_reverse,
  696. prev_sam_mask_logits,
  697. ):
  698. current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
  699. # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
  700. if len(current_vision_feats) > 1:
  701. high_res_features = [
  702. x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
  703. for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
  704. ]
  705. else:
  706. high_res_features = None
  707. if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
  708. # When use_mask_input_as_output_without_sam=True, we directly output the mask input
  709. # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
  710. pix_feat = current_vision_feats[-1].permute(1, 2, 0)
  711. pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
  712. sam_outputs = self._use_mask_as_output(
  713. pix_feat, high_res_features, mask_inputs
  714. )
  715. else:
  716. # fused the visual feature with previous memory features in the memory bank
  717. pix_feat = self._prepare_memory_conditioned_features(
  718. frame_idx=frame_idx,
  719. is_init_cond_frame=is_init_cond_frame,
  720. current_vision_feats=current_vision_feats[-1:],
  721. current_vision_pos_embeds=current_vision_pos_embeds[-1:],
  722. feat_sizes=feat_sizes[-1:],
  723. output_dict=output_dict,
  724. num_frames=num_frames,
  725. track_in_reverse=track_in_reverse,
  726. )
  727. # apply SAM-style segmentation head
  728. # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
  729. # e.g. in demo where such logits come from earlier interaction instead of correction sampling
  730. # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
  731. if prev_sam_mask_logits is not None:
  732. assert point_inputs is not None and mask_inputs is None
  733. mask_inputs = prev_sam_mask_logits
  734. multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
  735. sam_outputs = self._forward_sam_heads(
  736. backbone_features=pix_feat,
  737. point_inputs=point_inputs,
  738. mask_inputs=mask_inputs,
  739. high_res_features=high_res_features,
  740. multimask_output=multimask_output,
  741. )
  742. return current_out, sam_outputs, high_res_features, pix_feat
  743. def _encode_memory_in_output(
  744. self,
  745. current_vision_feats,
  746. feat_sizes,
  747. point_inputs,
  748. run_mem_encoder,
  749. high_res_masks,
  750. object_score_logits,
  751. current_out,
  752. ):
  753. if run_mem_encoder and self.num_maskmem > 0:
  754. high_res_masks_for_mem_enc = high_res_masks
  755. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  756. current_vision_feats=current_vision_feats,
  757. feat_sizes=feat_sizes,
  758. pred_masks_high_res=high_res_masks_for_mem_enc,
  759. object_score_logits=object_score_logits,
  760. is_mask_from_pts=(point_inputs is not None),
  761. )
  762. current_out["maskmem_features"] = maskmem_features
  763. current_out["maskmem_pos_enc"] = maskmem_pos_enc
  764. else:
  765. current_out["maskmem_features"] = None
  766. current_out["maskmem_pos_enc"] = None
  767. def track_step(
  768. self,
  769. frame_idx,
  770. is_init_cond_frame,
  771. current_vision_feats,
  772. current_vision_pos_embeds,
  773. feat_sizes,
  774. point_inputs,
  775. mask_inputs,
  776. output_dict,
  777. num_frames,
  778. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  779. # Whether to run the memory encoder on the predicted masks. Sometimes we might want
  780. # to skip the memory encoder with `run_mem_encoder=False`. For example,
  781. # in demo we might call `track_step` multiple times for each user click,
  782. # and only encode the memory when the user finalizes their clicks. And in ablation
  783. # settings like SAM training on static images, we don't need the memory encoder.
  784. run_mem_encoder=True,
  785. # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
  786. prev_sam_mask_logits=None,
  787. ):
  788. current_out, sam_outputs, _, _ = self._track_step(
  789. frame_idx,
  790. is_init_cond_frame,
  791. current_vision_feats,
  792. current_vision_pos_embeds,
  793. feat_sizes,
  794. point_inputs,
  795. mask_inputs,
  796. output_dict,
  797. num_frames,
  798. track_in_reverse,
  799. prev_sam_mask_logits,
  800. )
  801. (
  802. _,
  803. _,
  804. _,
  805. low_res_masks,
  806. high_res_masks,
  807. obj_ptr,
  808. object_score_logits,
  809. ) = sam_outputs
  810. current_out["pred_masks"] = low_res_masks
  811. current_out["pred_masks_high_res"] = high_res_masks
  812. current_out["obj_ptr"] = obj_ptr
  813. if not self.training:
  814. # Only add this in inference (to avoid unused param in activation checkpointing;
  815. # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
  816. current_out["object_score_logits"] = object_score_logits
  817. # Finally run the memory encoder on the predicted mask to encode
  818. # it into a new memory feature (that can be used in future frames)
  819. self._encode_memory_in_output(
  820. current_vision_feats,
  821. feat_sizes,
  822. point_inputs,
  823. run_mem_encoder,
  824. high_res_masks,
  825. object_score_logits,
  826. current_out,
  827. )
  828. return current_out
  829. def _use_multimask(self, is_init_cond_frame, point_inputs):
  830. """Whether to use multimask output in the SAM head."""
  831. num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
  832. multimask_output = (
  833. self.multimask_output_in_sam
  834. and (is_init_cond_frame or self.multimask_output_for_tracking)
  835. and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
  836. )
  837. return multimask_output
  838. def _apply_non_overlapping_constraints(self, pred_masks):
  839. """
  840. Apply non-overlapping constraints to the object scores in pred_masks. Here we
  841. keep only the highest scoring object at each spatial location in pred_masks.
  842. """
  843. batch_size = pred_masks.size(0)
  844. if batch_size == 1:
  845. return pred_masks
  846. device = pred_masks.device
  847. # "max_obj_inds": object index of the object with the highest score at each location
  848. max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
  849. # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
  850. batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
  851. keep = max_obj_inds == batch_obj_inds
  852. # suppress overlapping regions' scores below -10.0 so that the foreground regions
  853. # don't overlap (here sigmoid(-10.0)=4.5398e-05)
  854. pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
  855. return pred_masks