trainer.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import contextlib
  4. import fnmatch
  5. import gc
  6. import json
  7. import logging
  8. import math
  9. import os
  10. import time
  11. from collections import OrderedDict
  12. from dataclasses import dataclass, field
  13. from typing import Any, Dict, List, Mapping, Optional
  14. import numpy as np
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. from hydra.utils import instantiate
  19. from iopath.common.file_io import g_pathmgr
  20. from sam3.model.data_misc import BatchedDatapoint
  21. from sam3.model.model_misc import SAM3Output
  22. from sam3.model.utils.misc import copy_data_to_device
  23. from sam3.train.optim.optimizer import construct_optimizer
  24. from sam3.train.utils.checkpoint_utils import (
  25. assert_skipped_parameters_are_frozen,
  26. exclude_params_matching_unix_pattern,
  27. load_state_dict_into_model,
  28. with_check_parameter_frozen,
  29. )
  30. from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
  31. from sam3.train.utils.logger import Logger, setup_logging
  32. from sam3.train.utils.train_utils import (
  33. AverageMeter,
  34. collect_dict_keys,
  35. DurationMeter,
  36. get_amp_type,
  37. get_machine_local_and_dist_rank,
  38. get_resume_checkpoint,
  39. human_readable_time,
  40. is_dist_avail_and_initialized,
  41. log_env_variables,
  42. makedir,
  43. MemMeter,
  44. Phase,
  45. ProgressMeter,
  46. set_seeds,
  47. setup_distributed_backend,
  48. )
  49. CORE_LOSS_KEY = "core_loss"
  50. def unwrap_ddp_if_wrapped(model):
  51. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  52. return model.module
  53. return model
  54. @dataclass
  55. class OptimAMPConf:
  56. enabled: bool = False
  57. amp_dtype: str = "float16"
  58. @dataclass
  59. class OptimConf:
  60. optimizer: torch.optim.Optimizer = None
  61. options: Optional[Dict[str, Any]] = None
  62. param_group_modifiers: Optional[List] = None
  63. amp: Optional[Dict[str, Any]] = None
  64. gradient_clip: Any = None
  65. gradient_logger: Any = None
  66. def __post_init__(self):
  67. # amp
  68. if not isinstance(self.amp, OptimAMPConf):
  69. if self.amp is None:
  70. self.amp = {}
  71. assert isinstance(self.amp, Mapping)
  72. self.amp = OptimAMPConf(**self.amp)
  73. @dataclass
  74. class DistributedConf:
  75. backend: Optional[str] = None # inferred from accelerator type
  76. comms_dtype: Optional[str] = None
  77. find_unused_parameters: bool = False
  78. timeout_mins: int = 30
  79. gradient_as_bucket_view: bool = False # PyTorch DDP default is False
  80. static_graph: bool = False # PyTorch DDP default is False
  81. @dataclass
  82. class CudaConf:
  83. cudnn_deterministic: bool = False
  84. cudnn_benchmark: bool = True
  85. allow_tf32: bool = False
  86. # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
  87. matmul_allow_tf32: Optional[bool] = None
  88. # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
  89. cudnn_allow_tf32: Optional[bool] = None
  90. @dataclass
  91. class CheckpointConf:
  92. save_dir: str
  93. save_freq: int
  94. save_list: List[int] = field(default_factory=list)
  95. model_weight_initializer: Any = None
  96. save_best_meters: List[str] = None
  97. skip_saving_parameters: List[str] = field(default_factory=list)
  98. initialize_after_preemption: Optional[bool] = None
  99. # if not None, training will be resumed from this checkpoint
  100. resume_from: Optional[str] = None
  101. def infer_missing(self):
  102. if self.initialize_after_preemption is None:
  103. with_skip_saving = len(self.skip_saving_parameters) > 0
  104. self.initialize_after_preemption = with_skip_saving
  105. return self
  106. @dataclass
  107. class LoggingConf:
  108. log_dir: str
  109. log_freq: int # In iterations
  110. tensorboard_writer: Any
  111. log_level_primary: str = "INFO"
  112. log_level_secondary: str = "ERROR"
  113. log_scalar_frequency: int = 100
  114. log_visual_frequency: int = 100
  115. scalar_keys_to_log: Optional[Dict[str, Any]] = None
  116. log_batch_stats: bool = False
  117. wandb_writer: Optional[Any] = None
  118. class Trainer:
  119. """
  120. Trainer supporting the DDP training strategies.
  121. """
  122. EPSILON = 1e-8
  123. def __init__(
  124. self,
  125. *, # the order of these args can change at any time, so they are keyword-only
  126. data: Dict[str, Any],
  127. model: Dict[str, Any],
  128. logging: Dict[str, Any],
  129. checkpoint: Dict[str, Any],
  130. max_epochs: int,
  131. mode: str = "train",
  132. accelerator: str = "cuda",
  133. seed_value: int = 123,
  134. val_epoch_freq: int = 1,
  135. distributed: Dict[str, bool] = None,
  136. cuda: Dict[str, bool] = None,
  137. env_variables: Optional[Dict[str, Any]] = None,
  138. optim: Optional[Dict[str, Any]] = None,
  139. optim_overrides: Optional[List[Dict[str, Any]]] = None,
  140. meters: Optional[Dict[str, Any]] = None,
  141. loss: Optional[Dict[str, Any]] = None,
  142. skip_first_val: bool = False,
  143. skip_saving_ckpts: bool = False,
  144. empty_gpu_mem_cache_after_eval: bool = True,
  145. gradient_accumulation_steps: int = 1,
  146. ):
  147. self._setup_env_variables(env_variables)
  148. self._setup_timers()
  149. self.data_conf = data
  150. self.model_conf = model
  151. self.logging_conf = LoggingConf(**logging)
  152. self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
  153. self.max_epochs = max_epochs
  154. self.mode = mode
  155. self.val_epoch_freq = val_epoch_freq
  156. self.optim_conf = OptimConf(**optim) if optim is not None else OptimConf()
  157. self.meters_conf = meters
  158. self.loss_conf = loss
  159. self.gradient_accumulation_steps = gradient_accumulation_steps
  160. distributed = DistributedConf(**distributed or {})
  161. cuda = CudaConf(**cuda or {})
  162. self.where = 0.0
  163. self.skip_first_val = skip_first_val
  164. self.skip_saving_ckpts = skip_saving_ckpts
  165. self.empty_gpu_mem_cache_after_eval = empty_gpu_mem_cache_after_eval
  166. self._infer_distributed_backend_if_none(distributed, accelerator)
  167. self._setup_device(accelerator)
  168. self._setup_torch_dist_and_backend(cuda, distributed)
  169. makedir(self.logging_conf.log_dir)
  170. setup_logging(
  171. __name__,
  172. output_dir=self.logging_conf.log_dir,
  173. rank=self.rank,
  174. log_level_primary=self.logging_conf.log_level_primary,
  175. log_level_secondary=self.logging_conf.log_level_secondary,
  176. )
  177. set_seeds(seed_value, self.max_epochs, self.distributed_rank)
  178. log_env_variables()
  179. assert is_dist_avail_and_initialized(), (
  180. "Torch distributed needs to be initialized before calling the trainer."
  181. )
  182. self._setup_components() # Except Optimizer everything is setup here.
  183. self._move_to_device()
  184. self._construct_optimizers()
  185. self._setup_dataloaders()
  186. self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
  187. if self.checkpoint_conf.resume_from is not None:
  188. assert os.path.exists(self.checkpoint_conf.resume_from), (
  189. f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
  190. )
  191. dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
  192. if self.distributed_rank == 0 and not os.path.exists(dst):
  193. # Copy the "resume_from" checkpoint to the checkpoint folder
  194. # if there is not a checkpoint to resume from already there
  195. makedir(self.checkpoint_conf.save_dir)
  196. g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
  197. barrier()
  198. self.load_checkpoint()
  199. self._setup_ddp_distributed_training(distributed, accelerator)
  200. barrier()
  201. def _setup_timers(self):
  202. """
  203. Initializes counters for elapsed time and eta.
  204. """
  205. self.start_time = time.time()
  206. self.ckpt_time_elapsed = 0
  207. self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
  208. def _get_meters(self, phase_filters=None):
  209. if self.meters is None:
  210. return {}
  211. meters = {}
  212. for phase, phase_meters in self.meters.items():
  213. if phase_filters is not None and phase not in phase_filters:
  214. continue
  215. for key, key_meters in phase_meters.items():
  216. if key_meters is None:
  217. continue
  218. for name, meter in key_meters.items():
  219. meters[f"{phase}_{key}/{name}"] = meter
  220. return meters
  221. def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
  222. if distributed_conf.backend is None:
  223. distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
  224. def _setup_env_variables(self, env_variables_conf) -> None:
  225. if env_variables_conf is not None:
  226. for variable_name, value in env_variables_conf.items():
  227. os.environ[variable_name] = value
  228. def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
  229. if torch.cuda.is_available():
  230. torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
  231. torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
  232. torch.backends.cuda.matmul.allow_tf32 = (
  233. cuda_conf.matmul_allow_tf32
  234. if cuda_conf.matmul_allow_tf32 is not None
  235. else cuda_conf.allow_tf32
  236. )
  237. torch.backends.cudnn.allow_tf32 = (
  238. cuda_conf.cudnn_allow_tf32
  239. if cuda_conf.cudnn_allow_tf32 is not None
  240. else cuda_conf.allow_tf32
  241. )
  242. self.rank = setup_distributed_backend(
  243. distributed_conf.backend, distributed_conf.timeout_mins
  244. )
  245. def _setup_device(self, accelerator):
  246. self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
  247. if accelerator == "cuda":
  248. self.device = torch.device("cuda", self.local_rank)
  249. torch.cuda.set_device(self.local_rank)
  250. elif accelerator == "cpu":
  251. self.device = torch.device("cpu")
  252. else:
  253. raise ValueError(f"Unsupported accelerator: {accelerator}")
  254. def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
  255. assert isinstance(self.model, torch.nn.Module)
  256. self.model = nn.parallel.DistributedDataParallel(
  257. self.model,
  258. device_ids=[self.local_rank] if accelerator == "cuda" else [],
  259. find_unused_parameters=distributed_conf.find_unused_parameters,
  260. gradient_as_bucket_view=distributed_conf.gradient_as_bucket_view,
  261. static_graph=distributed_conf.static_graph,
  262. )
  263. if distributed_conf.comms_dtype is not None: # noqa
  264. from torch.distributed.algorithms import ddp_comm_hooks
  265. amp_type = get_amp_type(distributed_conf.comms_dtype)
  266. if amp_type == torch.bfloat16:
  267. hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
  268. logging.info("Enabling bfloat16 grad communication")
  269. else:
  270. hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
  271. logging.info("Enabling fp16 grad communication")
  272. process_group = None
  273. self.model.register_comm_hook(process_group, hook)
  274. def _move_to_device(self):
  275. logging.info(
  276. f"Moving components to device {self.device} and local rank {self.local_rank}."
  277. )
  278. self.model.to(self.device)
  279. logging.info(
  280. f"Done moving components to device {self.device} and local rank {self.local_rank}."
  281. )
  282. def save_checkpoint(self, epoch, checkpoint_names=None):
  283. if self.skip_saving_ckpts:
  284. logging.info(
  285. "skip_saving_ckpts is set to True. So, no checkpoints have been saved."
  286. )
  287. return
  288. checkpoint_folder = self.checkpoint_conf.save_dir
  289. makedir(checkpoint_folder)
  290. if checkpoint_names is None:
  291. checkpoint_names = ["checkpoint"]
  292. if (
  293. self.checkpoint_conf.save_freq > 0
  294. and (int(epoch) % self.checkpoint_conf.save_freq == 0)
  295. ) or int(epoch) in self.checkpoint_conf.save_list:
  296. checkpoint_names.append(f"checkpoint_{int(epoch)}")
  297. checkpoint_paths = []
  298. for ckpt_name in checkpoint_names:
  299. checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
  300. state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
  301. state_dict = exclude_params_matching_unix_pattern(
  302. patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
  303. )
  304. checkpoint = {
  305. "model": state_dict,
  306. "optimizer": self.optim.optimizer.state_dict(),
  307. "epoch": epoch,
  308. "loss": self.loss.state_dict(),
  309. "steps": self.steps,
  310. "time_elapsed": self.time_elapsed_meter.val,
  311. "best_meter_values": self.best_meter_values,
  312. }
  313. if self.optim_conf.amp.enabled:
  314. checkpoint["scaler"] = self.scaler.state_dict()
  315. # DDP checkpoints are only saved on rank 0 (all workers are identical)
  316. if self.distributed_rank != 0:
  317. return
  318. for checkpoint_path in checkpoint_paths:
  319. self._save_checkpoint(checkpoint, checkpoint_path)
  320. def _save_checkpoint(self, checkpoint, checkpoint_path):
  321. """
  322. Save a checkpoint while guarding against the job being killed in the middle
  323. of checkpoint saving (which corrupts the checkpoint file and ruins the
  324. entire training since usually only the last checkpoint is kept per run).
  325. We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
  326. and move it to overwrite the old checkpoint_path.
  327. """
  328. checkpoint_path_tmp = f"{checkpoint_path}.tmp"
  329. with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
  330. torch.save(checkpoint, f)
  331. # after torch.save is completed, replace the old checkpoint with the new one
  332. if g_pathmgr.exists(checkpoint_path):
  333. # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
  334. g_pathmgr.rm(checkpoint_path)
  335. success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
  336. assert success
  337. def load_checkpoint(self):
  338. ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
  339. if ckpt_path is None:
  340. self._init_model_state()
  341. else:
  342. if self.checkpoint_conf.initialize_after_preemption:
  343. self._call_model_initializer()
  344. self._load_resuming_checkpoint(ckpt_path)
  345. def _init_model_state(self):
  346. # Checking that parameters that won't be saved are indeed frozen
  347. # We do this check here before even saving the model to catch errors
  348. # are early as possible and not at the end of the first epoch
  349. assert_skipped_parameters_are_frozen(
  350. patterns=self.checkpoint_conf.skip_saving_parameters,
  351. model=self.model,
  352. )
  353. # Checking that parameters that won't be saved are initialized from
  354. # within the model definition, unless `initialize_after_preemption`
  355. # is explicitly set to `True`. If not, this is a bug, and after
  356. # preemption, the `skip_saving_parameters` will have random values
  357. allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
  358. with with_check_parameter_frozen(
  359. patterns=self.checkpoint_conf.skip_saving_parameters,
  360. model=self.model,
  361. disabled=allow_init_skip_parameters,
  362. ):
  363. self._call_model_initializer()
  364. def _call_model_initializer(self):
  365. model_weight_initializer = instantiate(
  366. self.checkpoint_conf.model_weight_initializer
  367. )
  368. if model_weight_initializer is not None:
  369. logging.info(
  370. f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}"
  371. )
  372. self.model = model_weight_initializer(model=self.model)
  373. def _load_resuming_checkpoint(self, ckpt_path: str):
  374. logging.info(f"Resuming training from {ckpt_path}")
  375. with g_pathmgr.open(ckpt_path, "rb") as f:
  376. checkpoint = torch.load(f, map_location="cpu")
  377. load_state_dict_into_model(
  378. model=self.model,
  379. state_dict=checkpoint["model"],
  380. ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
  381. )
  382. self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
  383. self.loss.load_state_dict(checkpoint["loss"], strict=True)
  384. self.epoch = checkpoint["epoch"]
  385. self.steps = checkpoint["steps"]
  386. self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
  387. if self.optim_conf.amp.enabled and "scaler" in checkpoint:
  388. self.scaler.load_state_dict(checkpoint["scaler"])
  389. self.best_meter_values = checkpoint.get("best_meter_values", {})
  390. if "train_dataset" in checkpoint and self.train_dataset is not None:
  391. self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
  392. def is_intermediate_val_epoch(self, epoch):
  393. skip_epoch = self.skip_first_val and epoch == 0
  394. return (
  395. epoch % self.val_epoch_freq == 0
  396. and epoch < self.max_epochs - 1
  397. and not skip_epoch
  398. )
  399. def _find_loss(self, key: str):
  400. if key in self.loss:
  401. return self.loss[key]
  402. assert key != "all", "Loss must be specified for key='all'"
  403. assert "default" in self.loss, (
  404. f"Key {key} not found in losss, and no default provided"
  405. )
  406. return self.loss["default"]
  407. def _find_meter(self, phase: str, key: str):
  408. if key in self.meters[phase]:
  409. return self.meters[phase][key]
  410. for cand_key, meter in self.meters[phase].items():
  411. if fnmatch.fnmatch(key, cand_key):
  412. return meter
  413. return None
  414. def _step(
  415. self,
  416. batch: BatchedDatapoint,
  417. model: nn.Module,
  418. phase: str,
  419. ):
  420. key, batch = batch.popitem()
  421. batch = copy_data_to_device(batch, self.device, non_blocking=True)
  422. find_stages = model(batch)
  423. find_targets = [
  424. unwrap_ddp_if_wrapped(model).back_convert(x) for x in batch.find_targets
  425. ]
  426. batch_size = len(batch.img_batch)
  427. loss = self._find_loss(key)(find_stages, find_targets)
  428. loss_str = f"Losses/{phase}_{key}_loss"
  429. loss_log_str = os.path.join("Step_Losses", loss_str)
  430. # loss contains multiple sub-components we wish to log
  431. step_losses = {}
  432. if isinstance(loss, dict):
  433. step_losses.update(
  434. {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()}
  435. )
  436. loss = self._log_loss_detailed_and_return_core_loss(
  437. loss, loss_log_str, self.steps[phase]
  438. )
  439. if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
  440. self.logger.log(
  441. loss_log_str,
  442. loss,
  443. self.steps[phase],
  444. )
  445. self.steps[phase] += 1
  446. ret_tuple = {loss_str: loss}, batch_size, step_losses
  447. if phase not in self.meters:
  448. return ret_tuple
  449. meters_dict = self._find_meter(phase, key)
  450. if meters_dict is None:
  451. return ret_tuple
  452. if meters_dict is not None:
  453. for _, meter in meters_dict.items():
  454. meter.update(
  455. find_stages=find_stages,
  456. find_metadatas=batch.find_metadatas,
  457. model=model,
  458. batch=batch,
  459. key=key,
  460. )
  461. # Cleanup memory
  462. if isinstance(find_stages, SAM3Output):
  463. for fs in find_stages:
  464. for k in list(fs.keys()):
  465. del fs[k]
  466. return ret_tuple
  467. def run(self):
  468. assert self.mode in ["train", "train_only", "val"]
  469. if self.mode == "train":
  470. if self.epoch > 0:
  471. logging.info(f"Resuming training from epoch: {self.epoch}")
  472. # resuming from a checkpoint
  473. if self.is_intermediate_val_epoch(self.epoch - 1):
  474. logging.info("Running previous val epoch")
  475. self.epoch -= 1
  476. self.run_val()
  477. self.epoch += 1
  478. self.run_train()
  479. self.run_val()
  480. elif self.mode == "val":
  481. self.run_val()
  482. elif self.mode == "train_only":
  483. self.run_train()
  484. def _setup_dataloaders(self):
  485. self.train_dataset = None
  486. self.val_dataset = None
  487. if self.mode in ["train", "val"]:
  488. self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
  489. if self.mode in ["train", "train_only"]:
  490. self.train_dataset = instantiate(self.data_conf.train)
  491. def run_train(self):
  492. while self.epoch < self.max_epochs:
  493. dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
  494. barrier()
  495. outs = self.train_epoch(dataloader)
  496. self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
  497. # log train to text file.
  498. if self.distributed_rank == 0:
  499. with g_pathmgr.open(
  500. os.path.join(self.logging_conf.log_dir, "train_stats.json"),
  501. "a",
  502. ) as f:
  503. f.write(json.dumps(outs) + "\n")
  504. # Save checkpoint before validating
  505. self.save_checkpoint(self.epoch + 1)
  506. del dataloader
  507. gc.collect()
  508. # Run val, not running on last epoch since will run after the
  509. # loop anyway
  510. if self.is_intermediate_val_epoch(self.epoch):
  511. self.run_val()
  512. if torch.cuda.is_available() and self.empty_gpu_mem_cache_after_eval:
  513. # release memory buffers held by the model during eval (which typically
  514. # involves a lot more frames in video grounding that during training)
  515. torch.cuda.empty_cache()
  516. if self.distributed_rank == 0:
  517. self.best_meter_values.update(self._get_trainer_state("train"))
  518. with g_pathmgr.open(
  519. os.path.join(self.logging_conf.log_dir, "best_stats.json"),
  520. "a",
  521. ) as f:
  522. f.write(json.dumps(self.best_meter_values) + "\n")
  523. self.epoch += 1
  524. # epoch was incremented in the loop but the val step runs out of the loop
  525. self.epoch -= 1
  526. def run_val(self):
  527. if not self.val_dataset:
  528. return
  529. dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
  530. outs = self.val_epoch(dataloader, phase=Phase.VAL)
  531. del dataloader
  532. gc.collect()
  533. self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
  534. if self.distributed_rank == 0:
  535. with g_pathmgr.open(
  536. os.path.join(self.logging_conf.log_dir, "val_stats.json"),
  537. "a",
  538. ) as f:
  539. f.write(json.dumps(outs) + "\n")
  540. def val_epoch(self, val_loader, phase):
  541. batch_time = AverageMeter("Batch Time", self.device, ":.2f")
  542. data_time = AverageMeter("Data Time", self.device, ":.2f")
  543. mem = MemMeter("Mem (GB)", self.device, ":.2f")
  544. iters_per_epoch = len(val_loader)
  545. curr_phases = [phase]
  546. curr_models = [self.model]
  547. loss_names = []
  548. for p in curr_phases:
  549. for key in self.loss.keys():
  550. loss_names.append(f"Losses/{p}_{key}_loss")
  551. loss_mts = OrderedDict(
  552. [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
  553. )
  554. extra_loss_mts = {}
  555. for model in curr_models:
  556. model.eval()
  557. if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
  558. unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
  559. progress = ProgressMeter(
  560. iters_per_epoch,
  561. [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
  562. self._get_meters(curr_phases),
  563. prefix="Val Epoch: [{}]".format(self.epoch),
  564. )
  565. end = time.time()
  566. for data_iter, batch in enumerate(val_loader):
  567. # measure data loading time
  568. data_time.update(time.time() - end)
  569. # batch = batch.to(self.device, non_blocking=True)
  570. # compute output
  571. with torch.no_grad():
  572. with torch.amp.autocast(
  573. device_type="cuda",
  574. enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
  575. dtype=(
  576. get_amp_type(self.optim_conf.amp.amp_dtype)
  577. if self.optim_conf
  578. else None
  579. ),
  580. ):
  581. for phase, model in zip(curr_phases, curr_models):
  582. loss_dict, batch_size, extra_losses = self._step(
  583. batch,
  584. model,
  585. phase,
  586. )
  587. assert len(loss_dict) == 1
  588. loss_key, loss = loss_dict.popitem()
  589. if loss_key in loss_mts:
  590. loss_mts[loss_key].update(loss.item(), batch_size)
  591. for k, v in extra_losses.items():
  592. if k not in extra_loss_mts:
  593. extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
  594. extra_loss_mts[k].update(v.item(), batch_size)
  595. # measure elapsed time
  596. batch_time.update(time.time() - end)
  597. end = time.time()
  598. self.time_elapsed_meter.update(
  599. time.time() - self.start_time + self.ckpt_time_elapsed
  600. )
  601. if torch.cuda.is_available():
  602. mem.update(reset_peak_usage=True)
  603. if data_iter % self.logging_conf.log_freq == 0:
  604. progress.display(data_iter)
  605. if data_iter % self.logging_conf.log_scalar_frequency == 0:
  606. # Log progress meters.
  607. for progress_meter in progress.meters:
  608. self.logger.log(
  609. os.path.join("Step_Stats", phase, progress_meter.name),
  610. progress_meter.val,
  611. self.steps[Phase.VAL],
  612. )
  613. if data_iter % 10 == 0:
  614. dist.barrier()
  615. self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
  616. self._log_timers(phase)
  617. for model in curr_models:
  618. if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
  619. unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
  620. out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
  621. for k, v in loss_mts.items():
  622. out_dict[k] = v.avg
  623. for k, v in extra_loss_mts.items():
  624. out_dict[k] = v.avg
  625. for phase in curr_phases:
  626. out_dict.update(self._get_trainer_state(phase))
  627. self._reset_meters(curr_phases)
  628. logging.info(f"Meters: {out_dict}")
  629. return out_dict
  630. def _get_trainer_state(self, phase):
  631. return {
  632. "Trainer/where": self.where,
  633. "Trainer/epoch": self.epoch,
  634. f"Trainer/steps_{phase}": self.steps[phase],
  635. }
  636. def train_epoch(self, train_loader):
  637. # Init stat meters
  638. batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
  639. data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
  640. mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
  641. data_times = []
  642. phase = Phase.TRAIN
  643. iters_per_epoch = len(train_loader)
  644. loss_names = []
  645. for batch_key in self.loss.keys():
  646. loss_names.append(f"Losses/{phase}_{batch_key}_loss")
  647. loss_mts = OrderedDict(
  648. [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
  649. )
  650. extra_loss_mts = {}
  651. progress = ProgressMeter(
  652. iters_per_epoch,
  653. [
  654. batch_time_meter,
  655. data_time_meter,
  656. mem_meter,
  657. self.time_elapsed_meter,
  658. *loss_mts.values(),
  659. ],
  660. self._get_meters([phase]),
  661. prefix="Train Epoch: [{}]".format(self.epoch),
  662. )
  663. # Model training loop
  664. self.model.train()
  665. end = time.time()
  666. for data_iter, batch in enumerate(train_loader):
  667. # measure data loading time
  668. data_time_meter.update(time.time() - end)
  669. data_times.append(data_time_meter.val)
  670. # batch = batch.to(
  671. # self.device, non_blocking=True
  672. # ) # move tensors in a tensorclass
  673. try:
  674. self._run_step(batch, phase, loss_mts, extra_loss_mts)
  675. # compute gradient and do optim step
  676. exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
  677. self.where = float(exact_epoch) / self.max_epochs
  678. assert self.where <= 1 + self.EPSILON
  679. if self.where < 1.0:
  680. self.optim.step_schedulers(
  681. self.where, step=int(exact_epoch * iters_per_epoch)
  682. )
  683. else:
  684. logging.warning(
  685. f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
  686. )
  687. # Log schedulers
  688. if data_iter % self.logging_conf.log_scalar_frequency == 0:
  689. for j, param_group in enumerate(self.optim.optimizer.param_groups):
  690. for option in self.optim.schedulers[j]:
  691. optim_prefix = (
  692. "" + f"{j}_"
  693. if len(self.optim.optimizer.param_groups) > 1
  694. else ""
  695. )
  696. self.logger.log(
  697. os.path.join("Optim", f"{optim_prefix}", option),
  698. param_group[option],
  699. self.steps[phase],
  700. )
  701. # Clipping gradients and detecting diverging gradients
  702. if self.gradient_clipper is not None:
  703. self.scaler.unscale_(self.optim.optimizer)
  704. self.gradient_clipper(model=self.model)
  705. if self.gradient_logger is not None:
  706. self.gradient_logger(
  707. self.model, rank=self.distributed_rank, where=self.where
  708. )
  709. # Optimizer step: the scaler will make sure gradients are not
  710. # applied if the gradients are infinite
  711. self.scaler.step(self.optim.optimizer)
  712. self.scaler.update()
  713. # measure elapsed time
  714. batch_time_meter.update(time.time() - end)
  715. end = time.time()
  716. self.time_elapsed_meter.update(
  717. time.time() - self.start_time + self.ckpt_time_elapsed
  718. )
  719. mem_meter.update(reset_peak_usage=True)
  720. if data_iter % self.logging_conf.log_freq == 0:
  721. progress.display(data_iter)
  722. if data_iter % self.logging_conf.log_scalar_frequency == 0:
  723. # Log progress meters.
  724. for progress_meter in progress.meters:
  725. self.logger.log(
  726. os.path.join("Step_Stats", phase, progress_meter.name),
  727. progress_meter.val,
  728. self.steps[phase],
  729. )
  730. # Catching NaN/Inf errors in the loss
  731. except FloatingPointError as e:
  732. raise e
  733. self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
  734. self._log_timers(Phase.TRAIN)
  735. self._log_sync_data_times(Phase.TRAIN, data_times)
  736. out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
  737. for k, v in loss_mts.items():
  738. out_dict[k] = v.avg
  739. for k, v in extra_loss_mts.items():
  740. out_dict[k] = v.avg
  741. out_dict.update(self._get_trainer_state(phase))
  742. logging.info(f"Losses and meters: {out_dict}")
  743. self._reset_meters([phase])
  744. return out_dict
  745. def _log_sync_data_times(self, phase, data_times):
  746. data_times = all_reduce_max(torch.tensor(data_times)).tolist()
  747. steps = range(self.steps[phase] - len(data_times), self.steps[phase])
  748. for step, data_time in zip(steps, data_times):
  749. if step % self.logging_conf.log_scalar_frequency == 0:
  750. self.logger.log(
  751. os.path.join("Step_Stats", phase, "Data Time Synced"),
  752. data_time,
  753. step,
  754. )
  755. def _run_step(
  756. self,
  757. batch: BatchedDatapoint,
  758. phase: str,
  759. loss_mts: Dict[str, AverageMeter],
  760. extra_loss_mts: Dict[str, AverageMeter],
  761. raise_on_error: bool = True,
  762. ):
  763. """
  764. Run the forward / backward
  765. """
  766. # it's important to set grads to None, especially with Adam since 0
  767. # grads will also update a model even if the step doesn't produce
  768. # gradients
  769. self.optim.zero_grad(set_to_none=True)
  770. if self.gradient_accumulation_steps > 1:
  771. assert isinstance(batch, list), (
  772. f"Expected a list of batches, got {type(batch)}"
  773. )
  774. assert len(batch) == self.gradient_accumulation_steps, (
  775. f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
  776. )
  777. accum_steps = len(batch)
  778. else:
  779. accum_steps = 1
  780. batch = [batch]
  781. for i, chunked_batch in enumerate(batch):
  782. ddp_context = (
  783. self.model.no_sync()
  784. if i < accum_steps - 1
  785. else contextlib.nullcontext()
  786. )
  787. with ddp_context:
  788. with torch.amp.autocast(
  789. device_type="cuda",
  790. enabled=self.optim_conf.amp.enabled,
  791. dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
  792. ):
  793. loss_dict, batch_size, extra_losses = self._step(
  794. chunked_batch,
  795. self.model,
  796. phase,
  797. )
  798. assert len(loss_dict) == 1
  799. loss_key, loss = loss_dict.popitem()
  800. if not math.isfinite(loss.item()):
  801. error_msg = f"Loss is {loss.item()}, attempting to stop training"
  802. logging.error(error_msg)
  803. if raise_on_error:
  804. raise FloatingPointError(error_msg)
  805. else:
  806. return
  807. self.scaler.scale(loss).backward()
  808. loss_mts[loss_key].update(loss.item(), batch_size)
  809. for extra_loss_key, extra_loss in extra_losses.items():
  810. if extra_loss_key not in extra_loss_mts:
  811. extra_loss_mts[extra_loss_key] = AverageMeter(
  812. extra_loss_key, self.device, ":.2e"
  813. )
  814. extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
  815. def _log_meters_and_save_best_ckpts(self, phases: List[str]):
  816. logging.info("Synchronizing meters")
  817. out_dict = {}
  818. checkpoint_save_keys = []
  819. for key, meter in self._get_meters(phases).items():
  820. meter_output = meter.compute_synced()
  821. is_better_check = getattr(meter, "is_better", None)
  822. for meter_subkey, meter_value in meter_output.items():
  823. out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
  824. if is_better_check is None:
  825. continue
  826. tracked_meter_key = os.path.join(key, meter_subkey)
  827. if tracked_meter_key not in self.best_meter_values or is_better_check(
  828. meter_value,
  829. self.best_meter_values[tracked_meter_key],
  830. ):
  831. self.best_meter_values[tracked_meter_key] = meter_value
  832. if (
  833. self.checkpoint_conf.save_best_meters is not None
  834. and key in self.checkpoint_conf.save_best_meters
  835. ):
  836. checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
  837. if len(checkpoint_save_keys) > 0:
  838. self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
  839. return out_dict
  840. def _log_timers(self, phase):
  841. time_remaining = 0
  842. epochs_remaining = self.max_epochs - self.epoch - 1
  843. val_epochs_remaining = sum(
  844. n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs)
  845. )
  846. # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
  847. # the end epoch.
  848. if (self.max_epochs - 1) % self.val_epoch_freq != 0:
  849. val_epochs_remaining += 1
  850. # Remove the current val run from estimate
  851. if phase == Phase.VAL:
  852. val_epochs_remaining -= 1
  853. time_remaining += (
  854. epochs_remaining * self.est_epoch_time[Phase.TRAIN]
  855. + val_epochs_remaining * self.est_epoch_time[Phase.VAL]
  856. )
  857. self.logger.log(
  858. os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
  859. self.time_elapsed_meter.val,
  860. self.steps[phase],
  861. )
  862. logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
  863. def _reset_meters(self, phases: str) -> None:
  864. for meter in self._get_meters(phases).values():
  865. meter.reset()
  866. def _check_val_key_match(self, val_keys, phase):
  867. if val_keys is not None:
  868. # Check if there are any duplicates
  869. assert len(val_keys) == len(set(val_keys)), (
  870. f"Duplicate keys in val datasets, keys: {val_keys}"
  871. )
  872. # Check that the keys match the meter keys
  873. if self.meters_conf is not None and phase in self.meters_conf:
  874. assert set(val_keys) == set(self.meters_conf[phase].keys()), (
  875. f"Keys in val datasets do not match the keys in meters."
  876. f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
  877. f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
  878. )
  879. if self.loss_conf is not None:
  880. loss_keys = set(self.loss_conf.keys()) - set(["all"])
  881. if "default" not in loss_keys:
  882. for k in val_keys:
  883. assert k in loss_keys, (
  884. f"Error: key {k} is not defined in the losses, and no default is set"
  885. )
  886. def _setup_components(self):
  887. # Get the keys for all the val datasets, if any
  888. val_phase = Phase.VAL
  889. val_keys = None
  890. if self.data_conf.get(val_phase, None) is not None:
  891. val_keys = collect_dict_keys(self.data_conf[val_phase])
  892. # Additional checks on the sanity of the config for val datasets
  893. self._check_val_key_match(val_keys, phase=val_phase)
  894. logging.info("Setting up components: Model, loss, optim, meters etc.")
  895. self.epoch = 0
  896. self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
  897. self.logger = Logger(self.logging_conf)
  898. self.model = instantiate(self.model_conf, _convert_="all")
  899. print_model_summary(self.model)
  900. self.loss = None
  901. if self.loss_conf:
  902. self.loss = {
  903. key: el # wrap_base_loss(el)
  904. for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
  905. }
  906. self.loss = nn.ModuleDict(self.loss)
  907. self.meters = {}
  908. self.best_meter_values = {}
  909. if self.meters_conf:
  910. self.meters = instantiate(self.meters_conf, _convert_="all")
  911. self.scaler = torch.amp.GradScaler(
  912. self.device,
  913. enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
  914. )
  915. self.gradient_clipper = (
  916. instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
  917. )
  918. self.gradient_logger = (
  919. instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
  920. )
  921. logging.info("Finished setting up components: Model, loss, optim, meters etc.")
  922. def _construct_optimizers(self):
  923. self.optim = construct_optimizer(
  924. self.model,
  925. self.optim_conf.optimizer,
  926. self.optim_conf.options,
  927. self.optim_conf.param_group_modifiers,
  928. )
  929. def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
  930. core_loss = loss.pop(CORE_LOSS_KEY)
  931. if step % self.logging_conf.log_scalar_frequency == 0:
  932. for k in loss:
  933. log_str = os.path.join(loss_str, k)
  934. self.logger.log(log_str, loss[k], step)
  935. return core_loss
  936. def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
  937. """
  938. Prints the model and the number of parameters in the model.
  939. # Multiple packages provide this info in a nice table format
  940. # However, they need us to provide an `input` (as they also write down the output sizes)
  941. # Our models are complex, and a single input is restrictive.
  942. # https://github.com/sksq96/pytorch-summary
  943. # https://github.com/nmhkahn/torchsummaryX
  944. """
  945. if get_rank() != 0:
  946. return
  947. param_kwargs = {}
  948. trainable_parameters = sum(
  949. p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
  950. )
  951. total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
  952. non_trainable_parameters = total_parameters - trainable_parameters
  953. logging.info("==" * 10)
  954. logging.info(f"Summary for model {type(model)}")
  955. logging.info(f"Model is {model}")
  956. logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
  957. logging.info(
  958. f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
  959. )
  960. logging.info(
  961. f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
  962. )
  963. logging.info("==" * 10)
  964. if log_dir:
  965. output_fpath = os.path.join(log_dir, "model.txt")
  966. with g_pathmgr.open(output_fpath, "w") as f:
  967. print(model, file=f)
  968. PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
  969. def get_human_readable_count(number: int) -> str:
  970. """
  971. Abbreviates an integer number with K, M, B, T for thousands, millions,
  972. billions and trillions, respectively.
  973. Examples:
  974. >>> get_human_readable_count(123)
  975. '123 '
  976. >>> get_human_readable_count(1234) # (one thousand)
  977. '1.2 K'
  978. >>> get_human_readable_count(2e6) # (two million)
  979. '2.0 M'
  980. >>> get_human_readable_count(3e9) # (three billion)
  981. '3.0 B'
  982. >>> get_human_readable_count(4e14) # (four hundred trillion)
  983. '400 T'
  984. >>> get_human_readable_count(5e15) # (more than trillion)
  985. '5,000 T'
  986. Args:
  987. number: a positive integer number
  988. Return:
  989. A string formatted according to the pattern described above.
  990. """
  991. assert number >= 0
  992. labels = PARAMETER_NUM_UNITS
  993. num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
  994. num_groups = int(np.ceil(num_digits / 3))
  995. num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
  996. shift = -3 * (num_groups - 1)
  997. number = number * (10**shift)
  998. index = num_groups - 1
  999. if index < 1 or number >= 100:
  1000. return f"{int(number):,d} {labels[index]}"
  1001. else:
  1002. return f"{number:,.1f} {labels[index]}"