demo_eval.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting.
  5. This means that the model's predictions are thresholded and evaluated as "hard" predictions.
  6. """
  7. import logging
  8. from typing import Optional
  9. import numpy as np
  10. import pycocotools.mask as maskUtils
  11. from pycocotools.cocoeval import COCOeval
  12. from sam3.eval.coco_eval import CocoEvaluator
  13. from sam3.train.masks_ops import compute_F_measure
  14. from sam3.train.utils.distributed import is_main_process
  15. from scipy.optimize import linear_sum_assignment
  16. class DemoEval(COCOeval):
  17. """
  18. This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting.
  19. This means that the model's predictions are thresholded and evaluated as "hard" predictions.
  20. """
  21. def __init__(
  22. self,
  23. coco_gt=None,
  24. coco_dt=None,
  25. iouType="bbox",
  26. threshold=0.5,
  27. compute_JnF=False,
  28. ):
  29. """
  30. Args:
  31. coco_gt (COCO): ground truth COCO API
  32. coco_dt (COCO): detections COCO API
  33. iou_type (str): type of IoU to evaluate
  34. threshold (float): threshold for predictions
  35. """
  36. super().__init__(coco_gt, coco_dt, iouType)
  37. self.threshold = threshold
  38. self.params.useCats = False
  39. self.params.areaRng = [[0**2, 1e5**2]]
  40. self.params.areaRngLbl = ["all"]
  41. self.params.maxDets = [100000]
  42. self.compute_JnF = compute_JnF
  43. def computeIoU(self, imgId, catId):
  44. # Same as the original COCOeval.computeIoU, but without sorting
  45. p = self.params
  46. if p.useCats:
  47. gt = self._gts[imgId, catId]
  48. dt = self._dts[imgId, catId]
  49. else:
  50. gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
  51. dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
  52. if len(gt) == 0 and len(dt) == 0:
  53. return []
  54. if p.iouType == "segm":
  55. g = [g["segmentation"] for g in gt]
  56. d = [d["segmentation"] for d in dt]
  57. elif p.iouType == "bbox":
  58. g = [g["bbox"] for g in gt]
  59. d = [d["bbox"] for d in dt]
  60. else:
  61. raise Exception("unknown iouType for iou computation")
  62. # compute iou between each dt and gt region
  63. iscrowd = [int(o["iscrowd"]) for o in gt]
  64. ious = maskUtils.iou(d, g, iscrowd)
  65. return ious
  66. def evaluateImg(self, imgId, catId, aRng, maxDet):
  67. """
  68. perform evaluation for single category and image
  69. :return: dict (single image results)
  70. """
  71. p = self.params
  72. assert not p.useCats, "This evaluator does not support per-category evaluation."
  73. assert catId == -1
  74. all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
  75. keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool)
  76. gt = [g for g in all_gts if not g["ignore"]]
  77. all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
  78. keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool)
  79. dt = [d for d in all_dts if d["score"] >= self.threshold]
  80. if len(gt) == 0 and len(dt) == 0:
  81. # This is a "true negative" case, where there are no GTs and no predictions
  82. # The box-level metrics are ill-defined, so we don't add them to this dict
  83. return {
  84. "image_id": imgId,
  85. "IL_TP": 0,
  86. "IL_TN": 1,
  87. "IL_FP": 0,
  88. "IL_FN": 0,
  89. "IL_perfect_neg": np.ones((len(p.iouThrs),), dtype=np.int64),
  90. "num_dt": len(dt),
  91. }
  92. if len(gt) > 0 and len(dt) == 0:
  93. # This is a "false negative" case, where there are GTs but no predictions
  94. return {
  95. "image_id": imgId,
  96. "IL_TP": 0,
  97. "IL_TN": 0,
  98. "IL_FP": 0,
  99. "IL_FN": 1,
  100. "TPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
  101. "FPs": np.zeros((len(p.iouThrs),), dtype=np.int64),
  102. "FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt),
  103. "local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
  104. "local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64),
  105. "IL_perfect_pos": np.zeros((len(p.iouThrs),), dtype=np.int64),
  106. "num_dt": len(dt),
  107. }
  108. # Load pre-computed ious
  109. ious = self.ious[(imgId, catId)]
  110. # compute matching
  111. if len(ious) == 0:
  112. ious = np.zeros((len(dt), len(gt)))
  113. else:
  114. ious = ious[keep_dt, :][:, keep_gt]
  115. assert ious.shape == (len(dt), len(gt))
  116. matched_dt, matched_gt = linear_sum_assignment(-ious)
  117. match_scores = ious[matched_dt, matched_gt]
  118. if self.compute_JnF and len(match_scores) > 0:
  119. j_score = match_scores.mean()
  120. f_measure = 0
  121. for dt_id, gt_id in zip(matched_dt, matched_gt):
  122. f_measure += compute_F_measure(
  123. gt_boundary_rle=gt[gt_id]["boundary"],
  124. gt_dilated_boundary_rle=gt[gt_id]["dilated_boundary"],
  125. dt_boundary_rle=dt[dt_id]["boundary"],
  126. dt_dilated_boundary_rle=dt[dt_id]["dilated_boundary"],
  127. )
  128. f_measure /= len(match_scores) + 1e-9
  129. JnF = (j_score + f_measure) * 0.5
  130. else:
  131. j_score = f_measure = JnF = -1
  132. TPs, FPs, FNs = [], [], []
  133. IL_perfect = []
  134. for thresh in p.iouThrs:
  135. TP = (match_scores >= thresh).sum()
  136. FP = len(dt) - TP
  137. FN = len(gt) - TP
  138. assert FP >= 0 and FN >= 0, (
  139. f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
  140. )
  141. TPs.append(TP)
  142. FPs.append(FP)
  143. FNs.append(FN)
  144. if FP == FN and FP == 0:
  145. IL_perfect.append(1)
  146. else:
  147. IL_perfect.append(0)
  148. TPs = np.array(TPs, dtype=np.int64)
  149. FPs = np.array(FPs, dtype=np.int64)
  150. FNs = np.array(FNs, dtype=np.int64)
  151. IL_perfect = np.array(IL_perfect, dtype=np.int64)
  152. # compute precision recall and F1
  153. precision = TPs / (TPs + FPs + 1e-4)
  154. assert np.all(precision <= 1)
  155. recall = TPs / (TPs + FNs + 1e-4)
  156. assert np.all(recall <= 1)
  157. F1 = 2 * precision * recall / (precision + recall + 1e-4)
  158. result = {
  159. "image_id": imgId,
  160. "TPs": TPs,
  161. "FPs": FPs,
  162. "FNs": FNs,
  163. "local_F1s": F1,
  164. "IL_TP": (len(gt) > 0) and (len(dt) > 0),
  165. "IL_FP": (len(gt) == 0) and (len(dt) > 0),
  166. "IL_TN": (len(gt) == 0) and (len(dt) == 0),
  167. "IL_FN": (len(gt) > 0) and (len(dt) == 0),
  168. ("IL_perfect_pos" if len(gt) > 0 else "IL_perfect_neg"): IL_perfect,
  169. "F": f_measure,
  170. "J": j_score,
  171. "J&F": JnF,
  172. "num_dt": len(dt),
  173. }
  174. if len(gt) > 0 and len(dt) > 0:
  175. result["local_positive_F1s"] = F1
  176. return result
  177. def accumulate(self, p=None):
  178. """
  179. Accumulate per image evaluation results and store the result in self.eval
  180. :param p: input params for evaluation
  181. :return: None
  182. """
  183. if not self.evalImgs:
  184. print("Please run evaluate() first")
  185. # allows input customized parameters
  186. if p is None:
  187. p = self.params
  188. setImgIds = set(p.imgIds)
  189. # TPs, FPs, FNs
  190. TPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  191. FPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  192. pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  193. FNs = np.zeros((len(p.iouThrs),), dtype=np.int64)
  194. local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64)
  195. # Image level metrics
  196. IL_TPs = 0
  197. IL_FPs = 0
  198. IL_TNs = 0
  199. IL_FNs = 0
  200. IL_perfects_neg = np.zeros((len(p.iouThrs),), dtype=np.int64)
  201. IL_perfects_pos = np.zeros((len(p.iouThrs),), dtype=np.int64)
  202. # JnF metric
  203. total_J = 0
  204. total_F = 0
  205. total_JnF = 0
  206. valid_img_count = 0
  207. total_pos_count = 0
  208. total_neg_count = 0
  209. valid_J_count = 0
  210. valid_F1_count = 0
  211. valid_F1_count_w0dt = 0
  212. for res in self.evalImgs:
  213. if res["image_id"] not in setImgIds:
  214. continue
  215. IL_TPs += res["IL_TP"]
  216. IL_FPs += res["IL_FP"]
  217. IL_TNs += res["IL_TN"]
  218. IL_FNs += res["IL_FN"]
  219. if "IL_perfect_neg" in res:
  220. IL_perfects_neg += res["IL_perfect_neg"]
  221. total_neg_count += 1
  222. else:
  223. assert "IL_perfect_pos" in res
  224. IL_perfects_pos += res["IL_perfect_pos"]
  225. total_pos_count += 1
  226. if "TPs" not in res:
  227. continue
  228. TPs += res["TPs"]
  229. FPs += res["FPs"]
  230. FNs += res["FNs"]
  231. valid_img_count += 1
  232. if "local_positive_F1s" in res:
  233. local_F1s += res["local_positive_F1s"]
  234. pmFPs += res["FPs"]
  235. valid_F1_count_w0dt += 1
  236. if res["num_dt"] > 0:
  237. valid_F1_count += 1
  238. if "J" in res and res["J"] > -1e-9:
  239. total_J += res["J"]
  240. total_F += res["F"]
  241. total_JnF += res["J&F"]
  242. valid_J_count += 1
  243. # compute precision recall and F1
  244. precision = TPs / (TPs + FPs + 1e-4)
  245. positive_micro_precision = TPs / (TPs + pmFPs + 1e-4)
  246. assert np.all(precision <= 1)
  247. recall = TPs / (TPs + FNs + 1e-4)
  248. assert np.all(recall <= 1)
  249. F1 = 2 * precision * recall / (precision + recall + 1e-4)
  250. positive_micro_F1 = (
  251. 2
  252. * positive_micro_precision
  253. * recall
  254. / (positive_micro_precision + recall + 1e-4)
  255. )
  256. IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6)
  257. IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6)
  258. IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6)
  259. IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6)
  260. IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / (
  261. (
  262. float(IL_TPs + IL_FPs)
  263. * float(IL_TPs + IL_FNs)
  264. * float(IL_TNs + IL_FPs)
  265. * float(IL_TNs + IL_FNs)
  266. )
  267. ** 0.5
  268. + 1e-6
  269. )
  270. IL_perfect_pos = IL_perfects_pos / (total_pos_count + 1e-9)
  271. IL_perfect_neg = IL_perfects_neg / (total_neg_count + 1e-9)
  272. total_J = total_J / (valid_J_count + 1e-9)
  273. total_F = total_F / (valid_J_count + 1e-9)
  274. total_JnF = total_JnF / (valid_J_count + 1e-9)
  275. self.eval = {
  276. "params": p,
  277. "TPs": TPs,
  278. "FPs": FPs,
  279. "positive_micro_FPs": pmFPs,
  280. "FNs": FNs,
  281. "precision": precision,
  282. "positive_micro_precision": positive_micro_precision,
  283. "recall": recall,
  284. "F1": F1,
  285. "positive_micro_F1": positive_micro_F1,
  286. "positive_macro_F1": local_F1s / valid_F1_count,
  287. "positive_w0dt_macro_F1": local_F1s / valid_F1_count_w0dt,
  288. "IL_recall": IL_rec,
  289. "IL_precision": IL_prec,
  290. "IL_F1": IL_F1,
  291. "IL_FPR": IL_FPR,
  292. "IL_MCC": IL_MCC,
  293. "IL_perfect_pos": IL_perfect_pos,
  294. "IL_perfect_neg": IL_perfect_neg,
  295. "J": total_J,
  296. "F": total_F,
  297. "J&F": total_JnF,
  298. }
  299. self.eval["CGF1"] = self.eval["positive_macro_F1"] * self.eval["IL_MCC"]
  300. self.eval["CGF1_w0dt"] = (
  301. self.eval["positive_w0dt_macro_F1"] * self.eval["IL_MCC"]
  302. )
  303. self.eval["CGF1_micro"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"]
  304. def summarize(self):
  305. """
  306. Compute and display summary metrics for evaluation results.
  307. Note this functin can *only* be applied on the default parameter setting
  308. """
  309. if not self.eval:
  310. raise Exception("Please run accumulate() first")
  311. def _summarize(iouThr=None, metric=""):
  312. p = self.params
  313. iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}"
  314. titleStr = "Average " + metric
  315. iouStr = (
  316. "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
  317. if iouThr is None
  318. else "{:0.2f}".format(iouThr)
  319. )
  320. s = self.eval[metric]
  321. # IoU
  322. if iouThr is not None:
  323. t = np.where(iouThr == p.iouThrs)[0]
  324. s = s[t]
  325. if len(s[s > -1]) == 0:
  326. mean_s = -1
  327. else:
  328. mean_s = np.mean(s[s > -1])
  329. print(iStr.format(titleStr, iouStr, mean_s))
  330. return mean_s
  331. def _summarize_single(metric=""):
  332. titleStr = "Average " + metric
  333. iStr = " {:<35} = {:0.3f}"
  334. s = self.eval[metric]
  335. print(iStr.format(titleStr, s))
  336. return s
  337. def _summarizeDets():
  338. # note: the index of these metrics are also used in video Demo F1 evaluation
  339. # when adding new metrics, please update the index in video Demo F1 evaluation
  340. # in "evaluate" method of the "VideoDemoF1Evaluator" class
  341. stats = np.zeros((len(DEMO_METRICS),))
  342. stats[0] = _summarize(metric="CGF1")
  343. stats[1] = _summarize(metric="precision")
  344. stats[2] = _summarize(metric="recall")
  345. stats[3] = _summarize(metric="F1")
  346. stats[4] = _summarize(metric="positive_macro_F1")
  347. stats[5] = _summarize_single(metric="IL_precision")
  348. stats[6] = _summarize_single(metric="IL_recall")
  349. stats[7] = _summarize_single(metric="IL_F1")
  350. stats[8] = _summarize_single(metric="IL_FPR")
  351. stats[9] = _summarize_single(metric="IL_MCC")
  352. stats[10] = _summarize(metric="IL_perfect_pos")
  353. stats[11] = _summarize(metric="IL_perfect_neg")
  354. stats[12] = _summarize(iouThr=0.5, metric="CGF1")
  355. stats[13] = _summarize(iouThr=0.5, metric="precision")
  356. stats[14] = _summarize(iouThr=0.5, metric="recall")
  357. stats[15] = _summarize(iouThr=0.5, metric="F1")
  358. stats[16] = _summarize(iouThr=0.5, metric="positive_macro_F1")
  359. stats[17] = _summarize(iouThr=0.5, metric="IL_perfect_pos")
  360. stats[18] = _summarize(iouThr=0.5, metric="IL_perfect_neg")
  361. stats[19] = _summarize(iouThr=0.75, metric="CGF1")
  362. stats[20] = _summarize(iouThr=0.75, metric="precision")
  363. stats[21] = _summarize(iouThr=0.75, metric="recall")
  364. stats[22] = _summarize(iouThr=0.75, metric="F1")
  365. stats[23] = _summarize(iouThr=0.75, metric="positive_macro_F1")
  366. stats[24] = _summarize(iouThr=0.75, metric="IL_perfect_pos")
  367. stats[25] = _summarize(iouThr=0.75, metric="IL_perfect_neg")
  368. stats[26] = _summarize_single(metric="J")
  369. stats[27] = _summarize_single(metric="F")
  370. stats[28] = _summarize_single(metric="J&F")
  371. stats[29] = _summarize(metric="CGF1_micro")
  372. stats[30] = _summarize(metric="positive_micro_precision")
  373. stats[31] = _summarize(metric="positive_micro_F1")
  374. stats[32] = _summarize(iouThr=0.5, metric="CGF1_micro")
  375. stats[33] = _summarize(iouThr=0.5, metric="positive_micro_precision")
  376. stats[34] = _summarize(iouThr=0.5, metric="positive_micro_F1")
  377. stats[35] = _summarize(iouThr=0.75, metric="CGF1_micro")
  378. stats[36] = _summarize(iouThr=0.75, metric="positive_micro_precision")
  379. stats[37] = _summarize(iouThr=0.75, metric="positive_micro_F1")
  380. stats[38] = _summarize(metric="CGF1_w0dt")
  381. stats[39] = _summarize(metric="positive_w0dt_macro_F1")
  382. stats[40] = _summarize(iouThr=0.5, metric="CGF1_w0dt")
  383. stats[41] = _summarize(iouThr=0.5, metric="positive_w0dt_macro_F1")
  384. stats[42] = _summarize(iouThr=0.75, metric="CGF1_w0dt")
  385. stats[43] = _summarize(iouThr=0.75, metric="positive_w0dt_macro_F1")
  386. return stats
  387. summarize = _summarizeDets
  388. self.stats = summarize()
  389. DEMO_METRICS = [
  390. "CGF1",
  391. "Precision",
  392. "Recall",
  393. "F1",
  394. "Macro_F1",
  395. "IL_Precision",
  396. "IL_Recall",
  397. "IL_F1",
  398. "IL_FPR",
  399. "IL_MCC",
  400. "IL_perfect_pos",
  401. "IL_perfect_neg",
  402. "CGF1@0.5",
  403. "Precision@0.5",
  404. "Recall@0.5",
  405. "F1@0.5",
  406. "Macro_F1@0.5",
  407. "IL_perfect_pos@0.5",
  408. "IL_perfect_neg@0.5",
  409. "CGF1@0.75",
  410. "Precision@0.75",
  411. "Recall@0.75",
  412. "F1@0.75",
  413. "Macro_F1@0.75",
  414. "IL_perfect_pos@0.75",
  415. "IL_perfect_neg@0.75",
  416. "J",
  417. "F",
  418. "J&F",
  419. "CGF1_micro",
  420. "positive_micro_Precision",
  421. "positive_micro_F1",
  422. "CGF1_micro@0.5",
  423. "positive_micro_Precision@0.5",
  424. "positive_micro_F1@0.5",
  425. "CGF1_micro@0.75",
  426. "positive_micro_Precision@0.75",
  427. "positive_micro_F1@0.75",
  428. "CGF1_w0dt",
  429. "positive_w0dt_macro_F1",
  430. "CGF1_w0dt@0.5",
  431. "positive_w0dt_macro_F1@0.5",
  432. "CGF1_w0dt@0.75",
  433. "positive_w0dt_macro_F1@0.75",
  434. ]
  435. class DemoEvaluator(CocoEvaluator):
  436. def __init__(
  437. self,
  438. coco_gt,
  439. iou_types,
  440. dump_dir: Optional[str],
  441. postprocessor,
  442. threshold=0.5,
  443. average_by_rarity=False,
  444. gather_pred_via_filesys=False,
  445. exhaustive_only=False,
  446. all_exhaustive_only=True,
  447. compute_JnF=False,
  448. metrics_dump_dir: Optional[str] = None,
  449. ):
  450. self.iou_types = iou_types
  451. self.threshold = threshold
  452. super().__init__(
  453. coco_gt=coco_gt,
  454. iou_types=iou_types,
  455. useCats=False,
  456. dump_dir=dump_dir,
  457. postprocessor=postprocessor,
  458. # average_by_rarity=average_by_rarity,
  459. gather_pred_via_filesys=gather_pred_via_filesys,
  460. exhaustive_only=exhaustive_only,
  461. all_exhaustive_only=all_exhaustive_only,
  462. metrics_dump_dir=metrics_dump_dir,
  463. )
  464. self.use_self_evaluate = True
  465. self.compute_JnF = compute_JnF
  466. def _lazy_init(self):
  467. if self.initialized:
  468. return
  469. super()._lazy_init()
  470. self.use_self_evaluate = True
  471. self.reset()
  472. def select_best_scoring(self, scorings):
  473. # This function is used for "oracle" type evaluation.
  474. # It accepts the evaluation results with respect to several ground truths, and picks the best
  475. if len(scorings) == 1:
  476. return scorings[0]
  477. assert scorings[0].ndim == 3, (
  478. f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
  479. )
  480. assert scorings[0].shape[0] == 1, (
  481. f"Expecting a single category, got {scorings[0].shape[0]}"
  482. )
  483. for scoring in scorings:
  484. assert scoring.shape == scorings[0].shape, (
  485. f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
  486. )
  487. selected_imgs = []
  488. for img_id in range(scorings[0].shape[-1]):
  489. best = scorings[0][:, :, img_id]
  490. for scoring in scorings[1:]:
  491. current = scoring[:, :, img_id]
  492. if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]:
  493. # we were able to compute a F1 score for this particular image in both evaluations
  494. # best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision
  495. best_score = best[0, 0]["local_F1s"].mean()
  496. current_score = current[0, 0]["local_F1s"].mean()
  497. if current_score > best_score:
  498. best = current
  499. else:
  500. # If we're here, it means that in that in some evaluation we were not able to get a valid local F1
  501. # This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction
  502. if "local_F1s" not in current[0, 0]:
  503. best = current
  504. selected_imgs.append(best)
  505. result = np.stack(selected_imgs, axis=-1)
  506. assert result.shape == scorings[0].shape
  507. return result
  508. def summarize(self):
  509. self._lazy_init()
  510. logging.info("Demo evaluator: Summarizing")
  511. if not is_main_process():
  512. return {}
  513. outs = {}
  514. prefix = "oracle_" if len(self.coco_evals) > 1 else ""
  515. # if self.rarity_buckets is None:
  516. self.accumulate(self.eval_img_ids)
  517. for iou_type, coco_eval in self.coco_evals[0].items():
  518. print("Demo metric, IoU type={}".format(iou_type))
  519. coco_eval.summarize()
  520. if "bbox" in self.coco_evals[0]:
  521. for i, value in enumerate(self.coco_evals[0]["bbox"].stats):
  522. outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value
  523. if "segm" in self.coco_evals[0]:
  524. for i, value in enumerate(self.coco_evals[0]["segm"].stats):
  525. outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value
  526. # else:
  527. # total_stats = {}
  528. # for bucket, img_list in self.rarity_buckets.items():
  529. # self.accumulate(imgIds=img_list)
  530. # bucket_name = RARITY_BUCKETS[bucket]
  531. # for iou_type, coco_eval in self.coco_evals[0].items():
  532. # print(
  533. # "Demo metric, IoU type={}, Rarity bucket={}".format(
  534. # iou_type, bucket_name
  535. # )
  536. # )
  537. # coco_eval.summarize()
  538. # if "bbox" in self.coco_evals[0]:
  539. # if "bbox" not in total_stats:
  540. # total_stats["bbox"] = np.zeros_like(
  541. # self.coco_evals[0]["bbox"].stats
  542. # )
  543. # total_stats["bbox"] += self.coco_evals[0]["bbox"].stats
  544. # for i, value in enumerate(self.coco_evals[0]["bbox"].stats):
  545. # outs[
  546. # f"coco_eval_bbox_{bucket_name}_{prefix}{DEMO_METRICS[i]}"
  547. # ] = value
  548. # if "segm" in self.coco_evals[0]:
  549. # if "segm" not in total_stats:
  550. # total_stats["segm"] = np.zeros_like(
  551. # self.coco_evals[0]["segm"].stats
  552. # )
  553. # total_stats["segm"] += self.coco_evals[0]["segm"].stats
  554. # for i, value in enumerate(self.coco_evals[0]["segm"].stats):
  555. # outs[
  556. # f"coco_eval_masks_{bucket_name}_{prefix}{DEMO_METRICS[i]}"
  557. # ] = value
  558. # if "bbox" in total_stats:
  559. # total_stats["bbox"] /= len(self.rarity_buckets)
  560. # for i, value in enumerate(total_stats["bbox"]):
  561. # outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value
  562. # if "segm" in total_stats:
  563. # total_stats["segm"] /= len(self.rarity_buckets)
  564. # for i, value in enumerate(total_stats["segm"]):
  565. # outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value
  566. return outs
  567. def accumulate(self, imgIds=None):
  568. self._lazy_init()
  569. logging.info(
  570. f"demo evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images"
  571. )
  572. if not is_main_process():
  573. return
  574. if imgIds is not None:
  575. for coco_eval in self.coco_evals[0].values():
  576. coco_eval.params.imgIds = list(imgIds)
  577. for coco_eval in self.coco_evals[0].values():
  578. coco_eval.accumulate()
  579. def reset(self):
  580. self.coco_evals = [{} for _ in range(len(self.coco_gts))]
  581. for i, coco_gt in enumerate(self.coco_gts):
  582. for iou_type in self.iou_types:
  583. self.coco_evals[i][iou_type] = DemoEval(
  584. coco_gt=coco_gt,
  585. iouType=iou_type,
  586. threshold=self.threshold,
  587. compute_JnF=self.compute_JnF,
  588. )
  589. self.coco_evals[i][iou_type].useCats = False
  590. self.img_ids = []
  591. self.eval_imgs = {k: [] for k in self.iou_types}
  592. if self.dump is not None:
  593. self.dump = []