run_ytvis_eval.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # flake8: noqa
  2. # pyre-unsafe
  3. """run_youtube_vis.py
  4. Run example:
  5. run_youtube_vis.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL STEm_Seg
  6. Command Line Arguments: Defaults, # Comments
  7. Eval arguments:
  8. 'USE_PARALLEL': False,
  9. 'NUM_PARALLEL_CORES': 8,
  10. 'BREAK_ON_ERROR': True, # Raises exception and exits with error
  11. 'RETURN_ON_ERROR': False, # if not BREAK_ON_ERROR, then returns from function on error
  12. 'LOG_ON_ERROR': os.path.join(code_path, 'error_log.txt'), # if not None, save any errors into a log file.
  13. 'PRINT_RESULTS': True,
  14. 'PRINT_ONLY_COMBINED': False,
  15. 'PRINT_CONFIG': True,
  16. 'TIME_PROGRESS': True,
  17. 'DISPLAY_LESS_PROGRESS': True,
  18. 'OUTPUT_SUMMARY': True,
  19. 'OUTPUT_EMPTY_CLASSES': True, # If False, summary files are not output for classes with no detections
  20. 'OUTPUT_DETAILED': True,
  21. 'PLOT_CURVES': True,
  22. Dataset arguments:
  23. 'GT_FOLDER': os.path.join(code_path, 'data/gt/youtube_vis/youtube_vis_training'), # Location of GT data
  24. 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/youtube_vis/youtube_vis_training'),
  25. # Trackers location
  26. 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER)
  27. 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder)
  28. 'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes)
  29. 'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val'
  30. 'PRINT_CONFIG': True, # Whether to print current config
  31. 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER
  32. 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER
  33. 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL
  34. Metric arguments:
  35. 'METRICS': ['TrackMAP', 'HOTA', 'CLEAR', 'Identity']
  36. """
  37. import argparse
  38. import os
  39. import sys
  40. from multiprocessing import freeze_support
  41. from . import trackeval
  42. def run_ytvis_eval(args=None, gt_json=None, dt_json=None):
  43. # Command line interface:
  44. default_eval_config = trackeval.Evaluator.get_default_eval_config()
  45. # print only combined since TrackMAP is undefined for per sequence breakdowns
  46. default_eval_config["PRINT_ONLY_COMBINED"] = True
  47. default_dataset_config = trackeval.datasets.YouTubeVIS.get_default_dataset_config()
  48. default_metrics_config = {"METRICS": ["HOTA"]}
  49. config = {
  50. **default_eval_config,
  51. **default_dataset_config,
  52. **default_metrics_config,
  53. } # Merge default configs
  54. parser = argparse.ArgumentParser()
  55. for setting in config.keys():
  56. if type(config[setting]) == list or type(config[setting]) == type(None):
  57. parser.add_argument("--" + setting, nargs="+")
  58. else:
  59. parser.add_argument("--" + setting)
  60. args = parser.parse_args(args).__dict__
  61. for setting in args.keys():
  62. if args[setting] is not None:
  63. if type(config[setting]) == type(True):
  64. if args[setting] == "True":
  65. x = True
  66. elif args[setting] == "False":
  67. x = False
  68. else:
  69. raise Exception(
  70. "Command line parameter " + setting + "must be True or False"
  71. )
  72. elif type(config[setting]) == type(1):
  73. x = int(args[setting])
  74. elif type(args[setting]) == type(None):
  75. x = None
  76. else:
  77. x = args[setting]
  78. config[setting] = x
  79. eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()}
  80. dataset_config = {
  81. k: v for k, v in config.items() if k in default_dataset_config.keys()
  82. }
  83. metrics_config = {
  84. k: v for k, v in config.items() if k in default_metrics_config.keys()
  85. }
  86. # Run code
  87. evaluator = trackeval.Evaluator(eval_config)
  88. # allow directly specifying the GT JSON data and Tracker (result)
  89. # JSON data as Python objects, without reading from files.
  90. dataset_config["GT_JSON_OBJECT"] = gt_json
  91. dataset_config["TRACKER_JSON_OBJECT"] = dt_json
  92. dataset_list = [trackeval.datasets.YouTubeVIS(dataset_config)]
  93. metrics_list = []
  94. # for metric in [trackeval.metrics.TrackMAP, trackeval.metrics.HOTA, trackeval.metrics.CLEAR,
  95. # trackeval.metrics.Identity]:
  96. for metric in [trackeval.metrics.HOTA]:
  97. if metric.get_name() in metrics_config["METRICS"]:
  98. metrics_list.append(metric())
  99. if len(metrics_list) == 0:
  100. raise Exception("No metrics selected for evaluation")
  101. output_res, output_msg = evaluator.evaluate(dataset_list, metrics_list)
  102. return output_res, output_msg
  103. if __name__ == "__main__":
  104. import sys
  105. freeze_support()
  106. run_ytvis_eval(sys.argv[1:])