train.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. import os
  5. import random
  6. import sys
  7. import traceback
  8. from argparse import ArgumentParser
  9. from copy import deepcopy
  10. import submitit
  11. import torch
  12. from hydra import compose, initialize_config_module
  13. from hydra.utils import instantiate
  14. from iopath.common.file_io import g_pathmgr
  15. from omegaconf import OmegaConf
  16. from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
  17. from tqdm import tqdm
  18. os.environ["HYDRA_FULL_ERROR"] = "1"
  19. class SlurmEvent:
  20. QUEUED = "QUEUED"
  21. START = "START"
  22. FINISH = "FINISH"
  23. JOB_ERROR = "JOB_ERROR"
  24. SLURM_SIGNAL = "SLURM_SIGNAL"
  25. def handle_custom_resolving(cfg):
  26. # We'll resolve the config here, so we can catch mistakes early.
  27. # However, we need to pass the un-resolved config to the launcher
  28. # (because DVC resolving needs to be done on the node it will run on)
  29. # First, do a copy without triggering resolving
  30. cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
  31. cfg_resolved = OmegaConf.create(cfg_resolved)
  32. return cfg_resolved
  33. def single_proc_run(local_rank, main_port, cfg, world_size):
  34. """Single GPU process"""
  35. os.environ["MASTER_ADDR"] = "localhost"
  36. os.environ["MASTER_PORT"] = str(main_port)
  37. os.environ["RANK"] = str(local_rank)
  38. os.environ["LOCAL_RANK"] = str(local_rank)
  39. os.environ["WORLD_SIZE"] = str(world_size)
  40. try:
  41. register_omegaconf_resolvers()
  42. except Exception as e:
  43. logging.info(e)
  44. trainer = instantiate(cfg.trainer, _recursive_=False)
  45. trainer.run()
  46. def single_node_runner(cfg, main_port: int):
  47. assert cfg.launcher.num_nodes == 1
  48. # assert cfg.launcher.gpus_per_node == 1
  49. num_proc = cfg.launcher.gpus_per_node
  50. torch.multiprocessing.set_start_method(
  51. "spawn"
  52. ) # CUDA runtime does not support `fork`
  53. if num_proc == 1:
  54. # directly call single_proc so we can easily set breakpoints
  55. # mp.spawn does not let us set breakpoints
  56. single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
  57. else:
  58. mp_runner = torch.multiprocessing.start_processes
  59. args = (main_port, cfg, num_proc)
  60. # Note: using "fork" below, "spawn" causes time and error regressions. Using
  61. # spawn changes the default multiprocessing context to spawn, which doesn't
  62. # interact well with the dataloaders (likely due to the use of OpenCV).
  63. mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
  64. def format_exception(e: Exception, limit=20):
  65. traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
  66. return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
  67. class SubmititRunner(submitit.helpers.Checkpointable):
  68. """A callable which is passed to submitit to launch the jobs."""
  69. def __init__(self, port, cfg):
  70. self.cfg = cfg
  71. self.port = port
  72. self.has_setup = False
  73. def run_trainer(self):
  74. job_env = submitit.JobEnvironment()
  75. # Need to add this again so the hydra.job.set_env PYTHONPATH
  76. # is also set when launching jobs.
  77. add_pythonpath_to_sys_path()
  78. os.environ["MASTER_ADDR"] = job_env.hostnames[0]
  79. os.environ["MASTER_PORT"] = str(self.port)
  80. os.environ["RANK"] = str(job_env.global_rank)
  81. os.environ["LOCAL_RANK"] = str(job_env.local_rank)
  82. os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
  83. register_omegaconf_resolvers()
  84. cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
  85. cfg_resolved = OmegaConf.create(cfg_resolved)
  86. trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
  87. trainer.run()
  88. def __call__(self):
  89. job_env = submitit.JobEnvironment()
  90. self.setup_job_info(job_env.job_id, job_env.global_rank)
  91. try:
  92. self.run_trainer()
  93. except Exception as e:
  94. # Log the exception. Then raise it again (as what SubmititRunner currently does).
  95. message = format_exception(e)
  96. logging.error(message)
  97. raise e
  98. def setup_job_info(self, job_id, rank):
  99. """Set up slurm job info"""
  100. self.job_info = {
  101. "job_id": job_id,
  102. "rank": rank,
  103. "cluster": self.cfg.get("cluster", None),
  104. "experiment_log_dir": self.cfg.launcher.experiment_log_dir,
  105. }
  106. self.has_setup = True
  107. def add_pythonpath_to_sys_path():
  108. if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
  109. return
  110. sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
  111. def main(args) -> None:
  112. cfg = compose(config_name=args.config)
  113. if cfg.launcher.experiment_log_dir is None:
  114. cfg.launcher.experiment_log_dir = os.path.join(
  115. os.getcwd(), "sam3_logs", args.config
  116. )
  117. print("###################### Train App Config ####################")
  118. print(OmegaConf.to_yaml(cfg))
  119. print("############################################################")
  120. add_pythonpath_to_sys_path()
  121. makedir(cfg.launcher.experiment_log_dir)
  122. with g_pathmgr.open(
  123. os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
  124. ) as f:
  125. f.write(OmegaConf.to_yaml(cfg))
  126. cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
  127. cfg_resolved = OmegaConf.create(cfg_resolved)
  128. with g_pathmgr.open(
  129. os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
  130. ) as f:
  131. f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
  132. submitit_conf = cfg.get("submitit", None)
  133. assert submitit_conf is not None, "Missing submitit config"
  134. experiment_log_dir = cfg.launcher.experiment_log_dir
  135. print(f"Experiment Log Dir:\n{experiment_log_dir}")
  136. submitit_dir = os.path.join(experiment_log_dir, "submitit_logs")
  137. # Prioritize cmd line args
  138. cfg.launcher.gpus_per_node = (
  139. args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
  140. )
  141. cfg.launcher.num_nodes = (
  142. args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
  143. )
  144. submitit_conf.use_cluster = (
  145. args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
  146. )
  147. if submitit_conf.use_cluster:
  148. executor = submitit.AutoExecutor(folder=submitit_dir)
  149. submitit_conf.partition = (
  150. args.partition
  151. if args.partition is not None
  152. else submitit_conf.get("partition", None)
  153. )
  154. submitit_conf.account = (
  155. args.account
  156. if args.account is not None
  157. else submitit_conf.get("account", None)
  158. )
  159. submitit_conf.qos = (
  160. args.qos if args.qos is not None else submitit_conf.get("qos", None)
  161. )
  162. job_kwargs = {
  163. "timeout_min": 60 * submitit_conf.timeout_hour,
  164. "name": (
  165. submitit_conf.name if hasattr(submitit_conf, "name") else args.config
  166. ),
  167. "slurm_partition": submitit_conf.partition,
  168. "gpus_per_node": cfg.launcher.gpus_per_node,
  169. "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
  170. "cpus_per_task": submitit_conf.cpus_per_task,
  171. "nodes": cfg.launcher.num_nodes,
  172. "slurm_additional_parameters": {
  173. "exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
  174. },
  175. }
  176. if "include_nodes" in submitit_conf:
  177. assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
  178. "Not enough nodes"
  179. )
  180. job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
  181. submitit_conf["include_nodes"]
  182. )
  183. if submitit_conf.account is not None:
  184. job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
  185. if submitit_conf.qos is not None:
  186. job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
  187. if submitit_conf.get("mem_gb", None) is not None:
  188. job_kwargs["mem_gb"] = submitit_conf.mem_gb
  189. elif submitit_conf.get("mem", None) is not None:
  190. job_kwargs["slurm_mem"] = submitit_conf.mem
  191. if submitit_conf.get("constraints", None) is not None:
  192. job_kwargs["slurm_constraint"] = submitit_conf.constraints
  193. if submitit_conf.get("comment", None) is not None:
  194. job_kwargs["slurm_comment"] = submitit_conf.comment
  195. # Supports only cpu-bind option within srun_args. New options can be added here
  196. if submitit_conf.get("srun_args", None) is not None:
  197. job_kwargs["slurm_srun_args"] = []
  198. if submitit_conf.srun_args.get("cpu_bind", None) is not None:
  199. job_kwargs["slurm_srun_args"].extend(
  200. ["--cpu-bind", submitit_conf.srun_args.cpu_bind]
  201. )
  202. print("###################### SLURM Config ####################")
  203. print(job_kwargs)
  204. print("##########################################")
  205. executor.update_parameters(**job_kwargs)
  206. if (
  207. "job_array" in submitit_conf
  208. and submitit_conf.job_array.get("num_tasks", -1) > 0
  209. ):
  210. num_tasks = submitit_conf.job_array.num_tasks
  211. job_array_config_dir = os.path.join(
  212. cfg.launcher.experiment_log_dir, "job_array_configs"
  213. )
  214. makedir(job_array_config_dir)
  215. job_indices = range(num_tasks)
  216. ports = random.sample(
  217. range(submitit_conf.port_range[0], submitit_conf.port_range[1] + 1),
  218. k=len(job_indices),
  219. )
  220. jobs_runners_configs = []
  221. with executor.batch():
  222. task_index = 0
  223. for indices, main_port in tqdm(zip(job_indices, ports)):
  224. curr_cfg = deepcopy(cfg)
  225. curr_cfg.submitit.job_array["task_index"] = task_index
  226. curr_cfg_resolved = handle_custom_resolving(cfg)
  227. runner = SubmititRunner(main_port, curr_cfg)
  228. job = executor.submit(runner)
  229. jobs_runners_configs.append(
  230. (job, runner, curr_cfg, curr_cfg_resolved)
  231. )
  232. task_index += 1
  233. for job, runner, job_cfg, job_cfg_resolved in jobs_runners_configs:
  234. print("Submitit Job ID:", job.job_id)
  235. # Save job specific config
  236. job_array_config_file = os.path.join(
  237. job_array_config_dir, "{}.config.yaml".format(job.job_id)
  238. )
  239. with g_pathmgr.open(job_array_config_file, "w") as f:
  240. f.write(OmegaConf.to_yaml(job_cfg))
  241. job_array_config_resolved_file = os.path.join(
  242. job_array_config_dir, "{}.config_resolved.yaml".format(job.job_id)
  243. )
  244. with g_pathmgr.open(job_array_config_resolved_file, "w") as f:
  245. f.write(OmegaConf.to_yaml(job_cfg_resolved, resolve=True))
  246. runner.setup_job_info(job.job_id, rank=0)
  247. # runner.log_event(event_type=SlurmEvent.QUEUED)
  248. else:
  249. main_port = random.randint(
  250. submitit_conf.port_range[0], submitit_conf.port_range[1]
  251. )
  252. runner = SubmititRunner(main_port, cfg)
  253. job = executor.submit(runner)
  254. print(f"Submitit Job ID: {job.job_id}")
  255. runner.setup_job_info(job.job_id, rank=0)
  256. else:
  257. cfg.launcher.num_nodes = 1
  258. main_port = random.randint(
  259. submitit_conf.port_range[0], submitit_conf.port_range[1]
  260. )
  261. single_node_runner(cfg, main_port)
  262. if __name__ == "__main__":
  263. initialize_config_module("sam3.train", version_base="1.2")
  264. parser = ArgumentParser()
  265. parser.add_argument(
  266. "-c",
  267. "--config",
  268. required=True,
  269. type=str,
  270. help="path to config file (e.g. configs/roboflow_v100_full_ft_100_images.yaml)",
  271. )
  272. parser.add_argument(
  273. "--use-cluster",
  274. type=int,
  275. default=None,
  276. help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
  277. )
  278. parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
  279. parser.add_argument("--account", type=str, default=None, help="SLURM account")
  280. parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
  281. parser.add_argument(
  282. "--num-gpus", type=int, default=None, help="number of GPUS per node"
  283. )
  284. parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
  285. args = parser.parse_args()
  286. args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
  287. register_omegaconf_resolvers()
  288. main(args)