trainer.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import gc
  6. import json
  7. import logging
  8. import math
  9. import os
  10. import time
  11. from collections import OrderedDict
  12. from dataclasses import dataclass, field
  13. from typing import Any, Dict, List, Mapping, Optional
  14. import numpy as np
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. from hydra.utils import instantiate
  19. from iopath.common.file_io import g_pathmgr
  20. from training.optimizer import construct_optimizer
  21. from training.utils.checkpoint_utils import (
  22. assert_skipped_parameters_are_frozen,
  23. exclude_params_matching_unix_pattern,
  24. load_state_dict_into_model,
  25. with_check_parameter_frozen,
  26. )
  27. from training.utils.data_utils import BatchedVideoDatapoint
  28. from training.utils.distributed import all_reduce_max, barrier, get_rank
  29. from training.utils.logger import Logger, setup_logging
  30. from training.utils.train_utils import (
  31. AverageMeter,
  32. collect_dict_keys,
  33. DurationMeter,
  34. get_amp_type,
  35. get_machine_local_and_dist_rank,
  36. get_resume_checkpoint,
  37. human_readable_time,
  38. is_dist_avail_and_initialized,
  39. log_env_variables,
  40. makedir,
  41. MemMeter,
  42. Phase,
  43. ProgressMeter,
  44. set_seeds,
  45. setup_distributed_backend,
  46. )
  47. CORE_LOSS_KEY = "core_loss"
  48. def unwrap_ddp_if_wrapped(model):
  49. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  50. return model.module
  51. return model
  52. @dataclass
  53. class OptimAMPConf:
  54. enabled: bool = False
  55. amp_dtype: str = "float16"
  56. @dataclass
  57. class OptimConf:
  58. optimizer: torch.optim.Optimizer = None
  59. options: Optional[Dict[str, Any]] = None
  60. param_group_modifiers: Optional[List] = None
  61. amp: Optional[Dict[str, Any]] = None
  62. gradient_clip: Any = None
  63. gradient_logger: Any = None
  64. def __post_init__(self):
  65. # amp
  66. if not isinstance(self.amp, OptimAMPConf):
  67. if self.amp is None:
  68. self.amp = {}
  69. assert isinstance(self.amp, Mapping)
  70. self.amp = OptimAMPConf(**self.amp)
  71. @dataclass
  72. class DistributedConf:
  73. backend: Optional[str] = None # inferred from accelerator type
  74. comms_dtype: Optional[str] = None
  75. find_unused_parameters: bool = False
  76. timeout_mins: int = 30
  77. @dataclass
  78. class CudaConf:
  79. cudnn_deterministic: bool = False
  80. cudnn_benchmark: bool = True
  81. allow_tf32: bool = False
  82. # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
  83. matmul_allow_tf32: Optional[bool] = None
  84. # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
  85. cudnn_allow_tf32: Optional[bool] = None
  86. @dataclass
  87. class CheckpointConf:
  88. save_dir: str
  89. save_freq: int
  90. save_list: List[int] = field(default_factory=list)
  91. model_weight_initializer: Any = None
  92. save_best_meters: List[str] = None
  93. skip_saving_parameters: List[str] = field(default_factory=list)
  94. initialize_after_preemption: Optional[bool] = None
  95. # if not None, training will be resumed from this checkpoint
  96. resume_from: Optional[str] = None
  97. def infer_missing(self):
  98. if self.initialize_after_preemption is None:
  99. with_skip_saving = len(self.skip_saving_parameters) > 0
  100. self.initialize_after_preemption = with_skip_saving
  101. return self
  102. @dataclass
  103. class LoggingConf:
  104. log_dir: str
  105. log_freq: int # In iterations
  106. tensorboard_writer: Any
  107. log_level_primary: str = "INFO"
  108. log_level_secondary: str = "ERROR"
  109. log_scalar_frequency: int = 100
  110. log_visual_frequency: int = 100
  111. scalar_keys_to_log: Optional[Dict[str, Any]] = None
  112. log_batch_stats: bool = False
  113. class Trainer:
  114. """
  115. Trainer supporting the DDP training strategies.
  116. """
  117. EPSILON = 1e-8
  118. def __init__(
  119. self,
  120. *, # the order of these args can change at any time, so they are keyword-only
  121. data: Dict[str, Any],
  122. model: Dict[str, Any],
  123. logging: Dict[str, Any],
  124. checkpoint: Dict[str, Any],
  125. max_epochs: int,
  126. mode: str = "train",
  127. accelerator: str = "cuda",
  128. seed_value: int = 123,
  129. val_epoch_freq: int = 1,
  130. distributed: Dict[str, bool] = None,
  131. cuda: Dict[str, bool] = None,
  132. env_variables: Optional[Dict[str, Any]] = None,
  133. optim: Optional[Dict[str, Any]] = None,
  134. optim_overrides: Optional[List[Dict[str, Any]]] = None,
  135. meters: Optional[Dict[str, Any]] = None,
  136. loss: Optional[Dict[str, Any]] = None,
  137. ):
  138. self._setup_env_variables(env_variables)
  139. self._setup_timers()
  140. self.data_conf = data
  141. self.model_conf = model
  142. self.logging_conf = LoggingConf(**logging)
  143. self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
  144. self.max_epochs = max_epochs
  145. self.mode = mode
  146. self.val_epoch_freq = val_epoch_freq
  147. self.optim_conf = OptimConf(**optim) if optim is not None else None
  148. self.meters_conf = meters
  149. self.loss_conf = loss
  150. distributed = DistributedConf(**distributed or {})
  151. cuda = CudaConf(**cuda or {})
  152. self.where = 0.0
  153. self._infer_distributed_backend_if_none(distributed, accelerator)
  154. self._setup_device(accelerator)
  155. self._setup_torch_dist_and_backend(cuda, distributed)
  156. makedir(self.logging_conf.log_dir)
  157. setup_logging(
  158. __name__,
  159. output_dir=self.logging_conf.log_dir,
  160. rank=self.rank,
  161. log_level_primary=self.logging_conf.log_level_primary,
  162. log_level_secondary=self.logging_conf.log_level_secondary,
  163. )
  164. set_seeds(seed_value, self.max_epochs, self.distributed_rank)
  165. log_env_variables()
  166. assert (
  167. is_dist_avail_and_initialized()
  168. ), "Torch distributed needs to be initialized before calling the trainer."
  169. self._setup_components() # Except Optimizer everything is setup here.
  170. self._move_to_device()
  171. self._construct_optimizers()
  172. self._setup_dataloaders()
  173. self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
  174. if self.checkpoint_conf.resume_from is not None:
  175. assert os.path.exists(
  176. self.checkpoint_conf.resume_from
  177. ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
  178. dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
  179. if self.distributed_rank == 0 and not os.path.exists(dst):
  180. # Copy the "resume_from" checkpoint to the checkpoint folder
  181. # if there is not a checkpoint to resume from already there
  182. makedir(self.checkpoint_conf.save_dir)
  183. g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
  184. barrier()
  185. self.load_checkpoint()
  186. self._setup_ddp_distributed_training(distributed, accelerator)
  187. barrier()
  188. def _setup_timers(self):
  189. """
  190. Initializes counters for elapsed time and eta.
  191. """
  192. self.start_time = time.time()
  193. self.ckpt_time_elapsed = 0
  194. self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
  195. def _get_meters(self, phase_filters=None):
  196. if self.meters is None:
  197. return {}
  198. meters = {}
  199. for phase, phase_meters in self.meters.items():
  200. if phase_filters is not None and phase not in phase_filters:
  201. continue
  202. for key, key_meters in phase_meters.items():
  203. if key_meters is None:
  204. continue
  205. for name, meter in key_meters.items():
  206. meters[f"{phase}_{key}/{name}"] = meter
  207. return meters
  208. def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
  209. if distributed_conf.backend is None:
  210. distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
  211. def _setup_env_variables(self, env_variables_conf) -> None:
  212. if env_variables_conf is not None:
  213. for variable_name, value in env_variables_conf.items():
  214. os.environ[variable_name] = value
  215. def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
  216. if torch.cuda.is_available():
  217. torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
  218. torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
  219. torch.backends.cuda.matmul.allow_tf32 = (
  220. cuda_conf.matmul_allow_tf32
  221. if cuda_conf.matmul_allow_tf32 is not None
  222. else cuda_conf.allow_tf32
  223. )
  224. torch.backends.cudnn.allow_tf32 = (
  225. cuda_conf.cudnn_allow_tf32
  226. if cuda_conf.cudnn_allow_tf32 is not None
  227. else cuda_conf.allow_tf32
  228. )
  229. self.rank = setup_distributed_backend(
  230. distributed_conf.backend, distributed_conf.timeout_mins
  231. )
  232. def _setup_device(self, accelerator):
  233. self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
  234. if accelerator == "cuda":
  235. self.device = torch.device("cuda", self.local_rank)
  236. torch.cuda.set_device(self.local_rank)
  237. elif accelerator == "cpu":
  238. self.device = torch.device("cpu")
  239. else:
  240. raise ValueError(f"Unsupported accelerator: {accelerator}")
  241. def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
  242. assert isinstance(self.model, torch.nn.Module)
  243. self.model = nn.parallel.DistributedDataParallel(
  244. self.model,
  245. device_ids=[self.local_rank] if accelerator == "cuda" else [],
  246. find_unused_parameters=distributed_conf.find_unused_parameters,
  247. )
  248. if distributed_conf.comms_dtype is not None: # noqa
  249. from torch.distributed.algorithms import ddp_comm_hooks
  250. amp_type = get_amp_type(distributed_conf.comms_dtype)
  251. if amp_type == torch.bfloat16:
  252. hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
  253. logging.info("Enabling bfloat16 grad communication")
  254. else:
  255. hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
  256. logging.info("Enabling fp16 grad communication")
  257. process_group = None
  258. self.model.register_comm_hook(process_group, hook)
  259. def _move_to_device(self):
  260. logging.info(
  261. f"Moving components to device {self.device} and local rank {self.local_rank}."
  262. )
  263. self.model.to(self.device)
  264. logging.info(
  265. f"Done moving components to device {self.device} and local rank {self.local_rank}."
  266. )
  267. def save_checkpoint(self, epoch, checkpoint_names=None):
  268. checkpoint_folder = self.checkpoint_conf.save_dir
  269. makedir(checkpoint_folder)
  270. if checkpoint_names is None:
  271. checkpoint_names = ["checkpoint"]
  272. if (
  273. self.checkpoint_conf.save_freq > 0
  274. and (int(epoch) % self.checkpoint_conf.save_freq == 0)
  275. ) or int(epoch) in self.checkpoint_conf.save_list:
  276. checkpoint_names.append(f"checkpoint_{int(epoch)}")
  277. checkpoint_paths = []
  278. for ckpt_name in checkpoint_names:
  279. checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
  280. state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
  281. state_dict = exclude_params_matching_unix_pattern(
  282. patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
  283. )
  284. checkpoint = {
  285. "model": state_dict,
  286. "optimizer": self.optim.optimizer.state_dict(),
  287. "epoch": epoch,
  288. "loss": self.loss.state_dict(),
  289. "steps": self.steps,
  290. "time_elapsed": self.time_elapsed_meter.val,
  291. "best_meter_values": self.best_meter_values,
  292. }
  293. if self.optim_conf.amp.enabled:
  294. checkpoint["scaler"] = self.scaler.state_dict()
  295. # DDP checkpoints are only saved on rank 0 (all workers are identical)
  296. if self.distributed_rank != 0:
  297. return
  298. for checkpoint_path in checkpoint_paths:
  299. self._save_checkpoint(checkpoint, checkpoint_path)
  300. def _save_checkpoint(self, checkpoint, checkpoint_path):
  301. """
  302. Save a checkpoint while guarding against the job being killed in the middle
  303. of checkpoint saving (which corrupts the checkpoint file and ruins the
  304. entire training since usually only the last checkpoint is kept per run).
  305. We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
  306. and move it to overwrite the old checkpoint_path.
  307. """
  308. checkpoint_path_tmp = f"{checkpoint_path}.tmp"
  309. with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
  310. torch.save(checkpoint, f)
  311. # after torch.save is completed, replace the old checkpoint with the new one
  312. if g_pathmgr.exists(checkpoint_path):
  313. # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
  314. g_pathmgr.rm(checkpoint_path)
  315. success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
  316. assert success
  317. def load_checkpoint(self):
  318. ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
  319. if ckpt_path is None:
  320. self._init_model_state()
  321. else:
  322. if self.checkpoint_conf.initialize_after_preemption:
  323. self._call_model_initializer()
  324. self._load_resuming_checkpoint(ckpt_path)
  325. def _init_model_state(self):
  326. # Checking that parameters that won't be saved are indeed frozen
  327. # We do this check here before even saving the model to catch errors
  328. # are early as possible and not at the end of the first epoch
  329. assert_skipped_parameters_are_frozen(
  330. patterns=self.checkpoint_conf.skip_saving_parameters,
  331. model=self.model,
  332. )
  333. # Checking that parameters that won't be saved are initialized from
  334. # within the model definition, unless `initialize_after_preemption`
  335. # is explicitly set to `True`. If not, this is a bug, and after
  336. # preemption, the `skip_saving_parameters` will have random values
  337. allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
  338. with with_check_parameter_frozen(
  339. patterns=self.checkpoint_conf.skip_saving_parameters,
  340. model=self.model,
  341. disabled=allow_init_skip_parameters,
  342. ):
  343. self._call_model_initializer()
  344. def _call_model_initializer(self):
  345. model_weight_initializer = instantiate(
  346. self.checkpoint_conf.model_weight_initializer
  347. )
  348. if model_weight_initializer is not None:
  349. logging.info(
  350. f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}"
  351. )
  352. self.model = model_weight_initializer(model=self.model)
  353. def _load_resuming_checkpoint(self, ckpt_path: str):
  354. logging.info(f"Resuming training from {ckpt_path}")
  355. with g_pathmgr.open(ckpt_path, "rb") as f:
  356. checkpoint = torch.load(f, map_location="cpu")
  357. load_state_dict_into_model(
  358. model=self.model,
  359. state_dict=checkpoint["model"],
  360. ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
  361. )
  362. self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
  363. self.loss.load_state_dict(checkpoint["loss"], strict=True)
  364. self.epoch = checkpoint["epoch"]
  365. self.steps = checkpoint["steps"]
  366. self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
  367. if self.optim_conf.amp.enabled and "scaler" in checkpoint:
  368. self.scaler.load_state_dict(checkpoint["scaler"])
  369. self.best_meter_values = checkpoint.get("best_meter_values", {})
  370. if "train_dataset" in checkpoint and self.train_dataset is not None:
  371. self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
  372. def is_intermediate_val_epoch(self, epoch):
  373. return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1
  374. def _step(
  375. self,
  376. batch: BatchedVideoDatapoint,
  377. model: nn.Module,
  378. phase: str,
  379. ):
  380. outputs = model(batch)
  381. targets = batch.masks
  382. batch_size = len(batch.img_batch)
  383. key = batch.dict_key # key for dataset
  384. loss = self.loss[key](outputs, targets)
  385. loss_str = f"Losses/{phase}_{key}_loss"
  386. loss_log_str = os.path.join("Step_Losses", loss_str)
  387. # loss contains multiple sub-components we wish to log
  388. step_losses = {}
  389. if isinstance(loss, dict):
  390. step_losses.update(
  391. {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()}
  392. )
  393. loss = self._log_loss_detailed_and_return_core_loss(
  394. loss, loss_log_str, self.steps[phase]
  395. )
  396. if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
  397. self.logger.log(
  398. loss_log_str,
  399. loss,
  400. self.steps[phase],
  401. )
  402. self.steps[phase] += 1
  403. ret_tuple = {loss_str: loss}, batch_size, step_losses
  404. if phase in self.meters and key in self.meters[phase]:
  405. meters_dict = self.meters[phase][key]
  406. if meters_dict is not None:
  407. for _, meter in meters_dict.items():
  408. meter.update(
  409. find_stages=outputs,
  410. find_metadatas=batch.metadata,
  411. )
  412. return ret_tuple
  413. def run(self):
  414. assert self.mode in ["train", "train_only", "val"]
  415. if self.mode == "train":
  416. if self.epoch > 0:
  417. logging.info(f"Resuming training from epoch: {self.epoch}")
  418. # resuming from a checkpoint
  419. if self.is_intermediate_val_epoch(self.epoch - 1):
  420. logging.info("Running previous val epoch")
  421. self.epoch -= 1
  422. self.run_val()
  423. self.epoch += 1
  424. self.run_train()
  425. self.run_val()
  426. elif self.mode == "val":
  427. self.run_val()
  428. elif self.mode == "train_only":
  429. self.run_train()
  430. def _setup_dataloaders(self):
  431. self.train_dataset = None
  432. self.val_dataset = None
  433. if self.mode in ["train", "val"]:
  434. self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
  435. if self.mode in ["train", "train_only"]:
  436. self.train_dataset = instantiate(self.data_conf.train)
  437. def run_train(self):
  438. while self.epoch < self.max_epochs:
  439. dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
  440. barrier()
  441. outs = self.train_epoch(dataloader)
  442. self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
  443. # log train to text file.
  444. if self.distributed_rank == 0:
  445. with g_pathmgr.open(
  446. os.path.join(self.logging_conf.log_dir, "train_stats.json"),
  447. "a",
  448. ) as f:
  449. f.write(json.dumps(outs) + "\n")
  450. # Save checkpoint before validating
  451. self.save_checkpoint(self.epoch + 1)
  452. del dataloader
  453. gc.collect()
  454. # Run val, not running on last epoch since will run after the
  455. # loop anyway
  456. if self.is_intermediate_val_epoch(self.epoch):
  457. self.run_val()
  458. if self.distributed_rank == 0:
  459. self.best_meter_values.update(self._get_trainer_state("train"))
  460. with g_pathmgr.open(
  461. os.path.join(self.logging_conf.log_dir, "best_stats.json"),
  462. "a",
  463. ) as f:
  464. f.write(json.dumps(self.best_meter_values) + "\n")
  465. self.epoch += 1
  466. # epoch was incremented in the loop but the val step runs out of the loop
  467. self.epoch -= 1
  468. def run_val(self):
  469. if not self.val_dataset:
  470. return
  471. dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
  472. outs = self.val_epoch(dataloader, phase=Phase.VAL)
  473. del dataloader
  474. gc.collect()
  475. self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
  476. if self.distributed_rank == 0:
  477. with g_pathmgr.open(
  478. os.path.join(self.logging_conf.log_dir, "val_stats.json"),
  479. "a",
  480. ) as f:
  481. f.write(json.dumps(outs) + "\n")
  482. def val_epoch(self, val_loader, phase):
  483. batch_time = AverageMeter("Batch Time", self.device, ":.2f")
  484. data_time = AverageMeter("Data Time", self.device, ":.2f")
  485. mem = MemMeter("Mem (GB)", self.device, ":.2f")
  486. iters_per_epoch = len(val_loader)
  487. curr_phases = [phase]
  488. curr_models = [self.model]
  489. loss_names = []
  490. for p in curr_phases:
  491. for key in self.loss.keys():
  492. loss_names.append(f"Losses/{p}_{key}_loss")
  493. loss_mts = OrderedDict(
  494. [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
  495. )
  496. extra_loss_mts = {}
  497. for model in curr_models:
  498. model.eval()
  499. if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
  500. unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
  501. progress = ProgressMeter(
  502. iters_per_epoch,
  503. [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
  504. self._get_meters(curr_phases),
  505. prefix="Val Epoch: [{}]".format(self.epoch),
  506. )
  507. end = time.time()
  508. for data_iter, batch in enumerate(val_loader):
  509. # measure data loading time
  510. data_time.update(time.time() - end)
  511. batch = batch.to(self.device, non_blocking=True)
  512. # compute output
  513. with torch.no_grad():
  514. with torch.cuda.amp.autocast(
  515. enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
  516. dtype=(
  517. get_amp_type(self.optim_conf.amp.amp_dtype)
  518. if self.optim_conf
  519. else None
  520. ),
  521. ):
  522. for phase, model in zip(curr_phases, curr_models):
  523. loss_dict, batch_size, extra_losses = self._step(
  524. batch,
  525. model,
  526. phase,
  527. )
  528. assert len(loss_dict) == 1
  529. loss_key, loss = loss_dict.popitem()
  530. loss_mts[loss_key].update(loss.item(), batch_size)
  531. for k, v in extra_losses.items():
  532. if k not in extra_loss_mts:
  533. extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
  534. extra_loss_mts[k].update(v.item(), batch_size)
  535. # measure elapsed time
  536. batch_time.update(time.time() - end)
  537. end = time.time()
  538. self.time_elapsed_meter.update(
  539. time.time() - self.start_time + self.ckpt_time_elapsed
  540. )
  541. if torch.cuda.is_available():
  542. mem.update(reset_peak_usage=True)
  543. if data_iter % self.logging_conf.log_freq == 0:
  544. progress.display(data_iter)
  545. if data_iter % self.logging_conf.log_scalar_frequency == 0:
  546. # Log progress meters.
  547. for progress_meter in progress.meters:
  548. self.logger.log(
  549. os.path.join("Step_Stats", phase, progress_meter.name),
  550. progress_meter.val,
  551. self.steps[Phase.VAL],
  552. )
  553. if data_iter % 10 == 0:
  554. dist.barrier()
  555. self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
  556. self._log_timers(phase)
  557. for model in curr_models:
  558. if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
  559. unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
  560. out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
  561. for k, v in loss_mts.items():
  562. out_dict[k] = v.avg
  563. for k, v in extra_loss_mts.items():
  564. out_dict[k] = v.avg
  565. for phase in curr_phases:
  566. out_dict.update(self._get_trainer_state(phase))
  567. self._reset_meters(curr_phases)
  568. logging.info(f"Meters: {out_dict}")
  569. return out_dict
  570. def _get_trainer_state(self, phase):
  571. return {
  572. "Trainer/where": self.where,
  573. "Trainer/epoch": self.epoch,
  574. f"Trainer/steps_{phase}": self.steps[phase],
  575. }
  576. def train_epoch(self, train_loader):
  577. # Init stat meters
  578. batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
  579. data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
  580. mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
  581. data_times = []
  582. phase = Phase.TRAIN
  583. iters_per_epoch = len(train_loader)
  584. loss_names = []
  585. for batch_key in self.loss.keys():
  586. loss_names.append(f"Losses/{phase}_{batch_key}_loss")
  587. loss_mts = OrderedDict(
  588. [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
  589. )
  590. extra_loss_mts = {}
  591. progress = ProgressMeter(
  592. iters_per_epoch,
  593. [
  594. batch_time_meter,
  595. data_time_meter,
  596. mem_meter,
  597. self.time_elapsed_meter,
  598. *loss_mts.values(),
  599. ],
  600. self._get_meters([phase]),
  601. prefix="Train Epoch: [{}]".format(self.epoch),
  602. )
  603. # Model training loop
  604. self.model.train()
  605. end = time.time()
  606. for data_iter, batch in enumerate(train_loader):
  607. # measure data loading time
  608. data_time_meter.update(time.time() - end)
  609. data_times.append(data_time_meter.val)
  610. batch = batch.to(
  611. self.device, non_blocking=True
  612. ) # move tensors in a tensorclass
  613. try:
  614. self._run_step(batch, phase, loss_mts, extra_loss_mts)
  615. # compute gradient and do optim step
  616. exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
  617. self.where = float(exact_epoch) / self.max_epochs
  618. assert self.where <= 1 + self.EPSILON
  619. if self.where < 1.0:
  620. self.optim.step_schedulers(
  621. self.where, step=int(exact_epoch * iters_per_epoch)
  622. )
  623. else:
  624. logging.warning(
  625. f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
  626. )
  627. # Log schedulers
  628. if data_iter % self.logging_conf.log_scalar_frequency == 0:
  629. for j, param_group in enumerate(self.optim.optimizer.param_groups):
  630. for option in self.optim.schedulers[j]:
  631. optim_prefix = (
  632. "" + f"{j}_"
  633. if len(self.optim.optimizer.param_groups) > 1
  634. else ""
  635. )
  636. self.logger.log(
  637. os.path.join("Optim", f"{optim_prefix}", option),
  638. param_group[option],
  639. self.steps[phase],
  640. )
  641. # Clipping gradients and detecting diverging gradients
  642. if self.gradient_clipper is not None:
  643. self.scaler.unscale_(self.optim.optimizer)
  644. self.gradient_clipper(model=self.model)
  645. if self.gradient_logger is not None:
  646. self.gradient_logger(
  647. self.model, rank=self.distributed_rank, where=self.where
  648. )
  649. # Optimizer step: the scaler will make sure gradients are not
  650. # applied if the gradients are infinite
  651. self.scaler.step(self.optim.optimizer)
  652. self.scaler.update()
  653. # measure elapsed time
  654. batch_time_meter.update(time.time() - end)
  655. end = time.time()
  656. self.time_elapsed_meter.update(
  657. time.time() - self.start_time + self.ckpt_time_elapsed
  658. )
  659. mem_meter.update(reset_peak_usage=True)
  660. if data_iter % self.logging_conf.log_freq == 0:
  661. progress.display(data_iter)
  662. if data_iter % self.logging_conf.log_scalar_frequency == 0:
  663. # Log progress meters.
  664. for progress_meter in progress.meters:
  665. self.logger.log(
  666. os.path.join("Step_Stats", phase, progress_meter.name),
  667. progress_meter.val,
  668. self.steps[phase],
  669. )
  670. # Catching NaN/Inf errors in the loss
  671. except FloatingPointError as e:
  672. raise e
  673. self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
  674. self._log_timers(Phase.TRAIN)
  675. self._log_sync_data_times(Phase.TRAIN, data_times)
  676. out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
  677. for k, v in loss_mts.items():
  678. out_dict[k] = v.avg
  679. for k, v in extra_loss_mts.items():
  680. out_dict[k] = v.avg
  681. out_dict.update(self._get_trainer_state(phase))
  682. logging.info(f"Losses and meters: {out_dict}")
  683. self._reset_meters([phase])
  684. return out_dict
  685. def _log_sync_data_times(self, phase, data_times):
  686. data_times = all_reduce_max(torch.tensor(data_times)).tolist()
  687. steps = range(self.steps[phase] - len(data_times), self.steps[phase])
  688. for step, data_time in zip(steps, data_times):
  689. if step % self.logging_conf.log_scalar_frequency == 0:
  690. self.logger.log(
  691. os.path.join("Step_Stats", phase, "Data Time Synced"),
  692. data_time,
  693. step,
  694. )
  695. def _run_step(
  696. self,
  697. batch: BatchedVideoDatapoint,
  698. phase: str,
  699. loss_mts: Dict[str, AverageMeter],
  700. extra_loss_mts: Dict[str, AverageMeter],
  701. raise_on_error: bool = True,
  702. ):
  703. """
  704. Run the forward / backward
  705. """
  706. # it's important to set grads to None, especially with Adam since 0
  707. # grads will also update a model even if the step doesn't produce
  708. # gradients
  709. self.optim.zero_grad(set_to_none=True)
  710. with torch.cuda.amp.autocast(
  711. enabled=self.optim_conf.amp.enabled,
  712. dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
  713. ):
  714. loss_dict, batch_size, extra_losses = self._step(
  715. batch,
  716. self.model,
  717. phase,
  718. )
  719. assert len(loss_dict) == 1
  720. loss_key, loss = loss_dict.popitem()
  721. if not math.isfinite(loss.item()):
  722. error_msg = f"Loss is {loss.item()}, attempting to stop training"
  723. logging.error(error_msg)
  724. if raise_on_error:
  725. raise FloatingPointError(error_msg)
  726. else:
  727. return
  728. self.scaler.scale(loss).backward()
  729. loss_mts[loss_key].update(loss.item(), batch_size)
  730. for extra_loss_key, extra_loss in extra_losses.items():
  731. if extra_loss_key not in extra_loss_mts:
  732. extra_loss_mts[extra_loss_key] = AverageMeter(
  733. extra_loss_key, self.device, ":.2e"
  734. )
  735. extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
  736. def _log_meters_and_save_best_ckpts(self, phases: List[str]):
  737. logging.info("Synchronizing meters")
  738. out_dict = {}
  739. checkpoint_save_keys = []
  740. for key, meter in self._get_meters(phases).items():
  741. meter_output = meter.compute_synced()
  742. is_better_check = getattr(meter, "is_better", None)
  743. for meter_subkey, meter_value in meter_output.items():
  744. out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
  745. if is_better_check is None:
  746. continue
  747. tracked_meter_key = os.path.join(key, meter_subkey)
  748. if tracked_meter_key not in self.best_meter_values or is_better_check(
  749. meter_value,
  750. self.best_meter_values[tracked_meter_key],
  751. ):
  752. self.best_meter_values[tracked_meter_key] = meter_value
  753. if (
  754. self.checkpoint_conf.save_best_meters is not None
  755. and key in self.checkpoint_conf.save_best_meters
  756. ):
  757. checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
  758. if len(checkpoint_save_keys) > 0:
  759. self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
  760. return out_dict
  761. def _log_timers(self, phase):
  762. time_remaining = 0
  763. epochs_remaining = self.max_epochs - self.epoch - 1
  764. val_epochs_remaining = sum(
  765. n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs)
  766. )
  767. # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
  768. # the end epoch.
  769. if (self.max_epochs - 1) % self.val_epoch_freq != 0:
  770. val_epochs_remaining += 1
  771. # Remove the current val run from estimate
  772. if phase == Phase.VAL:
  773. val_epochs_remaining -= 1
  774. time_remaining += (
  775. epochs_remaining * self.est_epoch_time[Phase.TRAIN]
  776. + val_epochs_remaining * self.est_epoch_time[Phase.VAL]
  777. )
  778. self.logger.log(
  779. os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
  780. self.time_elapsed_meter.val,
  781. self.steps[phase],
  782. )
  783. logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
  784. def _reset_meters(self, phases: str) -> None:
  785. for meter in self._get_meters(phases).values():
  786. meter.reset()
  787. def _check_val_key_match(self, val_keys, phase):
  788. if val_keys is not None:
  789. # Check if there are any duplicates
  790. assert len(val_keys) == len(
  791. set(val_keys)
  792. ), f"Duplicate keys in val datasets, keys: {val_keys}"
  793. # Check that the keys match the meter keys
  794. if self.meters_conf is not None and phase in self.meters_conf:
  795. assert set(val_keys) == set(self.meters_conf[phase].keys()), (
  796. f"Keys in val datasets do not match the keys in meters."
  797. f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
  798. f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
  799. )
  800. if self.loss_conf is not None:
  801. loss_keys = set(self.loss_conf.keys()) - set(["all"])
  802. assert all([k in loss_keys for k in val_keys]), (
  803. f"Keys in val datasets do not match the keys in losses."
  804. f"\nMissing in losses: {set(val_keys) - loss_keys}"
  805. f"\nMissing in val datasets: {loss_keys - set(val_keys)}"
  806. )
  807. def _setup_components(self):
  808. # Get the keys for all the val datasets, if any
  809. val_phase = Phase.VAL
  810. val_keys = None
  811. if self.data_conf.get(val_phase, None) is not None:
  812. val_keys = collect_dict_keys(self.data_conf[val_phase])
  813. # Additional checks on the sanity of the config for val datasets
  814. self._check_val_key_match(val_keys, phase=val_phase)
  815. logging.info("Setting up components: Model, loss, optim, meters etc.")
  816. self.epoch = 0
  817. self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
  818. self.logger = Logger(self.logging_conf)
  819. self.model = instantiate(self.model_conf, _convert_="all")
  820. print_model_summary(self.model)
  821. self.loss = None
  822. if self.loss_conf:
  823. self.loss = {
  824. key: el # wrap_base_loss(el)
  825. for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
  826. }
  827. self.loss = nn.ModuleDict(self.loss)
  828. self.meters = {}
  829. self.best_meter_values = {}
  830. if self.meters_conf:
  831. self.meters = instantiate(self.meters_conf, _convert_="all")
  832. self.scaler = torch.amp.GradScaler(
  833. self.device,
  834. enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
  835. )
  836. self.gradient_clipper = (
  837. instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
  838. )
  839. self.gradient_logger = (
  840. instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
  841. )
  842. logging.info("Finished setting up components: Model, loss, optim, meters etc.")
  843. def _construct_optimizers(self):
  844. self.optim = construct_optimizer(
  845. self.model,
  846. self.optim_conf.optimizer,
  847. self.optim_conf.options,
  848. self.optim_conf.param_group_modifiers,
  849. )
  850. def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
  851. core_loss = loss.pop(CORE_LOSS_KEY)
  852. if step % self.logging_conf.log_scalar_frequency == 0:
  853. for k in loss:
  854. log_str = os.path.join(loss_str, k)
  855. self.logger.log(log_str, loss[k], step)
  856. return core_loss
  857. def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
  858. """
  859. Prints the model and the number of parameters in the model.
  860. # Multiple packages provide this info in a nice table format
  861. # However, they need us to provide an `input` (as they also write down the output sizes)
  862. # Our models are complex, and a single input is restrictive.
  863. # https://github.com/sksq96/pytorch-summary
  864. # https://github.com/nmhkahn/torchsummaryX
  865. """
  866. if get_rank() != 0:
  867. return
  868. param_kwargs = {}
  869. trainable_parameters = sum(
  870. p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
  871. )
  872. total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
  873. non_trainable_parameters = total_parameters - trainable_parameters
  874. logging.info("==" * 10)
  875. logging.info(f"Summary for model {type(model)}")
  876. logging.info(f"Model is {model}")
  877. logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
  878. logging.info(
  879. f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
  880. )
  881. logging.info(
  882. f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
  883. )
  884. logging.info("==" * 10)
  885. if log_dir:
  886. output_fpath = os.path.join(log_dir, "model.txt")
  887. with g_pathmgr.open(output_fpath, "w") as f:
  888. print(model, file=f)
  889. PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
  890. def get_human_readable_count(number: int) -> str:
  891. """
  892. Abbreviates an integer number with K, M, B, T for thousands, millions,
  893. billions and trillions, respectively.
  894. Examples:
  895. >>> get_human_readable_count(123)
  896. '123 '
  897. >>> get_human_readable_count(1234) # (one thousand)
  898. '1.2 K'
  899. >>> get_human_readable_count(2e6) # (two million)
  900. '2.0 M'
  901. >>> get_human_readable_count(3e9) # (three billion)
  902. '3.0 B'
  903. >>> get_human_readable_count(4e14) # (four hundred trillion)
  904. '400 T'
  905. >>> get_human_readable_count(5e15) # (more than trillion)
  906. '5,000 T'
  907. Args:
  908. number: a positive integer number
  909. Return:
  910. A string formatted according to the pattern described above.
  911. """
  912. assert number >= 0
  913. labels = PARAMETER_NUM_UNITS
  914. num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
  915. num_groups = int(np.ceil(num_digits / 3))
  916. num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
  917. shift = -3 * (num_groups - 1)
  918. number = number * (10**shift)
  919. index = num_groups - 1
  920. if index < 1 or number >= 100:
  921. return f"{int(number):,d} {labels[index]}"
  922. else:
  923. return f"{number:,.1f} {labels[index]}"