saco_veval_eval.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import argparse
  4. import json
  5. import os
  6. from collections import defaultdict
  7. from iopath.common.file_io import g_pathmgr
  8. from sam3.eval.saco_veval_evaluators import (
  9. VideoCGF1Evaluator,
  10. VideoPhraseApEvaluator,
  11. VideoPhraseHotaEvaluator,
  12. VideoTetaEvaluator,
  13. YTVISPredFileEvaluator,
  14. )
  15. class VEvalEvaluator:
  16. def __init__(self, gt_annot_file: str, eval_res_file: str):
  17. self.gt_annot_file = gt_annot_file
  18. self.eval_res_file = eval_res_file
  19. self.evaluators = [
  20. # mAP
  21. YTVISPredFileEvaluator(gt_annot_file),
  22. # Phrase AP
  23. VideoPhraseApEvaluator(gt_annot_file),
  24. # TETA
  25. VideoTetaEvaluator(gt_annot_file, use_mask=True, is_exhaustive=True),
  26. # HOTA
  27. VideoPhraseHotaEvaluator(gt_annot_file),
  28. # cgF1
  29. VideoCGF1Evaluator(gt_annot_file),
  30. ]
  31. def run_eval(self, pred_file: str):
  32. dataset_results = {}
  33. video_np_results = defaultdict(dict)
  34. for evaluator in self.evaluators:
  35. d_res, v_np_res = evaluator.evaluate(pred_file)
  36. dataset_results.update(d_res)
  37. for (video_id, category_id), res in v_np_res.items():
  38. video_np_results[(video_id, category_id)].update(res)
  39. if len(dataset_results) == 0:
  40. dataset_results = {"": 0.0}
  41. formatted_video_np_results = [
  42. {"video_id": video_id, "category_id": category_id, **res}
  43. for (video_id, category_id), res in video_np_results.items()
  44. ]
  45. eval_metrics = {
  46. "dataset_results": dataset_results,
  47. "video_np_results": formatted_video_np_results,
  48. }
  49. with g_pathmgr.open(self.eval_res_file, "w") as f:
  50. json.dump(eval_metrics, f)
  51. return eval_metrics
  52. def run_main_all(dataset_name, args):
  53. gt_annot_file = os.path.join(args.gt_annot_dir, dataset_name + ".json")
  54. pred_file = os.path.join(args.pred_dir, dataset_name + "_preds.json")
  55. eval_res_file = os.path.join(args.eval_res_dir, dataset_name + "_eval_res.json")
  56. print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===")
  57. veval_evaluator = VEvalEvaluator(
  58. gt_annot_file=gt_annot_file, eval_res_file=eval_res_file
  59. )
  60. _ = veval_evaluator.run_eval(pred_file=pred_file)
  61. print(f"=== Results saved to {eval_res_file} ===")
  62. def main_all(args):
  63. saco_veval_dataset_names = [
  64. "saco_veval_sav_test",
  65. "saco_veval_sav_val",
  66. "saco_veval_yt1b_test",
  67. "saco_veval_yt1b_val",
  68. "saco_veval_smartglasses_test",
  69. "saco_veval_smartglasses_val",
  70. ]
  71. # multiprocessing may not really work as inner evaluator also using multiprocessing
  72. # so we just for loop
  73. for dataset_name in saco_veval_dataset_names:
  74. print(f"=== Running evaluation for dataset {dataset_name} ===")
  75. run_main_all(dataset_name=dataset_name, args=args)
  76. def main_one(args):
  77. gt_annot_file = args.gt_annot_file
  78. pred_file = args.pred_file
  79. eval_res_file = args.eval_res_file
  80. print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===")
  81. veval_evaluator = VEvalEvaluator(
  82. gt_annot_file=gt_annot_file, eval_res_file=eval_res_file
  83. )
  84. _ = veval_evaluator.run_eval(pred_file=pred_file)
  85. print(f"=== Results saved to {eval_res_file} ===")
  86. def main():
  87. parser = argparse.ArgumentParser(description="Run video grounding evaluators")
  88. # Create subparsers for different commands
  89. subparsers = parser.add_subparsers(dest="command", required=True)
  90. # Run evaluation for all datasets
  91. all_parser = subparsers.add_parser("all", help="Run evaluation for all datasets")
  92. all_parser.add_argument(
  93. "--gt_annot_dir",
  94. type=str,
  95. help="Directory that contains the ground truth annotation files",
  96. )
  97. all_parser.add_argument(
  98. "--pred_dir",
  99. type=str,
  100. help="Directory that contains the prediction files",
  101. )
  102. all_parser.add_argument(
  103. "--eval_res_dir",
  104. type=str,
  105. help="Directory that contains the eval results files",
  106. )
  107. all_parser.set_defaults(func=main_all)
  108. # Run evaluation for one dataset
  109. one_parser = subparsers.add_parser("one", help="Run evaluation for one dataset")
  110. one_parser.add_argument(
  111. "--gt_annot_file",
  112. type=str,
  113. help="Path to the ground truth annotation file",
  114. )
  115. one_parser.add_argument(
  116. "--pred_file",
  117. type=str,
  118. help="Path to the prediction file",
  119. )
  120. one_parser.add_argument(
  121. "--eval_res_file",
  122. type=str,
  123. help="Path to the eval results file",
  124. )
  125. one_parser.set_defaults(func=main_one)
  126. # Parse and dispatch
  127. args = parser.parse_args()
  128. args.func(args)
  129. if __name__ == "__main__":
  130. main()