teta.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # fmt: off
  2. # flake8: noqa
  3. # pyre-unsafe
  4. """Track Every Thing Accuracy metric."""
  5. import numpy as np
  6. from scipy.optimize import linear_sum_assignment
  7. from .. import _timing
  8. from ._base_metric import _BaseMetric
  9. EPS = np.finfo("float").eps # epsilon
  10. class TETA(_BaseMetric):
  11. """TETA metric."""
  12. def __init__(self, exhaustive=False, config=None):
  13. """Initialize metric."""
  14. super().__init__()
  15. self.plottable = True
  16. self.array_labels = np.arange(0.0, 0.99, 0.05)
  17. self.cls_array_labels = np.arange(0.5, 0.99, 0.05)
  18. self.integer_array_fields = [
  19. "Loc_TP",
  20. "Loc_FN",
  21. "Loc_FP",
  22. "Cls_TP",
  23. "Cls_FN",
  24. "Cls_FP",
  25. ]
  26. self.float_array_fields = (
  27. ["TETA", "LocA", "AssocA", "ClsA"]
  28. + ["LocRe", "LocPr"]
  29. + ["AssocRe", "AssocPr"]
  30. + ["ClsRe", "ClsPr"]
  31. )
  32. self.fields = self.float_array_fields + self.integer_array_fields
  33. self.summary_fields = self.float_array_fields
  34. self.exhaustive = exhaustive
  35. def compute_global_assignment(self, data_thr, alpha=0.5):
  36. """Compute global assignment of TP."""
  37. res = {
  38. thr: {t: {} for t in range(data_thr[thr]["num_timesteps"])}
  39. for thr in data_thr
  40. }
  41. for thr in data_thr:
  42. data = data_thr[thr]
  43. # return empty result if tracker or gt sequence is empty
  44. if data["num_tk_overlap_dets"] == 0 or data["num_gt_dets"] == 0:
  45. return res
  46. # global alignment score
  47. ga_score, _, _ = self.compute_global_alignment_score(data)
  48. # calculate scores for each timestep
  49. for t, (gt_ids_t, tk_ids_t) in enumerate(
  50. zip(data["gt_ids"], data["tk_ids"])
  51. ):
  52. # get matches optimizing for TETA
  53. amatch_rows, amatch_cols = self.compute_matches(
  54. data, t, ga_score, gt_ids_t, tk_ids_t, alpha=alpha
  55. )
  56. gt_ids = [data["gt_id_map"][tid] for tid in gt_ids_t[amatch_rows[0]]]
  57. matched_ids = [
  58. data["tk_id_map"][tid] for tid in tk_ids_t[amatch_cols[0]]
  59. ]
  60. res[thr][t] = dict(zip(gt_ids, matched_ids))
  61. return res
  62. def eval_sequence_single_thr(self, data, cls, cid2clsname, cls_fp_thr, thr):
  63. """Computes TETA metric for one threshold for one sequence."""
  64. res = {}
  65. class_info_list = []
  66. for field in self.float_array_fields + self.integer_array_fields:
  67. if field.startswith("Cls"):
  68. res[field] = np.zeros(len(self.cls_array_labels), dtype=float)
  69. else:
  70. res[field] = np.zeros((len(self.array_labels)), dtype=float)
  71. # return empty result if tracker or gt sequence is empty
  72. if data["num_tk_overlap_dets"] == 0:
  73. res["Loc_FN"] = data["num_gt_dets"] * np.ones(
  74. (len(self.array_labels)), dtype=float
  75. )
  76. if self.exhaustive:
  77. cls_fp_thr[cls] = data["num_tk_cls_dets"] * np.ones(
  78. (len(self.cls_array_labels)), dtype=float
  79. )
  80. res = self._compute_final_fields(res)
  81. return res, cls_fp_thr, class_info_list
  82. if data["num_gt_dets"] == 0:
  83. if self.exhaustive:
  84. cls_fp_thr[cls] = data["num_tk_cls_dets"] * np.ones(
  85. (len(self.cls_array_labels)), dtype=float
  86. )
  87. res = self._compute_final_fields(res)
  88. return res, cls_fp_thr, class_info_list
  89. # global alignment score
  90. ga_score, gt_id_count, tk_id_count = self.compute_global_alignment_score(data)
  91. matches_counts = [np.zeros_like(ga_score) for _ in self.array_labels]
  92. # calculate scores for each timestep
  93. for t, (gt_ids_t, tk_ids_t, tk_overlap_ids_t, tk_cls_ids_t) in enumerate(
  94. zip(
  95. data["gt_ids"],
  96. data["tk_ids"],
  97. data["tk_overlap_ids"],
  98. data["tk_class_eval_tk_ids"],
  99. )
  100. ):
  101. # deal with the case that there are no gt_det/tk_det in a timestep
  102. if len(gt_ids_t) == 0:
  103. if self.exhaustive:
  104. cls_fp_thr[cls] += len(tk_cls_ids_t)
  105. continue
  106. # get matches optimizing for TETA
  107. amatch_rows, amatch_cols = self.compute_matches(
  108. data, t, ga_score, gt_ids_t, tk_ids_t, list(self.array_labels)
  109. )
  110. # map overlap_ids to original ids.
  111. if len(tk_overlap_ids_t) != 0:
  112. sorter = np.argsort(tk_ids_t)
  113. indexes = sorter[
  114. np.searchsorted(tk_ids_t, tk_overlap_ids_t, sorter=sorter)
  115. ]
  116. sim_t = data["sim_scores"][t][:, indexes]
  117. fpl_candidates = tk_overlap_ids_t[(sim_t >= (thr / 100)).any(axis=0)]
  118. fpl_candidates_ori_ids_t = np.array(
  119. [data["tk_id_map"][tid] for tid in fpl_candidates]
  120. )
  121. else:
  122. fpl_candidates_ori_ids_t = []
  123. if self.exhaustive:
  124. cls_fp_thr[cls] += len(tk_cls_ids_t) - len(tk_overlap_ids_t)
  125. # calculate and accumulate basic statistics
  126. for a, alpha in enumerate(self.array_labels):
  127. match_row, match_col = amatch_rows[a], amatch_cols[a]
  128. num_matches = len(match_row)
  129. matched_ori_ids = set(
  130. [data["tk_id_map"][tid] for tid in tk_ids_t[match_col]]
  131. )
  132. match_tk_cls = data["tk_classes"][t][match_col]
  133. wrong_tk_cls = match_tk_cls[match_tk_cls != data["gt_classes"][t]]
  134. num_class_and_det_matches = np.sum(
  135. match_tk_cls == data["gt_classes"][t]
  136. )
  137. if alpha >= 0.5:
  138. for cid in wrong_tk_cls:
  139. if cid in cid2clsname:
  140. cname = cid2clsname[cid]
  141. cls_fp_thr[cname][a - 10] += 1
  142. res["Cls_TP"][a - 10] += num_class_and_det_matches
  143. res["Cls_FN"][a - 10] += num_matches - num_class_and_det_matches
  144. res["Loc_TP"][a] += num_matches
  145. res["Loc_FN"][a] += len(gt_ids_t) - num_matches
  146. res["Loc_FP"][a] += len(set(fpl_candidates_ori_ids_t) - matched_ori_ids)
  147. if num_matches > 0:
  148. matches_counts[a][gt_ids_t[match_row], tk_ids_t[match_col]] += 1
  149. # calculate AssocA, AssocRe, AssocPr
  150. self.compute_association_scores(res, matches_counts, gt_id_count, tk_id_count)
  151. # calculate final scores
  152. res = self._compute_final_fields(res)
  153. return res, cls_fp_thr, class_info_list
  154. def compute_global_alignment_score(self, data):
  155. """Computes global alignment score."""
  156. num_matches = np.zeros((data["num_gt_ids"], data["num_tk_ids"]))
  157. gt_id_count = np.zeros((data["num_gt_ids"], 1))
  158. tk_id_count = np.zeros((1, data["num_tk_ids"]))
  159. # loop through each timestep and accumulate global track info.
  160. for t, (gt_ids_t, tk_ids_t) in enumerate(zip(data["gt_ids"], data["tk_ids"])):
  161. # count potential matches between ids in each time step
  162. # these are normalized, weighted by match similarity
  163. sim = data["sim_scores"][t]
  164. sim_iou_denom = sim.sum(0, keepdims=True) + sim.sum(1, keepdims=True) - sim
  165. sim_iou = np.zeros_like(sim)
  166. mask = sim_iou_denom > (0 + EPS)
  167. sim_iou[mask] = sim[mask] / sim_iou_denom[mask]
  168. num_matches[gt_ids_t[:, None], tk_ids_t[None, :]] += sim_iou
  169. # calculate total number of dets for each gt_id and tk_id.
  170. gt_id_count[gt_ids_t] += 1
  171. tk_id_count[0, tk_ids_t] += 1
  172. # Calculate overall Jaccard alignment score between IDs
  173. ga_score = num_matches / (gt_id_count + tk_id_count - num_matches)
  174. return ga_score, gt_id_count, tk_id_count
  175. def compute_matches(self, data, t, ga_score, gt_ids, tk_ids, alpha):
  176. """Compute matches based on alignment score."""
  177. sim = data["sim_scores"][t]
  178. score_mat = ga_score[gt_ids[:, None], tk_ids[None, :]] * sim
  179. # Hungarian algorithm to find best matches
  180. match_rows, match_cols = linear_sum_assignment(-score_mat)
  181. if not isinstance(alpha, list):
  182. alpha = [alpha]
  183. alpha_match_rows, alpha_match_cols = [], []
  184. for a in alpha:
  185. matched_mask = sim[match_rows, match_cols] >= a - EPS
  186. alpha_match_rows.append(match_rows[matched_mask])
  187. alpha_match_cols.append(match_cols[matched_mask])
  188. return alpha_match_rows, alpha_match_cols
  189. def compute_association_scores(self, res, matches_counts, gt_id_count, tk_id_count):
  190. """Calculate association scores for each alpha.
  191. First calculate scores per gt_id/tk_id combo,
  192. and then average over the number of detections.
  193. """
  194. for a, _ in enumerate(self.array_labels):
  195. matches_count = matches_counts[a]
  196. ass_a = matches_count / np.maximum(
  197. 1, gt_id_count + tk_id_count - matches_count
  198. )
  199. res["AssocA"][a] = np.sum(matches_count * ass_a) / np.maximum(
  200. 1, res["Loc_TP"][a]
  201. )
  202. ass_re = matches_count / np.maximum(1, gt_id_count)
  203. res["AssocRe"][a] = np.sum(matches_count * ass_re) / np.maximum(
  204. 1, res["Loc_TP"][a]
  205. )
  206. ass_pr = matches_count / np.maximum(1, tk_id_count)
  207. res["AssocPr"][a] = np.sum(matches_count * ass_pr) / np.maximum(
  208. 1, res["Loc_TP"][a]
  209. )
  210. @_timing.time
  211. def eval_sequence(self, data, cls, cls_id_name_mapping, cls_fp):
  212. """Evaluate a single sequence across all thresholds."""
  213. res = {}
  214. class_info_dict = {}
  215. for thr in data:
  216. res[thr], cls_fp[thr], cls_info = self.eval_sequence_single_thr(
  217. data[thr], cls, cls_id_name_mapping, cls_fp[thr], thr
  218. )
  219. class_info_dict[thr] = cls_info
  220. return res, cls_fp, class_info_dict
  221. def combine_sequences(self, all_res):
  222. """Combines metrics across all sequences."""
  223. data = {}
  224. res = {}
  225. if all_res:
  226. thresholds = list(list(all_res.values())[0].keys())
  227. else:
  228. thresholds = [50]
  229. for thr in thresholds:
  230. data[thr] = {}
  231. for seq_key in all_res:
  232. data[thr][seq_key] = all_res[seq_key][thr]
  233. for thr in thresholds:
  234. res[thr] = self._combine_sequences_thr(data[thr])
  235. return res
  236. def _combine_sequences_thr(self, all_res):
  237. """Combines sequences over each threshold."""
  238. res = {}
  239. for field in self.integer_array_fields:
  240. res[field] = self._combine_sum(all_res, field)
  241. for field in ["AssocRe", "AssocPr", "AssocA"]:
  242. res[field] = self._combine_weighted_av(
  243. all_res, field, res, weight_field="Loc_TP"
  244. )
  245. res = self._compute_final_fields(res)
  246. return res
  247. def combine_classes_class_averaged(self, all_res, ignore_empty=False):
  248. """Combines metrics across all classes by averaging over classes.
  249. If 'ignore_empty' is True, then it only sums over classes
  250. with at least one gt or predicted detection.
  251. """
  252. data = {}
  253. res = {}
  254. if all_res:
  255. thresholds = list(list(all_res.values())[0].keys())
  256. else:
  257. thresholds = [50]
  258. for thr in thresholds:
  259. data[thr] = {}
  260. for cls_key in all_res:
  261. data[thr][cls_key] = all_res[cls_key][thr]
  262. for thr in data:
  263. res[thr] = self._combine_classes_class_averaged_thr(
  264. data[thr], ignore_empty=ignore_empty
  265. )
  266. return res
  267. def _combine_classes_class_averaged_thr(self, all_res, ignore_empty=False):
  268. """Combines classes over each threshold."""
  269. res = {}
  270. def check_empty(val):
  271. """Returns True if empty."""
  272. return not (val["Loc_TP"] + val["Loc_FN"] + val["Loc_FP"] > 0 + EPS).any()
  273. for field in self.integer_array_fields:
  274. if ignore_empty:
  275. res_field = {k: v for k, v in all_res.items() if not check_empty(v)}
  276. else:
  277. res_field = {k: v for k, v in all_res.items()}
  278. res[field] = self._combine_sum(res_field, field)
  279. for field in self.float_array_fields:
  280. if ignore_empty:
  281. res_field = [v[field] for v in all_res.values() if not check_empty(v)]
  282. else:
  283. res_field = [v[field] for v in all_res.values()]
  284. res[field] = np.mean(res_field, axis=0)
  285. return res
  286. def combine_classes_det_averaged(self, all_res):
  287. """Combines metrics across all classes by averaging over detections."""
  288. data = {}
  289. res = {}
  290. if all_res:
  291. thresholds = list(list(all_res.values())[0].keys())
  292. else:
  293. thresholds = [50]
  294. for thr in thresholds:
  295. data[thr] = {}
  296. for cls_key in all_res:
  297. data[thr][cls_key] = all_res[cls_key][thr]
  298. for thr in data:
  299. res[thr] = self._combine_classes_det_averaged_thr(data[thr])
  300. return res
  301. def _combine_classes_det_averaged_thr(self, all_res):
  302. """Combines detections over each threshold."""
  303. res = {}
  304. for field in self.integer_array_fields:
  305. res[field] = self._combine_sum(all_res, field)
  306. for field in ["AssocRe", "AssocPr", "AssocA"]:
  307. res[field] = self._combine_weighted_av(
  308. all_res, field, res, weight_field="Loc_TP"
  309. )
  310. res = self._compute_final_fields(res)
  311. return res
  312. @staticmethod
  313. def _compute_final_fields(res):
  314. """Calculate final metric values.
  315. This function is used both for both per-sequence calculation,
  316. and in combining values across sequences.
  317. """
  318. # LocA
  319. res["LocRe"] = res["Loc_TP"] / np.maximum(1, res["Loc_TP"] + res["Loc_FN"])
  320. res["LocPr"] = res["Loc_TP"] / np.maximum(1, res["Loc_TP"] + res["Loc_FP"])
  321. res["LocA"] = res["Loc_TP"] / np.maximum(
  322. 1, res["Loc_TP"] + res["Loc_FN"] + res["Loc_FP"]
  323. )
  324. # ClsA
  325. res["ClsRe"] = res["Cls_TP"] / np.maximum(1, res["Cls_TP"] + res["Cls_FN"])
  326. res["ClsPr"] = res["Cls_TP"] / np.maximum(1, res["Cls_TP"] + res["Cls_FP"])
  327. res["ClsA"] = res["Cls_TP"] / np.maximum(
  328. 1, res["Cls_TP"] + res["Cls_FN"] + res["Cls_FP"]
  329. )
  330. res["ClsRe"] = np.mean(res["ClsRe"])
  331. res["ClsPr"] = np.mean(res["ClsPr"])
  332. res["ClsA"] = np.mean(res["ClsA"])
  333. res["TETA"] = (res["LocA"] + res["AssocA"] + res["ClsA"]) / 3
  334. return res
  335. def print_summary_table(self, thr_res, thr, tracker, cls):
  336. """Prints summary table of results."""
  337. print("")
  338. metric_name = self.get_name()
  339. self._row_print(
  340. [f"{metric_name}{str(thr)}: {tracker}-{cls}"] + self.summary_fields
  341. )
  342. self._row_print(["COMBINED"] + thr_res)