postprocessors.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Postprocessors class to transform MDETR output according to the downstream task"""
  4. import dataclasses
  5. import logging
  6. from collections import defaultdict
  7. from typing import Dict, List, Optional
  8. import numpy as np
  9. import torch
  10. from sam3.model import box_ops
  11. from sam3.model.data_misc import BatchedInferenceMetadata, interpolate
  12. from sam3.train.masks_ops import rle_encode, robust_rle_encode
  13. from torch import nn
  14. class PostProcessNullOp(nn.Module):
  15. def __init__(self, **kwargs):
  16. super(PostProcessNullOp).__init__()
  17. pass
  18. def forward(self, input):
  19. pass
  20. def process_results(self, **kwargs):
  21. return kwargs["find_stages"]
  22. class PostProcessImage(nn.Module):
  23. """This module converts the model's output into the format expected by the coco api"""
  24. def __init__(
  25. self,
  26. max_dets_per_img: int,
  27. iou_type="bbox",
  28. to_cpu: bool = True,
  29. use_original_ids: bool = False,
  30. use_original_sizes_box: bool = False,
  31. use_original_sizes_mask: bool = False,
  32. convert_mask_to_rle: bool = False,
  33. always_interpolate_masks_on_gpu: bool = True,
  34. use_presence: bool = True,
  35. detection_threshold: float = -1.0,
  36. ) -> None:
  37. super().__init__()
  38. self.max_dets_per_img = max_dets_per_img
  39. self.iou_type = iou_type
  40. self.to_cpu = to_cpu
  41. self.convert_mask_to_rle = convert_mask_to_rle
  42. self.always_interpolate_masks_on_gpu = always_interpolate_masks_on_gpu
  43. self.use_presence = use_presence
  44. self.detection_threshold = detection_threshold
  45. self.use_original_ids = use_original_ids
  46. self.use_original_sizes_box = use_original_sizes_box
  47. self.use_original_sizes_mask = use_original_sizes_mask
  48. @torch.no_grad()
  49. def forward(
  50. self,
  51. outputs,
  52. target_sizes_boxes,
  53. target_sizes_masks,
  54. forced_labels=None,
  55. consistent=False,
  56. ret_tensordict: bool = False, # This is experimental
  57. ):
  58. """Perform the computation
  59. Parameters:
  60. outputs: raw outputs of the model
  61. target_sizes_boxes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
  62. For evaluation, this must be the original image size (before any data augmentation)
  63. For visualization, this should be the image size after data augment, but before padding
  64. target_sizes_masks: same but used to resize masks
  65. forced_labels: tensor of dimension [batch_size] containing the label to force for each image of the batch
  66. This is useful when evaluating the model using standard metrics (eg on COCO, LVIS). In that case,
  67. we query the model with every possible class label, so we when we pass the predictions to the evaluator,
  68. we want to make sure that the predicted "class" matches the one that was queried.
  69. consistent: whether all target sizes are equal
  70. ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
  71. """
  72. if ret_tensordict:
  73. assert consistent is True, (
  74. "We don't support returning TensorDict if the outputs have different shapes"
  75. ) # NOTE: It's possible but we don't support it.
  76. assert self.detection_threshold <= 0.0, "TODO: implement?"
  77. try:
  78. from tensordict import TensorDict
  79. except ImportError:
  80. logging.info(
  81. "tensordict is not installed. Install by running `pip install tensordict --no-deps`. Falling back by setting `ret_tensordict=False`"
  82. )
  83. ret_tensordict = False
  84. out_bbox = outputs["pred_boxes"] if "pred_boxes" in outputs else None
  85. out_logits = outputs["pred_logits"]
  86. pred_masks = outputs["pred_masks"] if self.iou_type == "segm" else None
  87. out_probs = out_logits.sigmoid()
  88. if self.use_presence:
  89. presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
  90. out_probs = out_probs * presence_score
  91. assert target_sizes_boxes.shape[1] == 2
  92. assert target_sizes_masks.shape[1] == 2
  93. batch_size = target_sizes_boxes.shape[0]
  94. boxes, scores, labels, keep = self._process_boxes_and_labels(
  95. target_sizes_boxes, forced_labels, out_bbox, out_probs
  96. )
  97. assert boxes is None or len(boxes) == batch_size
  98. out_masks = self._process_masks(
  99. target_sizes_masks, pred_masks, consistent=consistent, keep=keep
  100. )
  101. del pred_masks
  102. if boxes is None:
  103. assert out_masks is not None
  104. assert not ret_tensordict, (
  105. "We don't support returning TensorDict if the output does not contain boxes"
  106. )
  107. B = len(out_masks)
  108. boxes = [None] * B
  109. scores = [None] * B
  110. labels = [None] * B
  111. results = {
  112. "scores": scores,
  113. "labels": labels,
  114. "boxes": boxes,
  115. }
  116. if out_masks is not None:
  117. if self.convert_mask_to_rle:
  118. results.update(masks_rle=out_masks)
  119. else:
  120. results.update(masks=out_masks)
  121. if ret_tensordict:
  122. results = TensorDict(results).auto_batch_size_()
  123. if self.to_cpu:
  124. results = results.cpu()
  125. else:
  126. # Convert a dictonary of lists/tensors to list of dictionaries
  127. results = [
  128. dict(zip(results.keys(), res_tuple))
  129. for res_tuple in zip(*results.values())
  130. ]
  131. return results
  132. def _process_masks(self, target_sizes, pred_masks, consistent=True, keep=None):
  133. if pred_masks is None:
  134. return None
  135. if self.always_interpolate_masks_on_gpu:
  136. gpu_device = target_sizes.device
  137. assert gpu_device.type == "cuda"
  138. pred_masks = pred_masks.to(device=gpu_device)
  139. if consistent:
  140. assert keep is None, "TODO: implement?"
  141. # All masks should have the same shape, expected when processing a batch of size 1
  142. target_size = target_sizes.unique(dim=0)
  143. assert target_size.size(0) == 1, "Expecting all target sizes to be equal"
  144. out_masks = (
  145. interpolate(
  146. pred_masks,
  147. target_size.squeeze().tolist(),
  148. mode="bilinear",
  149. align_corners=False,
  150. ).sigmoid()
  151. > 0.5
  152. )
  153. if self.convert_mask_to_rle:
  154. raise RuntimeError("TODO: implement?")
  155. if self.to_cpu:
  156. out_masks = out_masks.cpu()
  157. else:
  158. out_masks = [[]] * len(pred_masks)
  159. assert keep is None or len(keep) == len(pred_masks)
  160. for i, mask in enumerate(pred_masks):
  161. h, w = target_sizes[i]
  162. if keep is not None:
  163. mask = mask[keep[i]]
  164. # Uses the gpu version fist, moves masks to cpu if it fails"""
  165. try:
  166. interpolated = (
  167. interpolate(
  168. mask.unsqueeze(1),
  169. (h, w),
  170. mode="bilinear",
  171. align_corners=False,
  172. ).sigmoid()
  173. > 0.5
  174. )
  175. except Exception as e:
  176. logging.info("Issue found, reverting to CPU mode!")
  177. mask_device = mask.device
  178. mask = mask.cpu()
  179. interpolated = (
  180. interpolate(
  181. mask.unsqueeze(1),
  182. (h, w),
  183. mode="bilinear",
  184. align_corners=False,
  185. ).sigmoid()
  186. > 0.5
  187. )
  188. interpolated = interpolated.to(mask_device)
  189. if self.convert_mask_to_rle:
  190. out_masks[i] = robust_rle_encode(interpolated.squeeze(1))
  191. else:
  192. out_masks[i] = interpolated
  193. if self.to_cpu:
  194. out_masks[i] = out_masks[i].cpu()
  195. return out_masks
  196. def _process_boxes_and_labels(
  197. self, target_sizes, forced_labels, out_bbox, out_probs
  198. ):
  199. if out_bbox is None:
  200. return None, None, None, None
  201. assert len(out_probs) == len(target_sizes)
  202. if self.to_cpu:
  203. out_probs = out_probs.cpu()
  204. scores, labels = out_probs.max(-1)
  205. if forced_labels is None:
  206. labels = torch.ones_like(labels)
  207. else:
  208. labels = forced_labels[:, None].expand_as(labels)
  209. # convert to [x0, y0, x1, y1] format
  210. boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
  211. img_h, img_w = target_sizes.unbind(1)
  212. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
  213. boxes = boxes * scale_fct[:, None, :]
  214. if self.to_cpu:
  215. boxes = boxes.cpu()
  216. keep = None
  217. if self.detection_threshold > 0:
  218. # Filter out the boxes with scores below the detection threshold
  219. keep = scores > self.detection_threshold
  220. assert len(keep) == len(boxes) == len(scores) == len(labels)
  221. boxes = [b[k.to(b.device)] for b, k in zip(boxes, keep)]
  222. scores = [s[k.to(s.device)] for s, k in zip(scores, keep)]
  223. labels = [l[k.to(l.device)] for l, k in zip(labels, keep)]
  224. return boxes, scores, labels, keep
  225. def process_results(
  226. self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
  227. ):
  228. if find_stages.loss_stages is not None:
  229. find_metadatas = [find_metadatas[i] for i in find_stages.loss_stages]
  230. assert len(find_stages) == len(find_metadatas)
  231. results = {}
  232. for outputs, meta in zip(find_stages, find_metadatas):
  233. img_size_for_boxes = (
  234. meta.original_size
  235. if self.use_original_sizes_box
  236. else torch.ones_like(meta.original_size)
  237. )
  238. img_size_for_masks = (
  239. meta.original_size
  240. if self.use_original_sizes_mask
  241. else torch.ones_like(meta.original_size)
  242. )
  243. detection_results = self(
  244. outputs,
  245. img_size_for_boxes,
  246. img_size_for_masks,
  247. forced_labels=(
  248. meta.original_category_id if self.use_original_ids else None
  249. ),
  250. )
  251. ids = (
  252. meta.original_image_id if self.use_original_ids else meta.coco_image_id
  253. )
  254. assert len(detection_results) == len(ids)
  255. for img_id, result in zip(ids, detection_results):
  256. if img_id.item() not in results:
  257. results[img_id.item()] = result
  258. else:
  259. assert set(results[img_id.item()].keys()) == set(result.keys())
  260. for k in result.keys():
  261. if isinstance(result[k], torch.Tensor):
  262. results[img_id.item()][k] = torch.cat(
  263. [results[img_id.item()][k], result[k]], dim=0
  264. )
  265. elif isinstance(result[k], list):
  266. results[img_id.item()][k] += result[k]
  267. else:
  268. raise NotImplementedError(
  269. f"Unexpected type {type(result[k])} in result."
  270. )
  271. # Prune the results to the max number of detections per image.
  272. for img_id, result in results.items():
  273. if (
  274. self.max_dets_per_img > 0
  275. and len(result["scores"]) > self.max_dets_per_img
  276. ):
  277. _, topk_indexes = torch.topk(
  278. result["scores"], self.max_dets_per_img, dim=0
  279. )
  280. if self.to_cpu:
  281. topk_indexes = topk_indexes.cpu()
  282. for k in result.keys():
  283. if isinstance(results[img_id][k], list):
  284. results[img_id][k] = [
  285. results[img_id][k][i] for i in topk_indexes.tolist()
  286. ]
  287. else:
  288. results[img_id][k] = results[img_id][k].to(topk_indexes.device)[
  289. topk_indexes
  290. ]
  291. return results
  292. class PostProcessAPIVideo(PostProcessImage):
  293. """This module converts the video model's output into the format expected by the YT-VIS api"""
  294. def __init__(
  295. self,
  296. *args,
  297. to_cpu: bool = True,
  298. convert_mask_to_rle: bool = False,
  299. always_interpolate_masks_on_gpu: bool = True,
  300. prob_thresh: float = 0.5,
  301. use_presence: bool = False,
  302. **kwargs,
  303. ):
  304. super().__init__(
  305. *args,
  306. # Here we always set `convert_mask_to_rle=False` in the base `PostProcessAPI` class
  307. # (so that its `_process_masks` won't return a list of RLEs). If we want to return
  308. # RLEs for video masklets, we handle it in this `PostProcessAPIVideo` class instead.
  309. convert_mask_to_rle=False,
  310. # Here we always set `to_cpu=False` in the base `PostProcessAPI` class (so that
  311. # the interpolated masks won't be automatically moved back to CPU). We will handle
  312. # it in this `PostProcessAPIVideo` class instead.
  313. always_interpolate_masks_on_gpu=always_interpolate_masks_on_gpu,
  314. use_presence=use_presence,
  315. **kwargs,
  316. )
  317. # Expected keys in the output dict to postprocess
  318. self.EXPECTED_KEYS = [
  319. "pred_logits",
  320. "pred_boxes",
  321. "pred_masks",
  322. ]
  323. # Whether to post-process video masklets (under packed representation) into RLE format
  324. self.convert_mask_to_rle_for_video = convert_mask_to_rle
  325. self.to_cpu_for_video = to_cpu
  326. self.prob_thresh = prob_thresh
  327. def process_results(
  328. self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
  329. ):
  330. """
  331. Tracking Postprocessor for SAM 3 video model.
  332. This function takes in the output of the SAM 3 video model and processes it to extract all the tracklet predictions.
  333. Args:
  334. find_stages: A list of tensors representing the output of the SAM 3 video model.
  335. find_metadatas: A list of BatchedInferenceMetadata objects containing metadata about each frame.
  336. **kwargs: Additional keyword arguments.
  337. Returns:
  338. A dictionary of predcitions with video_id as key.
  339. """
  340. # Import tensordict here to avoid global dependency.
  341. try:
  342. from tensordict import TensorDict
  343. except ImportError as e:
  344. logging.error(
  345. "tensordict is not installed, please install by running `pip install tensordict --no-deps`"
  346. )
  347. raise e
  348. # Notes and assumptions:
  349. # 1- This postprocessor assumes results only for a single video.
  350. # 2- There are N stage outputs corresponding to N video frames
  351. # 3- Each stage outputs contains PxQ preds, where P is number of prompts and Q is number of object queries. The output should also contain the tracking object ids corresponding to each object query.
  352. # 4- The tracking object id has a default value of -1, indicating that the object query is not tracking any object in the frame, and hence its predictions can be ingored for a given frame.
  353. # 5- Some objects may be tracked in a subset of frames only. So, we first extract the predictions in a packed representation (for efficient postprocessing -- specially memory)
  354. # and then we convert the packed representation into a padded one, where we zero pad boxes/masks for objects that are not tracked in some frames.
  355. # 6- We refer to objects by an object id, which is a tuple (prompt_idx, obj_id)
  356. assert len(find_stages) > 0, "There is nothing to postprocess?"
  357. PROMPT_AXIS, OBJ_QUERY_AXIS = (0, 1)
  358. NO_OBJ_ID = -1
  359. # Maps object ID -> [indices in packed tensor]
  360. tracked_objects_packed_idx = defaultdict(list)
  361. # Maps object ID -> [indices in padded tensor (abs frame index)]
  362. tracked_objects_frame_idx = defaultdict(list)
  363. total_num_preds = 0
  364. # This will hold the packed representation of predictions.
  365. vid_preds_packed: List[TensorDict] = []
  366. vid_masklets_rle_packed: List[Optional[Dict]] = []
  367. video_id = -1 # We assume single video postprocessing, this ID should be unique in the datapoint.
  368. for frame_idx, (frame_outs, meta) in enumerate(
  369. zip(find_stages, find_metadatas)
  370. ):
  371. # only store keys we need to extract the results
  372. frame_outs_td = TensorDict(
  373. {k: frame_outs[k] for k in self.EXPECTED_KEYS}
  374. ).auto_batch_size_() # Shape is [P,Q,...]
  375. meta_td = TensorDict(
  376. dataclasses.asdict(meta)
  377. ).auto_batch_size_() # Shape is [P,...]
  378. unique_vid_id = meta.original_image_id.unique()
  379. assert unique_vid_id.size(0) == 1
  380. if video_id == -1:
  381. video_id = unique_vid_id.item()
  382. else:
  383. assert video_id == unique_vid_id.item(), (
  384. "We can only postprocess one video per datapoint"
  385. )
  386. # keeping track of which objects appear in the current frame
  387. obj_ids_per_frame = frame_outs["pred_object_ids"]
  388. assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)
  389. if self.prob_thresh is not None:
  390. # only keep the predictions on this frame with probability above the threshold
  391. # (remove those predictions during the keep-alive period of a tracking query,
  392. # where its "pred_object_ids" is still the tracked object ID rather than -1)
  393. pred_probs = frame_outs["pred_logits"].sigmoid().squeeze(-1)
  394. obj_ids_per_frame = torch.where(
  395. pred_probs >= self.prob_thresh, obj_ids_per_frame, NO_OBJ_ID
  396. )
  397. tracked_obj_ids_idx = torch.where(obj_ids_per_frame != NO_OBJ_ID)
  398. # Object id is a tuple of (prompt_idx, obj_id). This is because the model can assign same obj_id for two different prompts.
  399. tracked_obj_ids = [
  400. (p_id.item(), obj_ids_per_frame[p_id, q_id].item())
  401. for p_id, q_id in zip(
  402. tracked_obj_ids_idx[PROMPT_AXIS],
  403. tracked_obj_ids_idx[OBJ_QUERY_AXIS],
  404. )
  405. ]
  406. if len(tracked_obj_ids) == 0:
  407. continue
  408. # For each object, we keep track of the packed and padded (frame index) indices
  409. for oid in tracked_obj_ids:
  410. tracked_objects_packed_idx[oid].append(total_num_preds)
  411. tracked_objects_frame_idx[oid].append(frame_idx)
  412. total_num_preds += 1
  413. # Since we have P*Q masks per frame, mask interpolation is the GPU memory bottleneck or time bottleneck in case of cpu processing.
  414. # Instead, we first extract results only for tracked objects, reducing the number of masks to K = sum_i(tracked_objs_per_ith_prompt), hopefully <<< P*Q
  415. tracked_objs_outs_td = frame_outs_td[
  416. tracked_obj_ids_idx
  417. ] # [P,Q,...] --> [K,...]
  418. meta_td = meta_td[tracked_obj_ids_idx[PROMPT_AXIS].cpu()]
  419. if self.always_interpolate_masks_on_gpu:
  420. gpu_device = meta_td["original_size"].device
  421. assert gpu_device.type == "cuda"
  422. tracked_objs_outs_td = tracked_objs_outs_td.to(device=gpu_device)
  423. frame_results_td = self(
  424. tracked_objs_outs_td.unsqueeze(1),
  425. (
  426. meta_td["original_size"]
  427. if self.use_original_sizes
  428. else torch.ones_like(meta_td["original_size"])
  429. ),
  430. forced_labels=(
  431. meta_td["original_category_id"] if self.use_original_ids else None
  432. ),
  433. consistent=True,
  434. ret_tensordict=True,
  435. ).squeeze(1)
  436. del tracked_objs_outs_td
  437. # Optionally, remove "masks" from output tensor dict and directly encode them
  438. # to RLE format under packed representations
  439. if self.convert_mask_to_rle_for_video:
  440. interpolated_binary_masks = frame_results_td.pop("masks")
  441. rle_list = rle_encode(interpolated_binary_masks, return_areas=True)
  442. vid_masklets_rle_packed.extend(rle_list)
  443. # Optionally, move output TensorDict to CPU (do this after RLE encoding step above)
  444. if self.to_cpu_for_video:
  445. frame_results_td = frame_results_td.cpu()
  446. vid_preds_packed.append(frame_results_td)
  447. if len(vid_preds_packed) == 0:
  448. logging.debug(f"Video {video_id} has no predictions")
  449. return {video_id: []}
  450. vid_preds_packed = torch.cat(vid_preds_packed, dim=0)
  451. ############### Construct a padded representation of the predictions ###############
  452. num_preds = len(tracked_objects_packed_idx)
  453. num_frames = len(find_stages)
  454. # We zero pad any missing prediction
  455. # NOTE: here, we also have padded tensors for "scores" and "labels", but we overwrite them later.
  456. padded_frames_results = TensorDict(
  457. {
  458. k: torch.zeros(
  459. num_preds, num_frames, *v.shape[1:], device=v.device, dtype=v.dtype
  460. )
  461. for k, v in vid_preds_packed.items()
  462. },
  463. batch_size=[
  464. num_preds,
  465. num_frames,
  466. ],
  467. )
  468. padded_frames_results["scores"][...] = -1e8 # a very low score for empty object
  469. # Track scores and labels of each pred tracklet, only for frames where the model was able to track that object
  470. tracklet_scores = []
  471. tracklet_labels = []
  472. # Optionally, fill the list of RLEs for masklets
  473. # note: only frames with actual predicted masks (in packed format) will be
  474. # filled with RLEs; the rest will remains None in results["masks_rle"]
  475. if self.convert_mask_to_rle_for_video:
  476. vid_masklets_rle_padded = [[None] * num_frames for _ in range(num_preds)]
  477. for o_idx, oid in enumerate(tracked_objects_packed_idx):
  478. oid2packed_idx = tracked_objects_packed_idx[oid]
  479. oid2padded_idx = tracked_objects_frame_idx[oid]
  480. obj_packed_results = vid_preds_packed[oid2packed_idx]
  481. padded_frames_results[o_idx][oid2padded_idx] = obj_packed_results
  482. if self.convert_mask_to_rle_for_video:
  483. for packed_idx, padded_idx in zip(oid2packed_idx, oid2padded_idx):
  484. vid_masklets_rle_padded[o_idx][padded_idx] = (
  485. vid_masklets_rle_packed[packed_idx]
  486. )
  487. # NOTE: We need a single confidence score per tracklet for the mAP metric.
  488. # We use the average confidence score across time. (How does this impact AP?)
  489. tracklet_scores.append(obj_packed_results["scores"].mean())
  490. # We also need to have a unique category Id per tracklet.
  491. # This is not a problem for phrase AP, however, for mAP we do majority voting across time.
  492. tracklet_labels.append(obj_packed_results["labels"].mode()[0])
  493. results = padded_frames_results.to_dict()
  494. results["scores"] = torch.stack(tracklet_scores, dim=0)
  495. results["labels"] = torch.stack(tracklet_labels, dim=0)
  496. if self.convert_mask_to_rle_for_video:
  497. results["masks_rle"] = vid_masklets_rle_padded
  498. # we keep the frame-level scores since it's needed by some evaluation scripts
  499. results["per_frame_scores"] = padded_frames_results["scores"]
  500. return {video_id: results}
  501. class PostProcessTracking(PostProcessImage):
  502. """This module converts the model's output into the format expected by the coco api"""
  503. def __init__(
  504. self,
  505. max_dets_per_img: int,
  506. iou_type="bbox",
  507. force_single_mask: bool = False,
  508. **kwargs,
  509. ) -> None:
  510. super().__init__(max_dets_per_img=max_dets_per_img, iou_type=iou_type, **kwargs)
  511. self.force_single_mask = force_single_mask
  512. def process_results(
  513. self, find_stages, find_metadatas: BatchedInferenceMetadata, **kwargs
  514. ):
  515. assert len(find_stages) == len(find_metadatas)
  516. results = {}
  517. for outputs, meta in zip(find_stages, find_metadatas):
  518. if self.force_single_mask:
  519. scores, labels = outputs["pred_logits"].max(-1)
  520. m = []
  521. for i in range(len(outputs["pred_masks"])):
  522. score, idx = scores[i].max(0)
  523. m.append(outputs["pred_masks"][i][idx])
  524. outputs["pred_masks"] = torch.stack(m, 0).unsqueeze(1)
  525. detection_results = self(outputs, meta.original_size, consistent=False)
  526. assert len(detection_results) == len(meta.coco_image_id)
  527. results.update(
  528. {
  529. (media_id.item(), object_id.item(), frame_index.item()): result
  530. for media_id, object_id, frame_index, result in zip(
  531. meta.original_image_id,
  532. meta.object_id,
  533. meta.frame_index,
  534. detection_results,
  535. )
  536. }
  537. )
  538. return results
  539. class PostProcessCounting(nn.Module):
  540. """This module converts the model's output to be evaluated for counting tasks"""
  541. def __init__(
  542. self,
  543. use_original_ids: bool = False,
  544. threshold: float = 0.5,
  545. use_presence: bool = False,
  546. ) -> None:
  547. """
  548. Args:
  549. use_original_ids: whether to use the original image ids or the coco ids
  550. threshold: threshold for counting (values above this are counted)
  551. """
  552. super().__init__()
  553. self.use_original_ids = use_original_ids
  554. self.threshold = threshold
  555. self.use_presence = use_presence
  556. def forward(self, outputs, target_sizes):
  557. """Perform the computation
  558. Parameters:
  559. outputs: raw outputs of the model
  560. target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
  561. """
  562. # Extract scores from model outputs and apply sigmoid
  563. scores = torch.sigmoid(outputs["pred_logits"]).squeeze(-1) # [B, N]
  564. if self.use_presence:
  565. presence_score = outputs["presence_logit_dec"].sigmoid()
  566. if presence_score.ndim == 1:
  567. presence_score = presence_score.unsqueeze(1) # [B, 1]
  568. scores = scores * presence_score # [B, N]
  569. # Calculate counts by summing values above threshold
  570. counts = (scores > self.threshold).float().sum(dim=1)
  571. assert len(counts) == len(target_sizes)
  572. results = []
  573. for count in counts:
  574. results.append({"count": count.item()})
  575. return results
  576. @torch.no_grad()
  577. def process_results(
  578. self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs
  579. ):
  580. assert len(find_stages) == len(find_metadatas)
  581. results = {}
  582. for outputs, meta in zip(find_stages, find_metadatas):
  583. detection_results = self(
  584. outputs,
  585. meta.original_size,
  586. )
  587. ids = (
  588. meta.original_image_id if self.use_original_ids else meta.coco_image_id
  589. )
  590. assert len(detection_results) == len(ids)
  591. for img_id, result in zip(ids, detection_results):
  592. results[img_id.item()] = result
  593. return results