| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398 |
- # flake8: noqa
- # pyre-unsafe
- import os
- import time
- import traceback
- from functools import partial
- from multiprocessing.pool import Pool
- import numpy as np
- from . import _timing, utils
- from .metrics import Count
- from .utils import TrackEvalException
- try:
- import tqdm
- TQDM_IMPORTED = True
- except ImportError as _:
- TQDM_IMPORTED = False
- class Evaluator:
- """Evaluator class for evaluating different metrics for different datasets"""
- @staticmethod
- def get_default_eval_config():
- """Returns the default config values for evaluation"""
- code_path = utils.get_code_path()
- default_config = {
- "USE_PARALLEL": False,
- "NUM_PARALLEL_CORES": 8,
- "BREAK_ON_ERROR": True, # Raises exception and exits with error
- "RETURN_ON_ERROR": False, # if not BREAK_ON_ERROR, then returns from function on error
- "LOG_ON_ERROR": os.path.join(
- code_path, "error_log.txt"
- ), # if not None, save any errors into a log file.
- "PRINT_RESULTS": True,
- "PRINT_ONLY_COMBINED": False,
- "PRINT_CONFIG": True,
- "TIME_PROGRESS": True,
- "DISPLAY_LESS_PROGRESS": True,
- "OUTPUT_SUMMARY": True,
- "OUTPUT_EMPTY_CLASSES": True, # If False, summary files are not output for classes with no detections
- "OUTPUT_DETAILED": True,
- "PLOT_CURVES": True,
- }
- return default_config
- def __init__(self, config=None):
- """Initialise the evaluator with a config file"""
- self.config = utils.init_config(config, self.get_default_eval_config(), "Eval")
- # Only run timing analysis if not run in parallel.
- if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]:
- _timing.DO_TIMING = True
- if self.config["DISPLAY_LESS_PROGRESS"]:
- _timing.DISPLAY_LESS_PROGRESS = True
- def _combine_results(
- self,
- res,
- metrics_list,
- metric_names,
- dataset,
- res_field="COMBINED_SEQ",
- target_tag=None,
- ):
- assert res_field.startswith("COMBINED_SEQ")
- # collecting combined cls keys (cls averaged, det averaged, super classes)
- tracker_list, seq_list, class_list = dataset.get_eval_info()
- combined_cls_keys = []
- res[res_field] = {}
- # narrow the target for evaluation
- if target_tag is not None:
- target_video_ids = [
- annot["video_id"]
- for annot in dataset.gt_data["annotations"]
- if target_tag in annot["tags"]
- ]
- vid2name = {
- video["id"]: video["file_names"][0].split("/")[0]
- for video in dataset.gt_data["videos"]
- }
- target_video_ids = set(target_video_ids)
- target_video = [vid2name[video_id] for video_id in target_video_ids]
- if len(target_video) == 0:
- raise TrackEvalException(
- "No sequences found with the tag %s" % target_tag
- )
- target_annotations = [
- annot
- for annot in dataset.gt_data["annotations"]
- if annot["video_id"] in target_video_ids
- ]
- assert all(target_tag in annot["tags"] for annot in target_annotations), (
- f"Not all annotations in the target sequences have the target tag {target_tag}. "
- "We currently only support a target tag at the sequence level, not at the annotation level."
- )
- else:
- target_video = seq_list
- # combine sequences for each class
- for c_cls in class_list:
- res[res_field][c_cls] = {}
- for metric, metric_name in zip(metrics_list, metric_names):
- curr_res = {
- seq_key: seq_value[c_cls][metric_name]
- for seq_key, seq_value in res.items()
- if not seq_key.startswith("COMBINED_SEQ")
- and seq_key in target_video
- }
- res[res_field][c_cls][metric_name] = metric.combine_sequences(curr_res)
- # combine classes
- if dataset.should_classes_combine:
- combined_cls_keys += [
- "cls_comb_cls_av",
- "cls_comb_det_av",
- "all",
- ]
- res[res_field]["cls_comb_cls_av"] = {}
- res[res_field]["cls_comb_det_av"] = {}
- for metric, metric_name in zip(metrics_list, metric_names):
- cls_res = {
- cls_key: cls_value[metric_name]
- for cls_key, cls_value in res[res_field].items()
- if cls_key not in combined_cls_keys
- }
- res[res_field]["cls_comb_cls_av"][metric_name] = (
- metric.combine_classes_class_averaged(cls_res)
- )
- res[res_field]["cls_comb_det_av"][metric_name] = (
- metric.combine_classes_det_averaged(cls_res)
- )
- # combine classes to super classes
- if dataset.use_super_categories:
- for cat, sub_cats in dataset.super_categories.items():
- combined_cls_keys.append(cat)
- res[res_field][cat] = {}
- for metric, metric_name in zip(metrics_list, metric_names):
- cat_res = {
- cls_key: cls_value[metric_name]
- for cls_key, cls_value in res[res_field].items()
- if cls_key in sub_cats
- }
- res[res_field][cat][metric_name] = (
- metric.combine_classes_det_averaged(cat_res)
- )
- return res, combined_cls_keys
- def _summarize_results(
- self,
- res,
- tracker,
- metrics_list,
- metric_names,
- dataset,
- res_field,
- combined_cls_keys,
- ):
- config = self.config
- output_fol = dataset.get_output_fol(tracker)
- tracker_display_name = dataset.get_display_name(tracker)
- for c_cls in res[
- res_field
- ].keys(): # class_list + combined classes if calculated
- summaries = []
- details = []
- num_dets = res[res_field][c_cls]["Count"]["Dets"]
- if config["OUTPUT_EMPTY_CLASSES"] or num_dets > 0:
- for metric, metric_name in zip(metrics_list, metric_names):
- # for combined classes there is no per sequence evaluation
- if c_cls in combined_cls_keys:
- table_res = {res_field: res[res_field][c_cls][metric_name]}
- else:
- table_res = {
- seq_key: seq_value[c_cls][metric_name]
- for seq_key, seq_value in res.items()
- }
- if config["PRINT_RESULTS"] and config["PRINT_ONLY_COMBINED"]:
- dont_print = (
- dataset.should_classes_combine
- and c_cls not in combined_cls_keys
- )
- if not dont_print:
- metric.print_table(
- {res_field: table_res[res_field]},
- tracker_display_name,
- c_cls,
- res_field,
- res_field,
- )
- elif config["PRINT_RESULTS"]:
- metric.print_table(
- table_res, tracker_display_name, c_cls, res_field, res_field
- )
- if config["OUTPUT_SUMMARY"]:
- summaries.append(metric.summary_results(table_res))
- if config["OUTPUT_DETAILED"]:
- details.append(metric.detailed_results(table_res))
- if config["PLOT_CURVES"]:
- metric.plot_single_tracker_results(
- table_res,
- tracker_display_name,
- c_cls,
- output_fol,
- )
- if config["OUTPUT_SUMMARY"]:
- utils.write_summary_results(summaries, c_cls, output_fol)
- if config["OUTPUT_DETAILED"]:
- utils.write_detailed_results(details, c_cls, output_fol)
- @_timing.time
- def evaluate(self, dataset_list, metrics_list, show_progressbar=False):
- """Evaluate a set of metrics on a set of datasets"""
- config = self.config
- metrics_list = metrics_list + [Count()] # Count metrics are always run
- metric_names = utils.validate_metrics_list(metrics_list)
- dataset_names = [dataset.get_name() for dataset in dataset_list]
- output_res = {}
- output_msg = {}
- for dataset, dataset_name in zip(dataset_list, dataset_names):
- # Get dataset info about what to evaluate
- output_res[dataset_name] = {}
- output_msg[dataset_name] = {}
- tracker_list, seq_list, class_list = dataset.get_eval_info()
- print(
- "\nEvaluating %i tracker(s) on %i sequence(s) for %i class(es) on %s dataset using the following "
- "metrics: %s\n"
- % (
- len(tracker_list),
- len(seq_list),
- len(class_list),
- dataset_name,
- ", ".join(metric_names),
- )
- )
- # Evaluate each tracker
- for tracker in tracker_list:
- # if not config['BREAK_ON_ERROR'] then go to next tracker without breaking
- try:
- # Evaluate each sequence in parallel or in series.
- # returns a nested dict (res), indexed like: res[seq][class][metric_name][sub_metric field]
- # e.g. res[seq_0001][pedestrian][hota][DetA]
- print("\nEvaluating %s\n" % tracker)
- time_start = time.time()
- if config["USE_PARALLEL"]:
- if show_progressbar and TQDM_IMPORTED:
- seq_list_sorted = sorted(seq_list)
- with (
- Pool(config["NUM_PARALLEL_CORES"]) as pool,
- tqdm.tqdm(total=len(seq_list)) as pbar,
- ):
- _eval_sequence = partial(
- eval_sequence,
- dataset=dataset,
- tracker=tracker,
- class_list=class_list,
- metrics_list=metrics_list,
- metric_names=metric_names,
- )
- results = []
- for r in pool.imap(
- _eval_sequence, seq_list_sorted, chunksize=20
- ):
- results.append(r)
- pbar.update()
- res = dict(zip(seq_list_sorted, results))
- else:
- with Pool(config["NUM_PARALLEL_CORES"]) as pool:
- _eval_sequence = partial(
- eval_sequence,
- dataset=dataset,
- tracker=tracker,
- class_list=class_list,
- metrics_list=metrics_list,
- metric_names=metric_names,
- )
- results = pool.map(_eval_sequence, seq_list)
- res = dict(zip(seq_list, results))
- else:
- res = {}
- if show_progressbar and TQDM_IMPORTED:
- seq_list_sorted = sorted(seq_list)
- for curr_seq in tqdm.tqdm(seq_list_sorted):
- res[curr_seq] = eval_sequence(
- curr_seq,
- dataset,
- tracker,
- class_list,
- metrics_list,
- metric_names,
- )
- else:
- for curr_seq in sorted(seq_list):
- res[curr_seq] = eval_sequence(
- curr_seq,
- dataset,
- tracker,
- class_list,
- metrics_list,
- metric_names,
- )
- # Combine results over all sequences and then over all classes
- res, combined_cls_keys = self._combine_results(
- res, metrics_list, metric_names, dataset, "COMBINED_SEQ"
- )
- if np.all(
- ["tags" in annot for annot in dataset.gt_data["annotations"]]
- ):
- # Combine results over the challenging sequences and then over all classes
- # currently only support "tracking_challenging_pair"
- res, _ = self._combine_results(
- res,
- metrics_list,
- metric_names,
- dataset,
- "COMBINED_SEQ_CHALLENGING",
- "tracking_challenging_pair",
- )
- # Print and output results in various formats
- if config["TIME_PROGRESS"]:
- print(
- "\nAll sequences for %s finished in %.2f seconds"
- % (tracker, time.time() - time_start)
- )
- self._summarize_results(
- res,
- tracker,
- metrics_list,
- metric_names,
- dataset,
- "COMBINED_SEQ",
- combined_cls_keys,
- )
- if "COMBINED_SEQ_CHALLENGING" in res:
- self._summarize_results(
- res,
- tracker,
- metrics_list,
- metric_names,
- dataset,
- "COMBINED_SEQ_CHALLENGING",
- combined_cls_keys,
- )
- # Output for returning from function
- output_res[dataset_name][tracker] = res
- output_msg[dataset_name][tracker] = "Success"
- except Exception as err:
- output_res[dataset_name][tracker] = None
- if type(err) == TrackEvalException:
- output_msg[dataset_name][tracker] = str(err)
- else:
- output_msg[dataset_name][tracker] = "Unknown error occurred."
- print("Tracker %s was unable to be evaluated." % tracker)
- print(err)
- traceback.print_exc()
- if config["LOG_ON_ERROR"] is not None:
- with open(config["LOG_ON_ERROR"], "a") as f:
- print(dataset_name, file=f)
- print(tracker, file=f)
- print(traceback.format_exc(), file=f)
- print("\n\n\n", file=f)
- if config["BREAK_ON_ERROR"]:
- raise err
- elif config["RETURN_ON_ERROR"]:
- return output_res, output_msg
- return output_res, output_msg
- @_timing.time
- def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names):
- """Function for evaluating a single sequence"""
- raw_data = dataset.get_raw_seq_data(tracker, seq)
- seq_res = {}
- for cls in class_list:
- seq_res[cls] = {}
- data = dataset.get_preprocessed_seq_data(raw_data, cls)
- for metric, met_name in zip(metrics_list, metric_names):
- seq_res[cls][met_name] = metric.eval_sequence(data)
- return seq_res
|