sav_frame_extraction_submitit.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. import argparse
  4. import os
  5. from pathlib import Path
  6. import cv2
  7. import numpy as np
  8. import submitit
  9. import tqdm
  10. def get_args_parser():
  11. parser = argparse.ArgumentParser(
  12. description="[SA-V Preprocessing] Extracting JPEG frames",
  13. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  14. )
  15. # ------------
  16. # DATA
  17. # ------------
  18. data_parser = parser.add_argument_group(
  19. title="SA-V dataset data root",
  20. description="What data to load and how to process it.",
  21. )
  22. data_parser.add_argument(
  23. "--sav-vid-dir",
  24. type=str,
  25. required=True,
  26. help=("Where to find the SAV videos"),
  27. )
  28. data_parser.add_argument(
  29. "--sav-frame-sample-rate",
  30. type=int,
  31. default=4,
  32. help="Rate at which to sub-sample frames",
  33. )
  34. # ------------
  35. # LAUNCH
  36. # ------------
  37. launch_parser = parser.add_argument_group(
  38. title="Cluster launch settings",
  39. description="Number of jobs and retry settings.",
  40. )
  41. launch_parser.add_argument(
  42. "--n-jobs",
  43. type=int,
  44. required=True,
  45. help="Shard the run over this many jobs.",
  46. )
  47. launch_parser.add_argument(
  48. "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes."
  49. )
  50. launch_parser.add_argument(
  51. "--partition", type=str, required=True, help="Partition to launch on."
  52. )
  53. launch_parser.add_argument(
  54. "--account", type=str, required=True, help="Partition to launch on."
  55. )
  56. launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
  57. # ------------
  58. # OUTPUT
  59. # ------------
  60. output_parser = parser.add_argument_group(
  61. title="Setting for results output", description="Where and how to save results."
  62. )
  63. output_parser.add_argument(
  64. "--output-dir",
  65. type=str,
  66. required=True,
  67. help=("Where to dump the extracted jpeg frames"),
  68. )
  69. output_parser.add_argument(
  70. "--slurm-output-root-dir",
  71. type=str,
  72. required=True,
  73. help=("Where to save slurm outputs"),
  74. )
  75. return parser
  76. def decode_video(video_path: str):
  77. assert os.path.exists(video_path)
  78. video = cv2.VideoCapture(video_path)
  79. video_frames = []
  80. while video.isOpened():
  81. ret, frame = video.read()
  82. if ret:
  83. video_frames.append(frame)
  84. else:
  85. break
  86. return video_frames
  87. def extract_frames(video_path, sample_rate):
  88. frames = decode_video(video_path)
  89. return frames[::sample_rate]
  90. def submitit_launch(video_paths, sample_rate, save_root):
  91. for path in tqdm.tqdm(video_paths):
  92. frames = extract_frames(path, sample_rate)
  93. output_folder = os.path.join(save_root, Path(path).stem)
  94. if not os.path.exists(output_folder):
  95. os.makedirs(output_folder)
  96. for fid, frame in enumerate(frames):
  97. frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
  98. cv2.imwrite(frame_path, frame)
  99. print(f"Saved output to {save_root}")
  100. if __name__ == "__main__":
  101. parser = get_args_parser()
  102. args = parser.parse_args()
  103. sav_vid_dir = args.sav_vid_dir
  104. save_root = args.output_dir
  105. sample_rate = args.sav_frame_sample_rate
  106. # List all SA-V videos
  107. mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
  108. mp4_files = np.array(mp4_files)
  109. chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
  110. print(f"Processing videos in: {sav_vid_dir}")
  111. print(f"Processing {len(mp4_files)} files")
  112. print(f"Beginning processing in {args.n_jobs} processes")
  113. # Submitit params
  114. jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
  115. cpus_per_task = 4
  116. executor = submitit.AutoExecutor(folder=jobs_dir)
  117. executor.update_parameters(
  118. timeout_min=args.timeout,
  119. gpus_per_node=0,
  120. tasks_per_node=1,
  121. slurm_array_parallelism=args.n_jobs,
  122. cpus_per_task=cpus_per_task,
  123. slurm_partition=args.partition,
  124. slurm_account=args.account,
  125. slurm_qos=args.qos,
  126. )
  127. executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
  128. # Launch
  129. jobs = []
  130. with executor.batch():
  131. for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
  132. job = executor.submit(
  133. submitit_launch,
  134. video_paths=mp4_chunk,
  135. sample_rate=sample_rate,
  136. save_root=save_root,
  137. )
  138. jobs.append(job)
  139. for j in jobs:
  140. print(f"Slurm JobID: {j.job_id}")
  141. print(f"Saving outputs to {save_root}")
  142. print(f"Slurm outputs at {args.slurm_output_root_dir}")