| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # fmt: off
- # flake8: noqa
- # pyre-unsafe
- from abc import ABC, abstractmethod
- import numpy as np
- from .. import _timing
- from ..utils import TrackEvalException
- class _BaseMetric(ABC):
- @abstractmethod
- def __init__(self):
- self.plottable = False
- self.integer_fields = []
- self.float_fields = []
- self.array_labels = []
- self.integer_array_fields = []
- self.float_array_fields = []
- self.fields = []
- self.summary_fields = []
- self.registered = False
- #####################################################################
- # Abstract functions for subclasses to implement
- @_timing.time
- @abstractmethod
- def eval_sequence(self, data):
- ...
- @abstractmethod
- def combine_sequences(self, all_res):
- ...
- @abstractmethod
- def combine_classes_class_averaged(self, all_res, ignore_empty=False):
- ...
- @abstractmethod
- def combine_classes_det_averaged(self, all_res):
- ...
- def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
- """Plot results, only valid for metrics with self.plottable."""
- if self.plottable:
- raise NotImplementedError(
- f"plot_results is not implemented for metric {self.get_name()}"
- )
- else:
- pass
- #####################################################################
- # Helper functions which are useful for all metrics:
- @classmethod
- def get_name(cls):
- return cls.__name__
- @staticmethod
- def _combine_sum(all_res, field):
- """Combine sequence results via sum"""
- return sum([all_res[k][field] for k in all_res.keys()])
- @staticmethod
- def _combine_weighted_av(all_res, field, comb_res, weight_field):
- """Combine sequence results via weighted average."""
- return sum(
- [all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]
- ) / np.maximum(1.0, comb_res[weight_field])
- def print_table(self, table_res, tracker, cls):
- """Print table of results for all sequences."""
- print("")
- metric_name = self.get_name()
- self._row_print(
- [metric_name + ": " + tracker + "-" + cls] + self.summary_fields
- )
- for seq, results in sorted(table_res.items()):
- if seq == "COMBINED_SEQ":
- continue
- summary_res = self._summary_row(results)
- self._row_print([seq] + summary_res)
- summary_res = self._summary_row(table_res["COMBINED_SEQ"])
- self._row_print(["COMBINED"] + summary_res)
- def _summary_row(self, results_):
- vals = []
- for h in self.summary_fields:
- if h in self.float_array_fields:
- vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
- elif h in self.float_fields:
- vals.append("{0:1.5g}".format(100 * float(results_[h])))
- elif h in self.integer_fields:
- vals.append("{0:d}".format(int(results_[h])))
- else:
- raise NotImplementedError(
- "Summary function not implemented for this field type."
- )
- return vals
- @staticmethod
- def _row_print(*argv):
- """Print results in evenly spaced rows, with more space in first row."""
- if len(argv) == 1:
- argv = argv[0]
- to_print = "%-35s" % argv[0]
- for v in argv[1:]:
- to_print += "%-10s" % str(v)
- print(to_print)
- def summary_results(self, table_res):
- """Return a simple summary of final results for a tracker."""
- return dict(
- zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"]),)
- )
- def detailed_results(self, table_res):
- """Return detailed final results for a tracker."""
- # Get detailed field information
- detailed_fields = self.float_fields + self.integer_fields
- for h in self.float_array_fields + self.integer_array_fields:
- for alpha in [int(100 * x) for x in self.array_labels]:
- detailed_fields.append(h + "___" + str(alpha))
- detailed_fields.append(h + "___AUC")
- # Get detailed results
- detailed_results = {}
- for seq, res in table_res.items():
- detailed_row = self._detailed_row(res)
- if len(detailed_row) != len(detailed_fields):
- raise TrackEvalException(
- f"Field names and data have different sizes "
- f"({len(detailed_row)} and {len(detailed_fields)})"
- )
- detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
- return detailed_results
- def _detailed_row(self, res):
- detailed_row = []
- for h in self.float_fields + self.integer_fields:
- detailed_row.append(res[h])
- for h in self.float_array_fields + self.integer_array_fields:
- for i, _ in enumerate([int(100 * x) for x in self.array_labels]):
- detailed_row.append(res[h][i])
- detailed_row.append(np.mean(res[h]))
- return detailed_row
|