coco_eval_offline.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. This evaluator is meant for regular COCO mAP evaluation, for example on the COCO val set.
  5. For Category mAP, we need the model to make predictions for all the categories on every single image.
  6. In general, since the number of classes can be big, and the API model makes predictions individually for each pair (image, class),
  7. we may need to split the inference process for a given image in several chunks.
  8. """
  9. import logging
  10. from collections import defaultdict
  11. import torch
  12. from pycocotools.coco import COCO
  13. from pycocotools.cocoeval import COCOeval
  14. from sam3.train.utils.distributed import is_main_process
  15. try:
  16. from tidecv import datasets, TIDE
  17. HAS_TIDE = True
  18. except ImportError:
  19. HAS_TIDE = False
  20. print("WARNING: TIDE not installed. Detailed analysis will not be available.")
  21. # the COCO detection metrics (https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L460-L471)
  22. COCO_METRICS = [
  23. "AP",
  24. "AP_50",
  25. "AP_75",
  26. "AP_small",
  27. "AP_medium",
  28. "AP_large",
  29. "AR_maxDets@1",
  30. "AR_maxDets@10",
  31. "AR_maxDets@100",
  32. "AR_small",
  33. "AR_medium",
  34. "AR_large",
  35. ]
  36. def convert_to_xywh(boxes):
  37. """Convert bounding boxes from xyxy format to xywh format."""
  38. xmin, ymin, xmax, ymax = boxes.unbind(-1)
  39. return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1)
  40. class HeapElement:
  41. """Utility class to make a heap with a custom comparator"""
  42. def __init__(self, val):
  43. self.val = val
  44. def __lt__(self, other):
  45. return self.val["score"] < other.val["score"]
  46. class COCOevalCustom(COCOeval):
  47. """
  48. This is a slightly modified version of the original COCO API with added support for positive split evaluation.
  49. """
  50. def __init__(
  51. self, cocoGt=None, cocoDt=None, iouType="segm", dt_only_positive=False
  52. ):
  53. super().__init__(cocoGt, cocoDt, iouType)
  54. self.dt_only_positive = dt_only_positive
  55. def _prepare(self):
  56. """
  57. Prepare ._gts and ._dts for evaluation based on params
  58. :return: None
  59. """
  60. def _toMask(anns, coco):
  61. # modify ann['segmentation'] by reference
  62. for ann in anns:
  63. rle = coco.annToRLE(ann)
  64. ann["segmentation"] = rle
  65. p = self.params
  66. if p.useCats:
  67. gts = self.cocoGt.loadAnns(
  68. self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
  69. )
  70. dts = self.cocoDt.loadAnns(
  71. self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
  72. )
  73. else:
  74. gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
  75. dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
  76. # convert ground truth to mask if iouType == 'segm'
  77. if p.iouType == "segm":
  78. _toMask(gts, self.cocoGt)
  79. _toMask(dts, self.cocoDt)
  80. # set ignore flag
  81. for gt in gts:
  82. gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
  83. gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
  84. if p.iouType == "keypoints":
  85. gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
  86. self._gts = defaultdict(list) # gt for evaluation
  87. self._dts = defaultdict(list) # dt for evaluation
  88. _gts_cat_ids = defaultdict(set) # gt for evaluation on positive split
  89. for gt in gts:
  90. self._gts[gt["image_id"], gt["category_id"]].append(gt)
  91. _gts_cat_ids[gt["image_id"]].add(gt["category_id"])
  92. #### BEGIN MODIFICATION ####
  93. for dt in dts:
  94. if (
  95. self.dt_only_positive
  96. and dt["category_id"] not in _gts_cat_ids[dt["image_id"]]
  97. ):
  98. continue
  99. self._dts[dt["image_id"], dt["category_id"]].append(dt)
  100. #### END MODIFICATION ####
  101. self.evalImgs = defaultdict(list) # per-image per-category evaluation results
  102. self.eval = {} # accumulated evaluation results
  103. class CocoEvaluatorOfflineWithPredFileEvaluators:
  104. def __init__(
  105. self,
  106. gt_path,
  107. tide: bool = True,
  108. iou_type: str = "bbox",
  109. positive_split=False,
  110. ):
  111. self.gt_path = gt_path
  112. self.tide_enabled = HAS_TIDE and tide
  113. self.positive_split = positive_split
  114. self.iou_type = iou_type
  115. def evaluate(self, dumped_file):
  116. if not is_main_process():
  117. return {}
  118. logging.info("OfflineCoco evaluator: Loading groundtruth")
  119. self.gt = COCO(self.gt_path)
  120. # Creating the result file
  121. logging.info("Coco evaluator: Creating the result file")
  122. cocoDt = self.gt.loadRes(str(dumped_file))
  123. # Run the evaluation
  124. logging.info("Coco evaluator: Running evaluation")
  125. coco_eval = COCOevalCustom(
  126. self.gt, cocoDt, iouType=self.iou_type, dt_only_positive=self.positive_split
  127. )
  128. coco_eval.evaluate()
  129. coco_eval.accumulate()
  130. coco_eval.summarize()
  131. outs = {}
  132. for i, value in enumerate(coco_eval.stats):
  133. outs[f"coco_eval_{self.iou_type}_{COCO_METRICS[i]}"] = value
  134. if self.tide_enabled:
  135. logging.info("Coco evaluator: Loading TIDE")
  136. self.tide_gt = datasets.COCO(self.gt_path)
  137. self.tide = TIDE(mode="mask" if self.iou_type == "segm" else "bbox")
  138. # Run TIDE
  139. logging.info("Coco evaluator: Running TIDE")
  140. self.tide.evaluate(
  141. self.tide_gt, datasets.COCOResult(str(dumped_file)), name="coco_eval"
  142. )
  143. self.tide.summarize()
  144. for k, v in self.tide.get_main_errors()["coco_eval"].items():
  145. outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v
  146. for k, v in self.tide.get_special_errors()["coco_eval"].items():
  147. outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v
  148. return outs