count.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # flake8: noqa
  2. # pyre-unsafe
  3. from .. import _timing
  4. from ._base_metric import _BaseMetric
  5. class Count(_BaseMetric):
  6. """Class which simply counts the number of tracker and gt detections and ids."""
  7. def __init__(self, config=None):
  8. super().__init__()
  9. self.integer_fields = ["Dets", "GT_Dets", "IDs", "GT_IDs"]
  10. self.fields = self.integer_fields
  11. self.summary_fields = self.fields
  12. @_timing.time
  13. def eval_sequence(self, data):
  14. """Returns counts for one sequence"""
  15. # Get results
  16. res = {
  17. "Dets": data["num_tracker_dets"],
  18. "GT_Dets": data["num_gt_dets"],
  19. "IDs": data["num_tracker_ids"],
  20. "GT_IDs": data["num_gt_ids"],
  21. "Frames": data["num_timesteps"],
  22. }
  23. return res
  24. def combine_sequences(self, all_res):
  25. """Combines metrics across all sequences"""
  26. res = {}
  27. for field in self.integer_fields:
  28. res[field] = self._combine_sum(all_res, field)
  29. return res
  30. def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None):
  31. """Combines metrics across all classes by averaging over the class values"""
  32. res = {}
  33. for field in self.integer_fields:
  34. res[field] = self._combine_sum(all_res, field)
  35. return res
  36. def combine_classes_det_averaged(self, all_res):
  37. """Combines metrics across all classes by averaging over the detection values"""
  38. res = {}
  39. for field in self.integer_fields:
  40. res[field] = self._combine_sum(all_res, field)
  41. return res