_base_metric.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # fmt: off
  2. # flake8: noqa
  3. # pyre-unsafe
  4. from abc import ABC, abstractmethod
  5. import numpy as np
  6. from .. import _timing
  7. from ..utils import TrackEvalException
  8. class _BaseMetric(ABC):
  9. @abstractmethod
  10. def __init__(self):
  11. self.plottable = False
  12. self.integer_fields = []
  13. self.float_fields = []
  14. self.array_labels = []
  15. self.integer_array_fields = []
  16. self.float_array_fields = []
  17. self.fields = []
  18. self.summary_fields = []
  19. self.registered = False
  20. #####################################################################
  21. # Abstract functions for subclasses to implement
  22. @_timing.time
  23. @abstractmethod
  24. def eval_sequence(self, data):
  25. ...
  26. @abstractmethod
  27. def combine_sequences(self, all_res):
  28. ...
  29. @abstractmethod
  30. def combine_classes_class_averaged(self, all_res, ignore_empty=False):
  31. ...
  32. @abstractmethod
  33. def combine_classes_det_averaged(self, all_res):
  34. ...
  35. def plot_single_tracker_results(self, all_res, tracker, output_folder, cls):
  36. """Plot results, only valid for metrics with self.plottable."""
  37. if self.plottable:
  38. raise NotImplementedError(
  39. f"plot_results is not implemented for metric {self.get_name()}"
  40. )
  41. else:
  42. pass
  43. #####################################################################
  44. # Helper functions which are useful for all metrics:
  45. @classmethod
  46. def get_name(cls):
  47. return cls.__name__
  48. @staticmethod
  49. def _combine_sum(all_res, field):
  50. """Combine sequence results via sum"""
  51. return sum([all_res[k][field] for k in all_res.keys()])
  52. @staticmethod
  53. def _combine_weighted_av(all_res, field, comb_res, weight_field):
  54. """Combine sequence results via weighted average."""
  55. return sum(
  56. [all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()]
  57. ) / np.maximum(1.0, comb_res[weight_field])
  58. def print_table(self, table_res, tracker, cls):
  59. """Print table of results for all sequences."""
  60. print("")
  61. metric_name = self.get_name()
  62. self._row_print(
  63. [metric_name + ": " + tracker + "-" + cls] + self.summary_fields
  64. )
  65. for seq, results in sorted(table_res.items()):
  66. if seq == "COMBINED_SEQ":
  67. continue
  68. summary_res = self._summary_row(results)
  69. self._row_print([seq] + summary_res)
  70. summary_res = self._summary_row(table_res["COMBINED_SEQ"])
  71. self._row_print(["COMBINED"] + summary_res)
  72. def _summary_row(self, results_):
  73. vals = []
  74. for h in self.summary_fields:
  75. if h in self.float_array_fields:
  76. vals.append("{0:1.5g}".format(100 * np.mean(results_[h])))
  77. elif h in self.float_fields:
  78. vals.append("{0:1.5g}".format(100 * float(results_[h])))
  79. elif h in self.integer_fields:
  80. vals.append("{0:d}".format(int(results_[h])))
  81. else:
  82. raise NotImplementedError(
  83. "Summary function not implemented for this field type."
  84. )
  85. return vals
  86. @staticmethod
  87. def _row_print(*argv):
  88. """Print results in evenly spaced rows, with more space in first row."""
  89. if len(argv) == 1:
  90. argv = argv[0]
  91. to_print = "%-35s" % argv[0]
  92. for v in argv[1:]:
  93. to_print += "%-10s" % str(v)
  94. print(to_print)
  95. def summary_results(self, table_res):
  96. """Return a simple summary of final results for a tracker."""
  97. return dict(
  98. zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"]),)
  99. )
  100. def detailed_results(self, table_res):
  101. """Return detailed final results for a tracker."""
  102. # Get detailed field information
  103. detailed_fields = self.float_fields + self.integer_fields
  104. for h in self.float_array_fields + self.integer_array_fields:
  105. for alpha in [int(100 * x) for x in self.array_labels]:
  106. detailed_fields.append(h + "___" + str(alpha))
  107. detailed_fields.append(h + "___AUC")
  108. # Get detailed results
  109. detailed_results = {}
  110. for seq, res in table_res.items():
  111. detailed_row = self._detailed_row(res)
  112. if len(detailed_row) != len(detailed_fields):
  113. raise TrackEvalException(
  114. f"Field names and data have different sizes "
  115. f"({len(detailed_row)} and {len(detailed_fields)})"
  116. )
  117. detailed_results[seq] = dict(zip(detailed_fields, detailed_row))
  118. return detailed_results
  119. def _detailed_row(self, res):
  120. detailed_row = []
  121. for h in self.float_fields + self.integer_fields:
  122. detailed_row.append(res[h])
  123. for h in self.float_array_fields + self.integer_array_fields:
  124. for i, _ in enumerate([int(100 * x) for x in self.array_labels]):
  125. detailed_row.append(res[h][i])
  126. detailed_row.append(np.mean(res[h]))
  127. return detailed_row