utils.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # fmt: off
  2. # flake8: noqa
  3. # pyre-unsafe
  4. import csv
  5. import os
  6. from collections import OrderedDict
  7. def validate_metrics_list(metrics_list):
  8. """Get names of metric class and ensures they are unique, further checks that the fields within each metric class
  9. do not have overlapping names.
  10. """
  11. metric_names = [metric.get_name() for metric in metrics_list]
  12. # check metric names are unique
  13. if len(metric_names) != len(set(metric_names)):
  14. raise TrackEvalException(
  15. "Code being run with multiple metrics of the same name"
  16. )
  17. fields = []
  18. for m in metrics_list:
  19. fields += m.fields
  20. # check metric fields are unique
  21. if len(fields) != len(set(fields)):
  22. raise TrackEvalException(
  23. "Code being run with multiple metrics with fields of the same name"
  24. )
  25. return metric_names
  26. def get_track_id_str(ann):
  27. """Get name of track ID in annotation."""
  28. if "track_id" in ann:
  29. tk_str = "track_id"
  30. elif "instance_id" in ann:
  31. tk_str = "instance_id"
  32. elif "scalabel_id" in ann:
  33. tk_str = "scalabel_id"
  34. else:
  35. assert False, "No track/instance ID."
  36. return tk_str
  37. class TrackEvalException(Exception):
  38. """Custom exception for catching expected errors."""
  39. ...