sam2.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import numpy as np
  7. import torch
  8. import torch.distributed
  9. from sam2.modeling.sam2_base import SAM2Base
  10. from sam2.modeling.sam2_utils import (
  11. get_1d_sine_pe,
  12. get_next_point,
  13. sample_box_points,
  14. select_closest_cond_frames,
  15. )
  16. from sam2.utils.misc import concat_points
  17. from training.utils.data_utils import BatchedVideoDatapoint
  18. class SAM2Train(SAM2Base):
  19. def __init__(
  20. self,
  21. image_encoder,
  22. memory_attention=None,
  23. memory_encoder=None,
  24. prob_to_use_pt_input_for_train=0.0,
  25. prob_to_use_pt_input_for_eval=0.0,
  26. prob_to_use_box_input_for_train=0.0,
  27. prob_to_use_box_input_for_eval=0.0,
  28. # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
  29. num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
  30. num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
  31. rand_frames_to_correct_for_train=False,
  32. rand_frames_to_correct_for_eval=False,
  33. # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
  34. # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
  35. # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
  36. # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
  37. # these are initial conditioning frames because as we track the video, more conditioning frames might be added
  38. # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
  39. num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
  40. num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
  41. rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
  42. rand_init_cond_frames_for_eval=False,
  43. # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
  44. # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
  45. add_all_frames_to_correct_as_cond=False,
  46. # how many additional correction points to sample (on each frame selected to be corrected)
  47. # note that the first frame receives an initial input click (in addition to any correction clicks)
  48. num_correction_pt_per_frame=7,
  49. # method for point sampling during evaluation
  50. # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
  51. # default to "center" to be consistent with evaluation in the SAM paper
  52. pt_sampling_for_eval="center",
  53. # During training, we optionally allow sampling the correction points from GT regions
  54. # instead of the prediction error regions with a small probability. This might allow the
  55. # model to overfit less to the error regions in training datasets
  56. prob_to_sample_from_gt_for_train=0.0,
  57. use_act_ckpt_iterative_pt_sampling=False,
  58. # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
  59. # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
  60. forward_backbone_per_frame_for_eval=False,
  61. freeze_image_encoder=False,
  62. **kwargs,
  63. ):
  64. super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
  65. self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
  66. self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
  67. # Point sampler and conditioning frames
  68. self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
  69. self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
  70. self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
  71. self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
  72. if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
  73. logging.info(
  74. f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
  75. )
  76. assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
  77. assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
  78. self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
  79. self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
  80. self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
  81. self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
  82. # Initial multi-conditioning frames
  83. self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
  84. self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
  85. self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
  86. self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
  87. self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
  88. self.num_correction_pt_per_frame = num_correction_pt_per_frame
  89. self.pt_sampling_for_eval = pt_sampling_for_eval
  90. self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
  91. # A random number generator with a fixed initial seed across GPUs
  92. self.rng = np.random.default_rng(seed=42)
  93. if freeze_image_encoder:
  94. for p in self.image_encoder.parameters():
  95. p.requires_grad = False
  96. def forward(self, input: BatchedVideoDatapoint):
  97. if self.training or not self.forward_backbone_per_frame_for_eval:
  98. # precompute image features on all frames before tracking
  99. backbone_out = self.forward_image(input.flat_img_batch)
  100. else:
  101. # defer image feature computation on a frame until it's being tracked
  102. backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
  103. backbone_out = self.prepare_prompt_inputs(backbone_out, input)
  104. previous_stages_out = self.forward_tracking(backbone_out, input)
  105. return previous_stages_out
  106. def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
  107. """Compute the image backbone features on the fly for the given img_ids."""
  108. # Only forward backbone on unique image ids to avoid repetitive computation
  109. # (if `img_ids` has only one element, it's already unique so we skip this step).
  110. if img_ids.numel() > 1:
  111. unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
  112. else:
  113. unique_img_ids, inv_ids = img_ids, None
  114. # Compute the image features on those unique image ids
  115. image = img_batch[unique_img_ids]
  116. backbone_out = self.forward_image(image)
  117. (
  118. _,
  119. vision_feats,
  120. vision_pos_embeds,
  121. feat_sizes,
  122. ) = self._prepare_backbone_features(backbone_out)
  123. # Inverse-map image features for `unique_img_ids` to the final image features
  124. # for the original input `img_ids`.
  125. if inv_ids is not None:
  126. image = image[inv_ids]
  127. vision_feats = [x[:, inv_ids] for x in vision_feats]
  128. vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
  129. return image, vision_feats, vision_pos_embeds, feat_sizes
  130. def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
  131. """
  132. Prepare input mask, point or box prompts. Optionally, we allow tracking from
  133. a custom `start_frame_idx` to the end of the video (for evaluation purposes).
  134. """
  135. # Load the ground-truth masks on all frames (so that we can later
  136. # sample correction points from them)
  137. # gt_masks_per_frame = {
  138. # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
  139. # for stage_id, targets in enumerate(input.find_targets)
  140. # }
  141. gt_masks_per_frame = {
  142. stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
  143. for stage_id, masks in enumerate(input.masks)
  144. }
  145. # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
  146. backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
  147. num_frames = input.num_frames
  148. backbone_out["num_frames"] = num_frames
  149. # Randomly decide whether to use point inputs or mask inputs
  150. if self.training:
  151. prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
  152. prob_to_use_box_input = self.prob_to_use_box_input_for_train
  153. num_frames_to_correct = self.num_frames_to_correct_for_train
  154. rand_frames_to_correct = self.rand_frames_to_correct_for_train
  155. num_init_cond_frames = self.num_init_cond_frames_for_train
  156. rand_init_cond_frames = self.rand_init_cond_frames_for_train
  157. else:
  158. prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
  159. prob_to_use_box_input = self.prob_to_use_box_input_for_eval
  160. num_frames_to_correct = self.num_frames_to_correct_for_eval
  161. rand_frames_to_correct = self.rand_frames_to_correct_for_eval
  162. num_init_cond_frames = self.num_init_cond_frames_for_eval
  163. rand_init_cond_frames = self.rand_init_cond_frames_for_eval
  164. if num_frames == 1:
  165. # here we handle a special case for mixing video + SAM on image training,
  166. # where we force using point input for the SAM task on static images
  167. prob_to_use_pt_input = 1.0
  168. num_frames_to_correct = 1
  169. num_init_cond_frames = 1
  170. assert num_init_cond_frames >= 1
  171. # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
  172. use_pt_input = self.rng.random() < prob_to_use_pt_input
  173. if rand_init_cond_frames and num_init_cond_frames > 1:
  174. # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
  175. num_init_cond_frames = self.rng.integers(
  176. 1, num_init_cond_frames, endpoint=True
  177. )
  178. if (
  179. use_pt_input
  180. and rand_frames_to_correct
  181. and num_frames_to_correct > num_init_cond_frames
  182. ):
  183. # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
  184. # correction clicks (only for the case of point input)
  185. num_frames_to_correct = self.rng.integers(
  186. num_init_cond_frames, num_frames_to_correct, endpoint=True
  187. )
  188. backbone_out["use_pt_input"] = use_pt_input
  189. # Sample initial conditioning frames
  190. if num_init_cond_frames == 1:
  191. init_cond_frames = [start_frame_idx] # starting frame
  192. else:
  193. # starting frame + randomly selected remaining frames (without replacement)
  194. init_cond_frames = [start_frame_idx] + self.rng.choice(
  195. range(start_frame_idx + 1, num_frames),
  196. num_init_cond_frames - 1,
  197. replace=False,
  198. ).tolist()
  199. backbone_out["init_cond_frames"] = init_cond_frames
  200. backbone_out["frames_not_in_init_cond"] = [
  201. t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
  202. ]
  203. # Prepare mask or point inputs on initial conditioning frames
  204. backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
  205. backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
  206. for t in init_cond_frames:
  207. if not use_pt_input:
  208. backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
  209. else:
  210. # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
  211. use_box_input = self.rng.random() < prob_to_use_box_input
  212. if use_box_input:
  213. points, labels = sample_box_points(
  214. gt_masks_per_frame[t],
  215. )
  216. else:
  217. # (here we only sample **one initial point** on initial conditioning frames from the
  218. # ground-truth mask; we may sample more correction points on the fly)
  219. points, labels = get_next_point(
  220. gt_masks=gt_masks_per_frame[t],
  221. pred_masks=None,
  222. method=(
  223. "uniform" if self.training else self.pt_sampling_for_eval
  224. ),
  225. )
  226. point_inputs = {"point_coords": points, "point_labels": labels}
  227. backbone_out["point_inputs_per_frame"][t] = point_inputs
  228. # Sample frames where we will add correction clicks on the fly
  229. # based on the error between prediction and ground-truth masks
  230. if not use_pt_input:
  231. # no correction points will be sampled when using mask inputs
  232. frames_to_add_correction_pt = []
  233. elif num_frames_to_correct == num_init_cond_frames:
  234. frames_to_add_correction_pt = init_cond_frames
  235. else:
  236. assert num_frames_to_correct > num_init_cond_frames
  237. # initial cond frame + randomly selected remaining frames (without replacement)
  238. extra_num = num_frames_to_correct - num_init_cond_frames
  239. frames_to_add_correction_pt = (
  240. init_cond_frames
  241. + self.rng.choice(
  242. backbone_out["frames_not_in_init_cond"], extra_num, replace=False
  243. ).tolist()
  244. )
  245. backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
  246. return backbone_out
  247. def forward_tracking(
  248. self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
  249. ):
  250. """Forward video tracking on each frame (and sample correction clicks)."""
  251. img_feats_already_computed = backbone_out["backbone_fpn"] is not None
  252. if img_feats_already_computed:
  253. # Prepare the backbone features
  254. # - vision_feats and vision_pos_embeds are in (HW)BC format
  255. (
  256. _,
  257. vision_feats,
  258. vision_pos_embeds,
  259. feat_sizes,
  260. ) = self._prepare_backbone_features(backbone_out)
  261. # Starting the stage loop
  262. num_frames = backbone_out["num_frames"]
  263. init_cond_frames = backbone_out["init_cond_frames"]
  264. frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
  265. # first process all the initial conditioning frames to encode them as memory,
  266. # and then conditioning on them to track the remaining frames
  267. processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
  268. output_dict = {
  269. "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  270. "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
  271. }
  272. for stage_id in processing_order:
  273. # Get the image features for the current frames
  274. # img_ids = input.find_inputs[stage_id].img_ids
  275. img_ids = input.flat_obj_to_img_idx[stage_id]
  276. if img_feats_already_computed:
  277. # Retrieve image features according to img_ids (if they are already computed).
  278. current_vision_feats = [x[:, img_ids] for x in vision_feats]
  279. current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
  280. else:
  281. # Otherwise, compute the image features on the fly for the given img_ids
  282. # (this might be used for evaluation on long videos to avoid backbone OOM).
  283. (
  284. _,
  285. current_vision_feats,
  286. current_vision_pos_embeds,
  287. feat_sizes,
  288. ) = self._prepare_backbone_features_per_frame(
  289. input.flat_img_batch, img_ids
  290. )
  291. # Get output masks based on this frame's prompts and previous memory
  292. current_out = self.track_step(
  293. frame_idx=stage_id,
  294. is_init_cond_frame=stage_id in init_cond_frames,
  295. current_vision_feats=current_vision_feats,
  296. current_vision_pos_embeds=current_vision_pos_embeds,
  297. feat_sizes=feat_sizes,
  298. point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
  299. mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
  300. gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
  301. frames_to_add_correction_pt=frames_to_add_correction_pt,
  302. output_dict=output_dict,
  303. num_frames=num_frames,
  304. )
  305. # Append the output, depending on whether it's a conditioning frame
  306. add_output_as_cond_frame = stage_id in init_cond_frames or (
  307. self.add_all_frames_to_correct_as_cond
  308. and stage_id in frames_to_add_correction_pt
  309. )
  310. if add_output_as_cond_frame:
  311. output_dict["cond_frame_outputs"][stage_id] = current_out
  312. else:
  313. output_dict["non_cond_frame_outputs"][stage_id] = current_out
  314. if return_dict:
  315. return output_dict
  316. # turn `output_dict` into a list for loss function
  317. all_frame_outputs = {}
  318. all_frame_outputs.update(output_dict["cond_frame_outputs"])
  319. all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
  320. all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
  321. # Make DDP happy with activation checkpointing by removing unused keys
  322. all_frame_outputs = [
  323. {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
  324. ]
  325. return all_frame_outputs
  326. def track_step(
  327. self,
  328. frame_idx,
  329. is_init_cond_frame,
  330. current_vision_feats,
  331. current_vision_pos_embeds,
  332. feat_sizes,
  333. point_inputs,
  334. mask_inputs,
  335. output_dict,
  336. num_frames,
  337. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  338. run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
  339. prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
  340. frames_to_add_correction_pt=None,
  341. gt_masks=None,
  342. ):
  343. if frames_to_add_correction_pt is None:
  344. frames_to_add_correction_pt = []
  345. current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
  346. frame_idx,
  347. is_init_cond_frame,
  348. current_vision_feats,
  349. current_vision_pos_embeds,
  350. feat_sizes,
  351. point_inputs,
  352. mask_inputs,
  353. output_dict,
  354. num_frames,
  355. track_in_reverse,
  356. prev_sam_mask_logits,
  357. )
  358. (
  359. low_res_multimasks,
  360. high_res_multimasks,
  361. ious,
  362. low_res_masks,
  363. high_res_masks,
  364. obj_ptr,
  365. object_score_logits,
  366. ) = sam_outputs
  367. current_out["multistep_pred_masks"] = low_res_masks
  368. current_out["multistep_pred_masks_high_res"] = high_res_masks
  369. current_out["multistep_pred_multimasks"] = [low_res_multimasks]
  370. current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
  371. current_out["multistep_pred_ious"] = [ious]
  372. current_out["multistep_point_inputs"] = [point_inputs]
  373. current_out["multistep_object_score_logits"] = [object_score_logits]
  374. # Optionally, sample correction points iteratively to correct the mask
  375. if frame_idx in frames_to_add_correction_pt:
  376. point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
  377. is_init_cond_frame,
  378. point_inputs,
  379. gt_masks,
  380. high_res_features,
  381. pix_feat,
  382. low_res_multimasks,
  383. high_res_multimasks,
  384. ious,
  385. low_res_masks,
  386. high_res_masks,
  387. object_score_logits,
  388. current_out,
  389. )
  390. (
  391. _,
  392. _,
  393. _,
  394. low_res_masks,
  395. high_res_masks,
  396. obj_ptr,
  397. object_score_logits,
  398. ) = final_sam_outputs
  399. # Use the final prediction (after all correction steps for output and eval)
  400. current_out["pred_masks"] = low_res_masks
  401. current_out["pred_masks_high_res"] = high_res_masks
  402. current_out["obj_ptr"] = obj_ptr
  403. # Finally run the memory encoder on the predicted mask to encode
  404. # it into a new memory feature (that can be used in future frames)
  405. self._encode_memory_in_output(
  406. current_vision_feats,
  407. feat_sizes,
  408. point_inputs,
  409. run_mem_encoder,
  410. high_res_masks,
  411. object_score_logits,
  412. current_out,
  413. )
  414. return current_out
  415. def _iter_correct_pt_sampling(
  416. self,
  417. is_init_cond_frame,
  418. point_inputs,
  419. gt_masks,
  420. high_res_features,
  421. pix_feat_with_mem,
  422. low_res_multimasks,
  423. high_res_multimasks,
  424. ious,
  425. low_res_masks,
  426. high_res_masks,
  427. object_score_logits,
  428. current_out,
  429. ):
  430. assert gt_masks is not None
  431. all_pred_masks = [low_res_masks]
  432. all_pred_high_res_masks = [high_res_masks]
  433. all_pred_multimasks = [low_res_multimasks]
  434. all_pred_high_res_multimasks = [high_res_multimasks]
  435. all_pred_ious = [ious]
  436. all_point_inputs = [point_inputs]
  437. all_object_score_logits = [object_score_logits]
  438. for _ in range(self.num_correction_pt_per_frame):
  439. # sample a new point from the error between prediction and ground-truth
  440. # (with a small probability, directly sample from GT masks instead of errors)
  441. if self.training and self.prob_to_sample_from_gt_for_train > 0:
  442. sample_from_gt = (
  443. self.rng.random() < self.prob_to_sample_from_gt_for_train
  444. )
  445. else:
  446. sample_from_gt = False
  447. # if `pred_for_new_pt` is None, only GT masks will be used for point sampling
  448. pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
  449. new_points, new_labels = get_next_point(
  450. gt_masks=gt_masks,
  451. pred_masks=pred_for_new_pt,
  452. method="uniform" if self.training else self.pt_sampling_for_eval,
  453. )
  454. point_inputs = concat_points(point_inputs, new_points, new_labels)
  455. # Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
  456. # For tracking, this means that when the user adds a correction click, we also feed
  457. # the tracking output mask logits along with the click as input to the SAM decoder.
  458. mask_inputs = low_res_masks
  459. multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
  460. if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
  461. sam_outputs = torch.utils.checkpoint.checkpoint(
  462. self._forward_sam_heads,
  463. backbone_features=pix_feat_with_mem,
  464. point_inputs=point_inputs,
  465. mask_inputs=mask_inputs,
  466. high_res_features=high_res_features,
  467. multimask_output=multimask_output,
  468. use_reentrant=False,
  469. )
  470. else:
  471. sam_outputs = self._forward_sam_heads(
  472. backbone_features=pix_feat_with_mem,
  473. point_inputs=point_inputs,
  474. mask_inputs=mask_inputs,
  475. high_res_features=high_res_features,
  476. multimask_output=multimask_output,
  477. )
  478. (
  479. low_res_multimasks,
  480. high_res_multimasks,
  481. ious,
  482. low_res_masks,
  483. high_res_masks,
  484. _,
  485. object_score_logits,
  486. ) = sam_outputs
  487. all_pred_masks.append(low_res_masks)
  488. all_pred_high_res_masks.append(high_res_masks)
  489. all_pred_multimasks.append(low_res_multimasks)
  490. all_pred_high_res_multimasks.append(high_res_multimasks)
  491. all_pred_ious.append(ious)
  492. all_point_inputs.append(point_inputs)
  493. all_object_score_logits.append(object_score_logits)
  494. # Concatenate the masks along channel (to compute losses on all of them,
  495. # using `MultiStepIteractiveMasks`)
  496. current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
  497. current_out["multistep_pred_masks_high_res"] = torch.cat(
  498. all_pred_high_res_masks, dim=1
  499. )
  500. current_out["multistep_pred_multimasks"] = all_pred_multimasks
  501. current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
  502. current_out["multistep_pred_ious"] = all_pred_ious
  503. current_out["multistep_point_inputs"] = all_point_inputs
  504. current_out["multistep_object_score_logits"] = all_object_score_logits
  505. return point_inputs, sam_outputs