eval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # flake8: noqa
  2. # pyre-unsafe
  3. import os
  4. import time
  5. import traceback
  6. from functools import partial
  7. from multiprocessing.pool import Pool
  8. import numpy as np
  9. from . import _timing, utils
  10. from .metrics import Count
  11. from .utils import TrackEvalException
  12. try:
  13. import tqdm
  14. TQDM_IMPORTED = True
  15. except ImportError as _:
  16. TQDM_IMPORTED = False
  17. class Evaluator:
  18. """Evaluator class for evaluating different metrics for different datasets"""
  19. @staticmethod
  20. def get_default_eval_config():
  21. """Returns the default config values for evaluation"""
  22. code_path = utils.get_code_path()
  23. default_config = {
  24. "USE_PARALLEL": False,
  25. "NUM_PARALLEL_CORES": 8,
  26. "BREAK_ON_ERROR": True, # Raises exception and exits with error
  27. "RETURN_ON_ERROR": False, # if not BREAK_ON_ERROR, then returns from function on error
  28. "LOG_ON_ERROR": os.path.join(
  29. code_path, "error_log.txt"
  30. ), # if not None, save any errors into a log file.
  31. "PRINT_RESULTS": True,
  32. "PRINT_ONLY_COMBINED": False,
  33. "PRINT_CONFIG": True,
  34. "TIME_PROGRESS": True,
  35. "DISPLAY_LESS_PROGRESS": True,
  36. "OUTPUT_SUMMARY": True,
  37. "OUTPUT_EMPTY_CLASSES": True, # If False, summary files are not output for classes with no detections
  38. "OUTPUT_DETAILED": True,
  39. "PLOT_CURVES": True,
  40. }
  41. return default_config
  42. def __init__(self, config=None):
  43. """Initialise the evaluator with a config file"""
  44. self.config = utils.init_config(config, self.get_default_eval_config(), "Eval")
  45. # Only run timing analysis if not run in parallel.
  46. if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]:
  47. _timing.DO_TIMING = True
  48. if self.config["DISPLAY_LESS_PROGRESS"]:
  49. _timing.DISPLAY_LESS_PROGRESS = True
  50. def _combine_results(
  51. self,
  52. res,
  53. metrics_list,
  54. metric_names,
  55. dataset,
  56. res_field="COMBINED_SEQ",
  57. target_tag=None,
  58. ):
  59. assert res_field.startswith("COMBINED_SEQ")
  60. # collecting combined cls keys (cls averaged, det averaged, super classes)
  61. tracker_list, seq_list, class_list = dataset.get_eval_info()
  62. combined_cls_keys = []
  63. res[res_field] = {}
  64. # narrow the target for evaluation
  65. if target_tag is not None:
  66. target_video_ids = [
  67. annot["video_id"]
  68. for annot in dataset.gt_data["annotations"]
  69. if target_tag in annot["tags"]
  70. ]
  71. vid2name = {
  72. video["id"]: video["file_names"][0].split("/")[0]
  73. for video in dataset.gt_data["videos"]
  74. }
  75. target_video_ids = set(target_video_ids)
  76. target_video = [vid2name[video_id] for video_id in target_video_ids]
  77. if len(target_video) == 0:
  78. raise TrackEvalException(
  79. "No sequences found with the tag %s" % target_tag
  80. )
  81. target_annotations = [
  82. annot
  83. for annot in dataset.gt_data["annotations"]
  84. if annot["video_id"] in target_video_ids
  85. ]
  86. assert all(target_tag in annot["tags"] for annot in target_annotations), (
  87. f"Not all annotations in the target sequences have the target tag {target_tag}. "
  88. "We currently only support a target tag at the sequence level, not at the annotation level."
  89. )
  90. else:
  91. target_video = seq_list
  92. # combine sequences for each class
  93. for c_cls in class_list:
  94. res[res_field][c_cls] = {}
  95. for metric, metric_name in zip(metrics_list, metric_names):
  96. curr_res = {
  97. seq_key: seq_value[c_cls][metric_name]
  98. for seq_key, seq_value in res.items()
  99. if not seq_key.startswith("COMBINED_SEQ")
  100. and seq_key in target_video
  101. }
  102. res[res_field][c_cls][metric_name] = metric.combine_sequences(curr_res)
  103. # combine classes
  104. if dataset.should_classes_combine:
  105. combined_cls_keys += [
  106. "cls_comb_cls_av",
  107. "cls_comb_det_av",
  108. "all",
  109. ]
  110. res[res_field]["cls_comb_cls_av"] = {}
  111. res[res_field]["cls_comb_det_av"] = {}
  112. for metric, metric_name in zip(metrics_list, metric_names):
  113. cls_res = {
  114. cls_key: cls_value[metric_name]
  115. for cls_key, cls_value in res[res_field].items()
  116. if cls_key not in combined_cls_keys
  117. }
  118. res[res_field]["cls_comb_cls_av"][metric_name] = (
  119. metric.combine_classes_class_averaged(cls_res)
  120. )
  121. res[res_field]["cls_comb_det_av"][metric_name] = (
  122. metric.combine_classes_det_averaged(cls_res)
  123. )
  124. # combine classes to super classes
  125. if dataset.use_super_categories:
  126. for cat, sub_cats in dataset.super_categories.items():
  127. combined_cls_keys.append(cat)
  128. res[res_field][cat] = {}
  129. for metric, metric_name in zip(metrics_list, metric_names):
  130. cat_res = {
  131. cls_key: cls_value[metric_name]
  132. for cls_key, cls_value in res[res_field].items()
  133. if cls_key in sub_cats
  134. }
  135. res[res_field][cat][metric_name] = (
  136. metric.combine_classes_det_averaged(cat_res)
  137. )
  138. return res, combined_cls_keys
  139. def _summarize_results(
  140. self,
  141. res,
  142. tracker,
  143. metrics_list,
  144. metric_names,
  145. dataset,
  146. res_field,
  147. combined_cls_keys,
  148. ):
  149. config = self.config
  150. output_fol = dataset.get_output_fol(tracker)
  151. tracker_display_name = dataset.get_display_name(tracker)
  152. for c_cls in res[
  153. res_field
  154. ].keys(): # class_list + combined classes if calculated
  155. summaries = []
  156. details = []
  157. num_dets = res[res_field][c_cls]["Count"]["Dets"]
  158. if config["OUTPUT_EMPTY_CLASSES"] or num_dets > 0:
  159. for metric, metric_name in zip(metrics_list, metric_names):
  160. # for combined classes there is no per sequence evaluation
  161. if c_cls in combined_cls_keys:
  162. table_res = {res_field: res[res_field][c_cls][metric_name]}
  163. else:
  164. table_res = {
  165. seq_key: seq_value[c_cls][metric_name]
  166. for seq_key, seq_value in res.items()
  167. }
  168. if config["PRINT_RESULTS"] and config["PRINT_ONLY_COMBINED"]:
  169. dont_print = (
  170. dataset.should_classes_combine
  171. and c_cls not in combined_cls_keys
  172. )
  173. if not dont_print:
  174. metric.print_table(
  175. {res_field: table_res[res_field]},
  176. tracker_display_name,
  177. c_cls,
  178. res_field,
  179. res_field,
  180. )
  181. elif config["PRINT_RESULTS"]:
  182. metric.print_table(
  183. table_res, tracker_display_name, c_cls, res_field, res_field
  184. )
  185. if config["OUTPUT_SUMMARY"]:
  186. summaries.append(metric.summary_results(table_res))
  187. if config["OUTPUT_DETAILED"]:
  188. details.append(metric.detailed_results(table_res))
  189. if config["PLOT_CURVES"]:
  190. metric.plot_single_tracker_results(
  191. table_res,
  192. tracker_display_name,
  193. c_cls,
  194. output_fol,
  195. )
  196. if config["OUTPUT_SUMMARY"]:
  197. utils.write_summary_results(summaries, c_cls, output_fol)
  198. if config["OUTPUT_DETAILED"]:
  199. utils.write_detailed_results(details, c_cls, output_fol)
  200. @_timing.time
  201. def evaluate(self, dataset_list, metrics_list, show_progressbar=False):
  202. """Evaluate a set of metrics on a set of datasets"""
  203. config = self.config
  204. metrics_list = metrics_list + [Count()] # Count metrics are always run
  205. metric_names = utils.validate_metrics_list(metrics_list)
  206. dataset_names = [dataset.get_name() for dataset in dataset_list]
  207. output_res = {}
  208. output_msg = {}
  209. for dataset, dataset_name in zip(dataset_list, dataset_names):
  210. # Get dataset info about what to evaluate
  211. output_res[dataset_name] = {}
  212. output_msg[dataset_name] = {}
  213. tracker_list, seq_list, class_list = dataset.get_eval_info()
  214. print(
  215. "\nEvaluating %i tracker(s) on %i sequence(s) for %i class(es) on %s dataset using the following "
  216. "metrics: %s\n"
  217. % (
  218. len(tracker_list),
  219. len(seq_list),
  220. len(class_list),
  221. dataset_name,
  222. ", ".join(metric_names),
  223. )
  224. )
  225. # Evaluate each tracker
  226. for tracker in tracker_list:
  227. # if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
  228. try:
  229. # Evaluate each sequence in parallel or in series.
  230. # returns a nested dict (res), indexed like: res[seq][class][metric_name][sub_metric field]
  231. # e.g. res[seq_0001][pedestrian][hota][DetA]
  232. print("\nEvaluating %s\n" % tracker)
  233. time_start = time.time()
  234. if config["USE_PARALLEL"]:
  235. if show_progressbar and TQDM_IMPORTED:
  236. seq_list_sorted = sorted(seq_list)
  237. with (
  238. Pool(config["NUM_PARALLEL_CORES"]) as pool,
  239. tqdm.tqdm(total=len(seq_list)) as pbar,
  240. ):
  241. _eval_sequence = partial(
  242. eval_sequence,
  243. dataset=dataset,
  244. tracker=tracker,
  245. class_list=class_list,
  246. metrics_list=metrics_list,
  247. metric_names=metric_names,
  248. )
  249. results = []
  250. for r in pool.imap(
  251. _eval_sequence, seq_list_sorted, chunksize=20
  252. ):
  253. results.append(r)
  254. pbar.update()
  255. res = dict(zip(seq_list_sorted, results))
  256. else:
  257. with Pool(config["NUM_PARALLEL_CORES"]) as pool:
  258. _eval_sequence = partial(
  259. eval_sequence,
  260. dataset=dataset,
  261. tracker=tracker,
  262. class_list=class_list,
  263. metrics_list=metrics_list,
  264. metric_names=metric_names,
  265. )
  266. results = pool.map(_eval_sequence, seq_list)
  267. res = dict(zip(seq_list, results))
  268. else:
  269. res = {}
  270. if show_progressbar and TQDM_IMPORTED:
  271. seq_list_sorted = sorted(seq_list)
  272. for curr_seq in tqdm.tqdm(seq_list_sorted):
  273. res[curr_seq] = eval_sequence(
  274. curr_seq,
  275. dataset,
  276. tracker,
  277. class_list,
  278. metrics_list,
  279. metric_names,
  280. )
  281. else:
  282. for curr_seq in sorted(seq_list):
  283. res[curr_seq] = eval_sequence(
  284. curr_seq,
  285. dataset,
  286. tracker,
  287. class_list,
  288. metrics_list,
  289. metric_names,
  290. )
  291. # Combine results over all sequences and then over all classes
  292. res, combined_cls_keys = self._combine_results(
  293. res, metrics_list, metric_names, dataset, "COMBINED_SEQ"
  294. )
  295. if np.all(
  296. ["tags" in annot for annot in dataset.gt_data["annotations"]]
  297. ):
  298. # Combine results over the challenging sequences and then over all classes
  299. # currently only support "tracking_challenging_pair"
  300. res, _ = self._combine_results(
  301. res,
  302. metrics_list,
  303. metric_names,
  304. dataset,
  305. "COMBINED_SEQ_CHALLENGING",
  306. "tracking_challenging_pair",
  307. )
  308. # Print and output results in various formats
  309. if config["TIME_PROGRESS"]:
  310. print(
  311. "\nAll sequences for %s finished in %.2f seconds"
  312. % (tracker, time.time() - time_start)
  313. )
  314. self._summarize_results(
  315. res,
  316. tracker,
  317. metrics_list,
  318. metric_names,
  319. dataset,
  320. "COMBINED_SEQ",
  321. combined_cls_keys,
  322. )
  323. if "COMBINED_SEQ_CHALLENGING" in res:
  324. self._summarize_results(
  325. res,
  326. tracker,
  327. metrics_list,
  328. metric_names,
  329. dataset,
  330. "COMBINED_SEQ_CHALLENGING",
  331. combined_cls_keys,
  332. )
  333. # Output for returning from function
  334. output_res[dataset_name][tracker] = res
  335. output_msg[dataset_name][tracker] = "Success"
  336. except Exception as err:
  337. output_res[dataset_name][tracker] = None
  338. if type(err) == TrackEvalException:
  339. output_msg[dataset_name][tracker] = str(err)
  340. else:
  341. output_msg[dataset_name][tracker] = "Unknown error occurred."
  342. print("Tracker %s was unable to be evaluated." % tracker)
  343. print(err)
  344. traceback.print_exc()
  345. if config["LOG_ON_ERROR"] is not None:
  346. with open(config["LOG_ON_ERROR"], "a") as f:
  347. print(dataset_name, file=f)
  348. print(tracker, file=f)
  349. print(traceback.format_exc(), file=f)
  350. print("\n\n\n", file=f)
  351. if config["BREAK_ON_ERROR"]:
  352. raise err
  353. elif config["RETURN_ON_ERROR"]:
  354. return output_res, output_msg
  355. return output_res, output_msg
  356. @_timing.time
  357. def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
  358. """Function for evaluating a single sequence"""
  359. raw_data = dataset.get_raw_seq_data(tracker, seq)
  360. seq_res = {}
  361. for cls in class_list:
  362. seq_res[cls] = {}
  363. data = dataset.get_preprocessed_seq_data(raw_data, cls)
  364. for metric, met_name in zip(metrics_list, metric_names):
  365. seq_res[cls][met_name] = metric.eval_sequence(data)
  366. return seq_res