train.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import os
  7. import random
  8. import sys
  9. import traceback
  10. from argparse import ArgumentParser
  11. import submitit
  12. import torch
  13. from hydra import compose, initialize_config_module
  14. from hydra.utils import instantiate
  15. from iopath.common.file_io import g_pathmgr
  16. from omegaconf import OmegaConf
  17. from training.utils.train_utils import makedir, register_omegaconf_resolvers
  18. os.environ["HYDRA_FULL_ERROR"] = "1"
  19. def single_proc_run(local_rank, main_port, cfg, world_size):
  20. """Single GPU process"""
  21. os.environ["MASTER_ADDR"] = "localhost"
  22. os.environ["MASTER_PORT"] = str(main_port)
  23. os.environ["RANK"] = str(local_rank)
  24. os.environ["LOCAL_RANK"] = str(local_rank)
  25. os.environ["WORLD_SIZE"] = str(world_size)
  26. try:
  27. register_omegaconf_resolvers()
  28. except Exception as e:
  29. logging.info(e)
  30. trainer = instantiate(cfg.trainer, _recursive_=False)
  31. trainer.run()
  32. def single_node_runner(cfg, main_port: int):
  33. assert cfg.launcher.num_nodes == 1
  34. num_proc = cfg.launcher.gpus_per_node
  35. torch.multiprocessing.set_start_method(
  36. "spawn"
  37. ) # CUDA runtime does not support `fork`
  38. if num_proc == 1:
  39. # directly call single_proc so we can easily set breakpoints
  40. # mp.spawn does not let us set breakpoints
  41. single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
  42. else:
  43. mp_runner = torch.multiprocessing.start_processes
  44. args = (main_port, cfg, num_proc)
  45. # Note: using "fork" below, "spawn" causes time and error regressions. Using
  46. # spawn changes the default multiprocessing context to spawn, which doesn't
  47. # interact well with the dataloaders (likely due to the use of OpenCV).
  48. mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
  49. def format_exception(e: Exception, limit=20):
  50. traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
  51. return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
  52. class SubmititRunner(submitit.helpers.Checkpointable):
  53. """A callable which is passed to submitit to launch the jobs."""
  54. def __init__(self, port, cfg):
  55. self.cfg = cfg
  56. self.port = port
  57. self.has_setup = False
  58. def run_trainer(self):
  59. job_env = submitit.JobEnvironment()
  60. # Need to add this again so the hydra.job.set_env PYTHONPATH
  61. # is also set when launching jobs.
  62. add_pythonpath_to_sys_path()
  63. os.environ["MASTER_ADDR"] = job_env.hostnames[0]
  64. os.environ["MASTER_PORT"] = str(self.port)
  65. os.environ["RANK"] = str(job_env.global_rank)
  66. os.environ["LOCAL_RANK"] = str(job_env.local_rank)
  67. os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
  68. register_omegaconf_resolvers()
  69. cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
  70. cfg_resolved = OmegaConf.create(cfg_resolved)
  71. trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
  72. trainer.run()
  73. def __call__(self):
  74. job_env = submitit.JobEnvironment()
  75. self.setup_job_info(job_env.job_id, job_env.global_rank)
  76. try:
  77. self.run_trainer()
  78. except Exception as e:
  79. # Log the exception. Then raise it again (as what SubmititRunner currently does).
  80. message = format_exception(e)
  81. logging.error(message)
  82. raise e
  83. def setup_job_info(self, job_id, rank):
  84. """Set up slurm job info"""
  85. self.job_info = {
  86. "job_id": job_id,
  87. "rank": rank,
  88. "cluster": self.cfg.get("cluster", None),
  89. "experiment_log_dir": self.cfg.launcher.experiment_log_dir,
  90. }
  91. self.has_setup = True
  92. def add_pythonpath_to_sys_path():
  93. if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
  94. return
  95. sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
  96. def main(args) -> None:
  97. cfg = compose(config_name=args.config)
  98. if cfg.launcher.experiment_log_dir is None:
  99. cfg.launcher.experiment_log_dir = os.path.join(
  100. os.getcwd(), "sam2_logs", args.config
  101. )
  102. print("###################### Train App Config ####################")
  103. print(OmegaConf.to_yaml(cfg))
  104. print("############################################################")
  105. add_pythonpath_to_sys_path()
  106. makedir(cfg.launcher.experiment_log_dir)
  107. with g_pathmgr.open(
  108. os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
  109. ) as f:
  110. f.write(OmegaConf.to_yaml(cfg))
  111. cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
  112. cfg_resolved = OmegaConf.create(cfg_resolved)
  113. with g_pathmgr.open(
  114. os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
  115. ) as f:
  116. f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
  117. submitit_conf = cfg.get("submitit", None)
  118. assert submitit_conf is not None, "Missing submitit config"
  119. submitit_dir = cfg.launcher.experiment_log_dir
  120. submitit_dir = os.path.join(submitit_dir, "submitit_logs")
  121. # Priotrize cmd line args
  122. cfg.launcher.gpus_per_node = (
  123. args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
  124. )
  125. cfg.launcher.num_nodes = (
  126. args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
  127. )
  128. submitit_conf.use_cluster = (
  129. args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
  130. )
  131. if submitit_conf.use_cluster:
  132. executor = submitit.AutoExecutor(folder=submitit_dir)
  133. submitit_conf.partition = (
  134. args.partition
  135. if args.partition is not None
  136. else submitit_conf.get("partition", None)
  137. )
  138. submitit_conf.account = (
  139. args.account
  140. if args.account is not None
  141. else submitit_conf.get("account", None)
  142. )
  143. submitit_conf.qos = (
  144. args.qos if args.qos is not None else submitit_conf.get("qos", None)
  145. )
  146. job_kwargs = {
  147. "timeout_min": 60 * submitit_conf.timeout_hour,
  148. "name": (
  149. submitit_conf.name if hasattr(submitit_conf, "name") else args.config
  150. ),
  151. "slurm_partition": submitit_conf.partition,
  152. "gpus_per_node": cfg.launcher.gpus_per_node,
  153. "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
  154. "cpus_per_task": submitit_conf.cpus_per_task,
  155. "nodes": cfg.launcher.num_nodes,
  156. "slurm_additional_parameters": {
  157. "exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
  158. },
  159. }
  160. if "include_nodes" in submitit_conf:
  161. assert (
  162. len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
  163. ), "Not enough nodes"
  164. job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
  165. submitit_conf["include_nodes"]
  166. )
  167. if submitit_conf.account is not None:
  168. job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
  169. if submitit_conf.qos is not None:
  170. job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
  171. if submitit_conf.get("mem_gb", None) is not None:
  172. job_kwargs["mem_gb"] = submitit_conf.mem_gb
  173. elif submitit_conf.get("mem", None) is not None:
  174. job_kwargs["slurm_mem"] = submitit_conf.mem
  175. if submitit_conf.get("constraints", None) is not None:
  176. job_kwargs["slurm_constraint"] = submitit_conf.constraints
  177. if submitit_conf.get("comment", None) is not None:
  178. job_kwargs["slurm_comment"] = submitit_conf.comment
  179. # Supports only cpu-bind option within srun_args. New options can be added here
  180. if submitit_conf.get("srun_args", None) is not None:
  181. job_kwargs["slurm_srun_args"] = []
  182. if submitit_conf.srun_args.get("cpu_bind", None) is not None:
  183. job_kwargs["slurm_srun_args"].extend(
  184. ["--cpu-bind", submitit_conf.srun_args.cpu_bind]
  185. )
  186. print("###################### SLURM Config ####################")
  187. print(job_kwargs)
  188. print("##########################################")
  189. executor.update_parameters(**job_kwargs)
  190. main_port = random.randint(
  191. submitit_conf.port_range[0], submitit_conf.port_range[1]
  192. )
  193. runner = SubmititRunner(main_port, cfg)
  194. job = executor.submit(runner)
  195. print(f"Submitit Job ID: {job.job_id}")
  196. runner.setup_job_info(job.job_id, rank=0)
  197. else:
  198. cfg.launcher.num_nodes = 1
  199. main_port = random.randint(
  200. submitit_conf.port_range[0], submitit_conf.port_range[1]
  201. )
  202. single_node_runner(cfg, main_port)
  203. if __name__ == "__main__":
  204. initialize_config_module("sam2", version_base="1.2")
  205. parser = ArgumentParser()
  206. parser.add_argument(
  207. "-c",
  208. "--config",
  209. required=True,
  210. type=str,
  211. help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
  212. )
  213. parser.add_argument(
  214. "--use-cluster",
  215. type=int,
  216. default=None,
  217. help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
  218. )
  219. parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
  220. parser.add_argument("--account", type=str, default=None, help="SLURM account")
  221. parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
  222. parser.add_argument(
  223. "--num-gpus", type=int, default=None, help="number of GPUS per node"
  224. )
  225. parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
  226. args = parser.parse_args()
  227. args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
  228. register_omegaconf_resolvers()
  229. main(args)