cgf1_eval.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import contextlib
  4. import copy
  5. import json
  6. import os
  7. import time
  8. from collections import defaultdict
  9. from dataclasses import dataclass
  10. from typing import List, Union
  11. import numpy as np
  12. import pycocotools.mask as maskUtils
  13. from pycocotools.coco import COCO
  14. from pycocotools.cocoeval import COCOeval
  15. from scipy.optimize import linear_sum_assignment
  16. from tqdm import tqdm
  17. @dataclass
  18. class Metric:
  19. name: str
  20. # whether the metric is computed at the image level or the box level
  21. image_level: bool
  22. # iou threshold (None is used for image level metrics or to indicate averaging over all thresholds in [0.5:0.95])
  23. iou_threshold: Union[float, None]
  24. CGF1_METRICS = [
  25. Metric(name="cgF1", image_level=False, iou_threshold=None),
  26. Metric(name="precision", image_level=False, iou_threshold=None),
  27. Metric(name="recall", image_level=False, iou_threshold=None),
  28. Metric(name="F1", image_level=False, iou_threshold=None),
  29. Metric(name="positive_macro_F1", image_level=False, iou_threshold=None),
  30. Metric(name="positive_micro_F1", image_level=False, iou_threshold=None),
  31. Metric(name="positive_micro_precision", image_level=False, iou_threshold=None),
  32. Metric(name="IL_precision", image_level=True, iou_threshold=None),
  33. Metric(name="IL_recall", image_level=True, iou_threshold=None),
  34. Metric(name="IL_F1", image_level=True, iou_threshold=None),
  35. Metric(name="IL_FPR", image_level=True, iou_threshold=None),
  36. Metric(name="IL_MCC", image_level=True, iou_threshold=None),
  37. Metric(name="cgF1", image_level=False, iou_threshold=0.5),
  38. Metric(name="precision", image_level=False, iou_threshold=0.5),
  39. Metric(name="recall", image_level=False, iou_threshold=0.5),
  40. Metric(name="F1", image_level=False, iou_threshold=0.5),
  41. Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.5),
  42. Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.5),
  43. Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.5),
  44. Metric(name="cgF1", image_level=False, iou_threshold=0.75),
  45. Metric(name="precision", image_level=False, iou_threshold=0.75),
  46. Metric(name="recall", image_level=False, iou_threshold=0.75),
  47. Metric(name="F1", image_level=False, iou_threshold=0.75),
  48. Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.75),
  49. Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.75),
  50. Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.75),
  51. ]
  52. class COCOCustom(COCO):
  53. """COCO class from pycocotools with tiny modifications for speed"""
  54. def createIndex(self):
  55. # create index
  56. print("creating index...")
  57. anns, cats, imgs = {}, {}, {}
  58. imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
  59. if "annotations" in self.dataset:
  60. for ann in self.dataset["annotations"]:
  61. imgToAnns[ann["image_id"]].append(ann)
  62. anns[ann["id"]] = ann
  63. if "images" in self.dataset:
  64. # MODIFICATION: do not reload imgs if they are already there
  65. if self.imgs:
  66. imgs = self.imgs
  67. else:
  68. for img in self.dataset["images"]:
  69. imgs[img["id"]] = img
  70. # END MODIFICATION
  71. if "categories" in self.dataset:
  72. for cat in self.dataset["categories"]:
  73. cats[cat["id"]] = cat
  74. if "annotations" in self.dataset and "categories" in self.dataset:
  75. for ann in self.dataset["annotations"]:
  76. catToImgs[ann["category_id"]].append(ann["image_id"])
  77. print("index created!")
  78. # create class members
  79. self.anns = anns
  80. self.imgToAnns = imgToAnns
  81. self.catToImgs = catToImgs
  82. self.imgs = imgs
  83. self.cats = cats
  84. def loadRes(self, resFile):
  85. """
  86. Load result file and return a result api object.
  87. :param resFile (str) : file name of result file
  88. :return: res (obj) : result api object
  89. """
  90. res = COCOCustom()
  91. res.dataset["info"] = copy.deepcopy(self.dataset.get("info", {}))
  92. # MODIFICATION: no copy
  93. # res.dataset['images'] = [img for img in self.dataset['images']]
  94. res.dataset["images"] = self.dataset["images"]
  95. # END MODIFICATION
  96. print("Loading and preparing results...")
  97. tic = time.time()
  98. if type(resFile) == str:
  99. with open(resFile) as f:
  100. anns = json.load(f)
  101. elif type(resFile) == np.ndarray:
  102. anns = self.loadNumpyAnnotations(resFile)
  103. else:
  104. anns = resFile
  105. assert type(anns) == list, "results in not an array of objects"
  106. annsImgIds = [ann["image_id"] for ann in anns]
  107. # MODIFICATION: faster and cached subset check
  108. if not hasattr(self, "img_id_set"):
  109. self.img_id_set = set(self.getImgIds())
  110. assert set(annsImgIds).issubset(self.img_id_set), (
  111. "Results do not correspond to current coco set"
  112. )
  113. # END MODIFICATION
  114. if "caption" in anns[0]:
  115. imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
  116. [ann["image_id"] for ann in anns]
  117. )
  118. res.dataset["images"] = [
  119. img for img in res.dataset["images"] if img["id"] in imgIds
  120. ]
  121. for id, ann in enumerate(anns):
  122. ann["id"] = id + 1
  123. elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
  124. res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
  125. for id, ann in enumerate(anns):
  126. bb = ann["bbox"]
  127. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  128. if not "segmentation" in ann:
  129. ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  130. ann["area"] = bb[2] * bb[3]
  131. ann["id"] = id + 1
  132. ann["iscrowd"] = 0
  133. elif "segmentation" in anns[0]:
  134. res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
  135. for id, ann in enumerate(anns):
  136. # now only support compressed RLE format as segmentation results
  137. ann["area"] = maskUtils.area(ann["segmentation"])
  138. if not "bbox" in ann:
  139. ann["bbox"] = maskUtils.toBbox(ann["segmentation"])
  140. ann["id"] = id + 1
  141. ann["iscrowd"] = 0
  142. elif "keypoints" in anns[0]:
  143. res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
  144. for id, ann in enumerate(anns):
  145. s = ann["keypoints"]
  146. x = s[0::3]
  147. y = s[1::3]
  148. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  149. ann["area"] = (x1 - x0) * (y1 - y0)
  150. ann["id"] = id + 1
  151. ann["bbox"] = [x0, y0, x1 - x0, y1 - y0]
  152. print("DONE (t={:0.2f}s)".format(time.time() - tic))
  153. res.dataset["annotations"] = anns
  154. # MODIFICATION: inherit images
  155. res.imgs = self.imgs
  156. # END MODIFICATION
  157. res.createIndex()
  158. return res
  159. class CGF1Eval(COCOeval):
  160. """
  161. This evaluator is based upon COCO evaluation, but evaluates the model in a more realistic setting
  162. for downstream applications.
  163. See SAM3 paper for the details on the CGF1 metric.
  164. Do not use this evaluator directly. Prefer the CGF1Evaluator wrapper.
  165. Notes:
  166. - This evaluator does not support per-category evaluation (in the way defined by pyCocotools)
  167. - In open vocabulary settings, we have different noun-phrases for each image. What we call an "image_id" here is actually an (image, noun-phrase) pair. So in every "image_id" there is only one category, implied by the noun-phrase. Thus we can ignore the usual coco "category" field of the predictions
  168. """
  169. def __init__(
  170. self,
  171. coco_gt=None,
  172. coco_dt=None,
  173. iouType="segm",
  174. threshold=0.5,
  175. ):
  176. """
  177. Args:
  178. coco_gt (COCO): ground truth COCO API
  179. coco_dt (COCO): detections COCO API
  180. iou_type (str): type of IoU to evaluate
  181. threshold (float): threshold for predictions
  182. """
  183. super().__init__(coco_gt, coco_dt, iouType)
  184. self.threshold = threshold
  185. self.params.useCats = False
  186. self.params.areaRng = [[0**2, 1e5**2]]
  187. self.params.areaRngLbl = ["all"]
  188. self.params.maxDets = [1000000]
  189. def computeIoU(self, imgId, catId):
  190. # Same as the original COCOeval.computeIoU, but without sorting
  191. p = self.params
  192. if p.useCats:
  193. gt = self._gts[imgId, catId]
  194. dt = self._dts[imgId, catId]
  195. else:
  196. gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
  197. dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
  198. if len(gt) == 0 and len(dt) == 0:
  199. return []
  200. if p.iouType == "segm":
  201. g = [g["segmentation"] for g in gt]
  202. d = [d["segmentation"] for d in dt]
  203. elif p.iouType == "bbox":
  204. g = [g["bbox"] for g in gt]
  205. d = [d["bbox"] for d in dt]
  206. else:
  207. raise Exception("unknown iouType for iou computation")
  208. # compute iou between each dt and gt region
  209. iscrowd = [int(o["iscrowd"]) for o in gt]
  210. ious = maskUtils.iou(d, g, iscrowd)
  211. return ious
  212. def evaluateImg(self, imgId, catId, aRng, maxDet):
  213. """
  214. perform evaluation for single category and image
  215. :return: dict (single image results)
  216. """
  217. p = self.params
  218. assert not p.useCats, "This evaluator does not support per-category evaluation."
  219. assert catId == -1
  220. all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
  221. keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool)
  222. gt = [g for g in all_gts if not g["ignore"]]
  223. all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
  224. keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool)
  225. dt = [d for d in all_dts if d["score"] >= self.threshold]
  226. if len(gt) == 0 and len(dt) == 0:
  227. # This is a "true negative" case, where there are no GTs and no predictions
  228. # The box-level metrics are ill-defined, so we don't add them to this dict
  229. return {
  230. "image_id": imgId,
  231. "IL_TP": 0,
  232. "IL_TN": 1,
  233. "IL_FP": 0,
  234. "IL_FN": 0,
  235. "num_dt": len(dt),
  236. }
  237. if len(gt) > 0 and len(dt) == 0:
  238. # This is a "false negative" case, where there are GTs but no predictions
  239. return {
  240. "image_id": imgId,
  241. "IL_TP": 0,
  242. "IL_TN": 0,
  243. "IL_FP": 0,
  244. "IL_FN": 1,
  245. "TPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
  246. "FPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
  247. "FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt),
  248. "local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
  249. "local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
  250. "num_dt": len(dt),
  251. }
  252. # Load pre-computed ious
  253. ious = self.ious[(imgId, catId)]
  254. # compute matching
  255. if len(ious) == 0:
  256. ious = np.zeros((len(dt), len(gt)))
  257. else:
  258. ious = ious[keep_dt, :][:, keep_gt]
  259. assert ious.shape == (len(dt), len(gt))
  260. matched_dt, matched_gt = linear_sum_assignment(-ious)
  261. match_scores = ious[matched_dt, matched_gt]
  262. TPs, FPs, FNs = [], [], []
  263. IL_perfect = []
  264. for thresh in p.iouThrs:
  265. TP = (match_scores >= thresh).sum()
  266. FP = len(dt) - TP
  267. FN = len(gt) - TP
  268. assert FP >= 0 and FN >= 0, (
  269. f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
  270. )
  271. TPs.append(TP)
  272. FPs.append(FP)
  273. FNs.append(FN)
  274. if FP == FN and FP == 0:
  275. IL_perfect.append(1)
  276. else:
  277. IL_perfect.append(0)
  278. TPs = np.array(TPs, dtype=np.int64)
  279. FPs = np.array(FPs, dtype=np.int64)
  280. FNs = np.array(FNs, dtype=np.int64)
  281. IL_perfect = np.array(IL_perfect, dtype=np.int64)
  282. # compute precision recall and F1
  283. precision = TPs / (TPs + FPs + 1e-4)
  284. assert np.all(precision <= 1)
  285. recall = TPs / (TPs + FNs + 1e-4)
  286. assert np.all(recall <= 1)
  287. F1 = 2 * precision * recall / (precision + recall + 1e-4)
  288. result = {
  289. "image_id": imgId,
  290. "TPs": TPs,
  291. "FPs": FPs,
  292. "FNs": FNs,
  293. "local_F1s": F1,
  294. "IL_TP": (len(gt) > 0) and (len(dt) > 0),
  295. "IL_FP": (len(gt) == 0) and (len(dt) > 0),
  296. "IL_TN": (len(gt) == 0) and (len(dt) == 0),
  297. "IL_FN": (len(gt) > 0) and (len(dt) == 0),
  298. "num_dt": len(dt),
  299. }
  300. if len(gt) > 0 and len(dt) > 0:
  301. result["local_positive_F1s"] = F1
  302. return result
  303. def accumulate(self, p=None):
  304. """
  305. Accumulate per image evaluation results and store the result in self.eval
  306. :param p: input params for evaluation
  307. :return: None
  308. """
  309. if self.evalImgs is None or len(self.evalImgs) == 0:
  310. print("Please run evaluate() first")
  311. # allows input customized parameters
  312. if p is None:
  313. p = self.params
  314. setImgIds = set(p.imgIds)
  315. # TPs, FPs, FNs
  316. TPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  317. FPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  318. pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  319. FNs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  320. local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64)
  321. # Image level metrics
  322. IL_TPs = 0
  323. IL_FPs = 0
  324. IL_TNs = 0
  325. IL_FNs = 0
  326. valid_img_count = 0
  327. valid_F1_count = 0
  328. evaledImgIds = set()
  329. for res in self.evalImgs:
  330. if res["image_id"] not in setImgIds:
  331. continue
  332. evaledImgIds.add(res["image_id"])
  333. IL_TPs += res["IL_TP"]
  334. IL_FPs += res["IL_FP"]
  335. IL_TNs += res["IL_TN"]
  336. IL_FNs += res["IL_FN"]
  337. if "TPs" not in res:
  338. continue
  339. TPs += res["TPs"]
  340. FPs += res["FPs"]
  341. FNs += res["FNs"]
  342. valid_img_count += 1
  343. if "local_positive_F1s" in res:
  344. local_F1s += res["local_positive_F1s"]
  345. pmFPs += res["FPs"]
  346. if res["num_dt"] > 0:
  347. valid_F1_count += 1
  348. assert len(setImgIds - evaledImgIds) == 0, (
  349. f"{len(setImgIds - evaledImgIds)} images not evaluated. "
  350. f"Here are the IDs of the first 3: {list(setImgIds - evaledImgIds)[:3]}"
  351. )
  352. # compute precision recall and F1
  353. precision = TPs / (TPs + FPs + 1e-4)
  354. positive_micro_precision = TPs / (TPs + pmFPs + 1e-4)
  355. assert np.all(precision <= 1)
  356. recall = TPs / (TPs + FNs + 1e-4)
  357. assert np.all(recall <= 1)
  358. F1 = 2 * precision * recall / (precision + recall + 1e-4)
  359. positive_micro_F1 = (
  360. 2
  361. * positive_micro_precision
  362. * recall
  363. / (positive_micro_precision + recall + 1e-4)
  364. )
  365. IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6)
  366. IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6)
  367. IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6)
  368. IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6)
  369. IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / (
  370. (
  371. float(IL_TPs + IL_FPs)
  372. * float(IL_TPs + IL_FNs)
  373. * float(IL_TNs + IL_FPs)
  374. * float(IL_TNs + IL_FNs)
  375. )
  376. ** 0.5
  377. + 1e-6
  378. )
  379. self.eval = {
  380. "params": p,
  381. "TPs": TPs,
  382. "FPs": FPs,
  383. "positive_micro_FPs": pmFPs,
  384. "FNs": FNs,
  385. "precision": precision,
  386. "positive_micro_precision": positive_micro_precision,
  387. "recall": recall,
  388. "F1": F1,
  389. "positive_micro_F1": positive_micro_F1,
  390. "positive_macro_F1": local_F1s / valid_F1_count,
  391. "IL_recall": IL_rec,
  392. "IL_precision": IL_prec,
  393. "IL_F1": IL_F1,
  394. "IL_FPR": IL_FPR,
  395. "IL_MCC": IL_MCC,
  396. }
  397. self.eval["cgF1"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"]
  398. def summarize(self):
  399. """
  400. Compute and display summary metrics for evaluation results.
  401. """
  402. if not self.eval:
  403. raise Exception("Please run accumulate() first")
  404. def _summarize(iouThr=None, metric=""):
  405. p = self.params
  406. iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}"
  407. titleStr = "Average " + metric
  408. iouStr = (
  409. "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
  410. if iouThr is None
  411. else "{:0.2f}".format(iouThr)
  412. )
  413. s = self.eval[metric]
  414. # IoU
  415. if iouThr is not None:
  416. t = np.where(iouThr == p.iouThrs)[0]
  417. s = s[t]
  418. if len(s[s > -1]) == 0:
  419. mean_s = -1
  420. else:
  421. mean_s = np.mean(s[s > -1])
  422. print(iStr.format(titleStr, iouStr, mean_s))
  423. return mean_s
  424. def _summarize_single(metric=""):
  425. titleStr = "Average " + metric
  426. iStr = " {:<35} = {:0.3f}"
  427. s = self.eval[metric]
  428. print(iStr.format(titleStr, s))
  429. return s
  430. def _summarizeDets():
  431. stats = []
  432. for metric in CGF1_METRICS:
  433. if metric.image_level:
  434. stats.append(_summarize_single(metric=metric.name))
  435. else:
  436. stats.append(
  437. _summarize(iouThr=metric.iou_threshold, metric=metric.name)
  438. )
  439. return np.asarray(stats)
  440. summarize = _summarizeDets
  441. self.stats = summarize()
  442. def _evaluate(self):
  443. """
  444. Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
  445. """
  446. p = self.params
  447. # add backward compatibility if useSegm is specified in params
  448. p.imgIds = list(np.unique(p.imgIds))
  449. p.useCats = False
  450. p.maxDets = sorted(p.maxDets)
  451. self.params = p
  452. self._prepare()
  453. # loop through images, area range, max detection number
  454. catIds = [-1]
  455. if p.iouType == "segm" or p.iouType == "bbox":
  456. computeIoU = self.computeIoU
  457. else:
  458. raise RuntimeError(f"Unsupported iou {p.iouType}")
  459. self.ious = {
  460. (imgId, catId): computeIoU(imgId, catId)
  461. for imgId in p.imgIds
  462. for catId in catIds
  463. }
  464. maxDet = p.maxDets[-1]
  465. evalImgs = [
  466. self.evaluateImg(imgId, catId, areaRng, maxDet)
  467. for catId in catIds
  468. for areaRng in p.areaRng
  469. for imgId in p.imgIds
  470. ]
  471. # this is NOT in the pycocotools code, but could be done outside
  472. evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
  473. return p.imgIds, evalImgs
  474. class CGF1Evaluator:
  475. """
  476. Wrapper class for cgF1 evaluation.
  477. This supports the oracle setting (when several ground-truths are available per image)
  478. """
  479. def __init__(
  480. self,
  481. gt_path: Union[str, List[str]],
  482. iou_type="segm",
  483. verbose=False,
  484. ):
  485. """
  486. Args:
  487. gt_path (str or list of str): path(s) to ground truth COCO json file(s)
  488. iou_type (str): type of IoU to evaluate
  489. threshold (float): threshold for predictions
  490. """
  491. self.gt_paths = gt_path if isinstance(gt_path, list) else [gt_path]
  492. self.iou_type = iou_type
  493. self.coco_gts = [COCOCustom(gt) for gt in self.gt_paths]
  494. self.verbose = verbose
  495. self.coco_evals = []
  496. for i, coco_gt in enumerate(self.coco_gts):
  497. self.coco_evals.append(
  498. CGF1Eval(
  499. coco_gt=coco_gt,
  500. iouType=iou_type,
  501. )
  502. )
  503. self.coco_evals[i].useCats = False
  504. exclude_img_ids = set()
  505. # exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts
  506. for coco_gt in self.coco_gts[1:]:
  507. exclude_img_ids = exclude_img_ids.union(
  508. {
  509. img["id"]
  510. for img in coco_gt.dataset["images"]
  511. if not img["is_instance_exhaustive"]
  512. }
  513. )
  514. # we only eval on instance exhaustive queries
  515. self.eval_img_ids = [
  516. img["id"]
  517. for img in self.coco_gts[0].dataset["images"]
  518. if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids)
  519. ]
  520. def evaluate(self, pred_file: str):
  521. """
  522. Evaluate the detections using cgF1 metric.
  523. Args:
  524. pred_file: path to the predictions COCO json file
  525. """
  526. assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
  527. assert len(self.coco_gts) == len(self.coco_evals), (
  528. "Mismatch in number of ground truths and evaluators."
  529. )
  530. if self.verbose:
  531. print(f"Loading predictions from {pred_file}")
  532. with open(pred_file, "r") as f:
  533. preds = json.load(f)
  534. if self.verbose:
  535. print(f"Loaded {len(preds)} predictions")
  536. img2preds = defaultdict(list)
  537. for pred in preds:
  538. img2preds[pred["image_id"]].append(pred)
  539. all_eval_imgs = []
  540. for img_id in tqdm(self.eval_img_ids, disable=not self.verbose):
  541. results = img2preds[img_id]
  542. all_scorings = []
  543. for cur_coco_gt, coco_eval in zip(self.coco_gts, self.coco_evals):
  544. # suppress pycocotools prints
  545. with open(os.devnull, "w") as devnull:
  546. with contextlib.redirect_stdout(devnull):
  547. coco_dt = (
  548. cur_coco_gt.loadRes(results) if results else COCOCustom()
  549. )
  550. coco_eval.cocoDt = coco_dt
  551. coco_eval.params.imgIds = [img_id]
  552. coco_eval.params.useCats = False
  553. img_ids, eval_imgs = _evaluate(coco_eval)
  554. all_scorings.append(eval_imgs)
  555. selected = self._select_best_scoring(all_scorings)
  556. all_eval_imgs.append(selected)
  557. # After this point, we have selected the best scoring per image among several ground truths
  558. # we can now accumulate and summarize, using only the first coco_eval
  559. self.coco_evals[0].evalImgs = list(
  560. np.concatenate(all_eval_imgs, axis=2).flatten()
  561. )
  562. self.coco_evals[0].params.imgIds = self.eval_img_ids
  563. self.coco_evals[0]._paramsEval = copy.deepcopy(self.coco_evals[0].params)
  564. if self.verbose:
  565. print(f"Accumulating results")
  566. self.coco_evals[0].accumulate()
  567. print("cgF1 metric, IoU type={}".format(self.iou_type))
  568. self.coco_evals[0].summarize()
  569. print()
  570. out = {}
  571. for i, value in enumerate(self.coco_evals[0].stats):
  572. name = CGF1_METRICS[i].name
  573. if CGF1_METRICS[i].iou_threshold is not None:
  574. name = f"{name}@{CGF1_METRICS[i].iou_threshold}"
  575. out[f"cgF1_eval_{self.iou_type}_{name}"] = float(value)
  576. return out
  577. @staticmethod
  578. def _select_best_scoring(scorings):
  579. # This function is used for "oracle" type evaluation.
  580. # It accepts the evaluation results with respect to several ground truths, and picks the best
  581. if len(scorings) == 1:
  582. return scorings[0]
  583. assert scorings[0].ndim == 3, (
  584. f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
  585. )
  586. assert scorings[0].shape[0] == 1, (
  587. f"Expecting a single category, got {scorings[0].shape[0]}"
  588. )
  589. for scoring in scorings:
  590. assert scoring.shape == scorings[0].shape, (
  591. f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
  592. )
  593. selected_imgs = []
  594. for img_id in range(scorings[0].shape[-1]):
  595. best = scorings[0][:, :, img_id]
  596. for scoring in scorings[1:]:
  597. current = scoring[:, :, img_id]
  598. if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]:
  599. # we were able to compute a F1 score for this particular image in both evaluations
  600. # best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision
  601. best_score = best[0, 0]["local_F1s"].mean()
  602. current_score = current[0, 0]["local_F1s"].mean()
  603. if current_score > best_score:
  604. best = current
  605. else:
  606. # If we're here, it means that in that in some evaluation we were not able to get a valid local F1
  607. # This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction
  608. if "local_F1s" not in current[0, 0]:
  609. best = current
  610. selected_imgs.append(best)
  611. result = np.stack(selected_imgs, axis=-1)
  612. assert result.shape == scorings[0].shape
  613. return result