| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import contextlib
- import fnmatch
- import gc
- import json
- import logging
- import math
- import os
- import time
- from collections import OrderedDict
- from dataclasses import dataclass, field
- from typing import Any, Dict, List, Mapping, Optional
- import numpy as np
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- from hydra.utils import instantiate
- from iopath.common.file_io import g_pathmgr
- from sam3.model.data_misc import BatchedDatapoint
- from sam3.model.model_misc import SAM3Output
- from sam3.model.utils.misc import copy_data_to_device
- from sam3.train.optim.optimizer import construct_optimizer
- from sam3.train.utils.checkpoint_utils import (
- assert_skipped_parameters_are_frozen,
- exclude_params_matching_unix_pattern,
- load_state_dict_into_model,
- with_check_parameter_frozen,
- )
- from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
- from sam3.train.utils.logger import Logger, setup_logging
- from sam3.train.utils.train_utils import (
- AverageMeter,
- collect_dict_keys,
- DurationMeter,
- get_amp_type,
- get_machine_local_and_dist_rank,
- get_resume_checkpoint,
- human_readable_time,
- is_dist_avail_and_initialized,
- log_env_variables,
- makedir,
- MemMeter,
- Phase,
- ProgressMeter,
- set_seeds,
- setup_distributed_backend,
- )
- CORE_LOSS_KEY = "core_loss"
- def unwrap_ddp_if_wrapped(model):
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- return model.module
- return model
- @dataclass
- class OptimAMPConf:
- enabled: bool = False
- amp_dtype: str = "float16"
- @dataclass
- class OptimConf:
- optimizer: torch.optim.Optimizer = None
- options: Optional[Dict[str, Any]] = None
- param_group_modifiers: Optional[List] = None
- amp: Optional[Dict[str, Any]] = None
- gradient_clip: Any = None
- gradient_logger: Any = None
- def __post_init__(self):
- # amp
- if not isinstance(self.amp, OptimAMPConf):
- if self.amp is None:
- self.amp = {}
- assert isinstance(self.amp, Mapping)
- self.amp = OptimAMPConf(**self.amp)
- @dataclass
- class DistributedConf:
- backend: Optional[str] = None # inferred from accelerator type
- comms_dtype: Optional[str] = None
- find_unused_parameters: bool = False
- timeout_mins: int = 30
- gradient_as_bucket_view: bool = False # PyTorch DDP default is False
- static_graph: bool = False # PyTorch DDP default is False
- @dataclass
- class CudaConf:
- cudnn_deterministic: bool = False
- cudnn_benchmark: bool = True
- allow_tf32: bool = False
- # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
- matmul_allow_tf32: Optional[bool] = None
- # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
- cudnn_allow_tf32: Optional[bool] = None
- @dataclass
- class CheckpointConf:
- save_dir: str
- save_freq: int
- save_list: List[int] = field(default_factory=list)
- model_weight_initializer: Any = None
- save_best_meters: List[str] = None
- skip_saving_parameters: List[str] = field(default_factory=list)
- initialize_after_preemption: Optional[bool] = None
- # if not None, training will be resumed from this checkpoint
- resume_from: Optional[str] = None
- def infer_missing(self):
- if self.initialize_after_preemption is None:
- with_skip_saving = len(self.skip_saving_parameters) > 0
- self.initialize_after_preemption = with_skip_saving
- return self
- @dataclass
- class LoggingConf:
- log_dir: str
- log_freq: int # In iterations
- tensorboard_writer: Any
- log_level_primary: str = "INFO"
- log_level_secondary: str = "ERROR"
- log_scalar_frequency: int = 100
- log_visual_frequency: int = 100
- scalar_keys_to_log: Optional[Dict[str, Any]] = None
- log_batch_stats: bool = False
- wandb_writer: Optional[Any] = None
- class Trainer:
- """
- Trainer supporting the DDP training strategies.
- """
- EPSILON = 1e-8
- def __init__(
- self,
- *, # the order of these args can change at any time, so they are keyword-only
- data: Dict[str, Any],
- model: Dict[str, Any],
- logging: Dict[str, Any],
- checkpoint: Dict[str, Any],
- max_epochs: int,
- mode: str = "train",
- accelerator: str = "cuda",
- seed_value: int = 123,
- val_epoch_freq: int = 1,
- distributed: Dict[str, bool] = None,
- cuda: Dict[str, bool] = None,
- env_variables: Optional[Dict[str, Any]] = None,
- optim: Optional[Dict[str, Any]] = None,
- optim_overrides: Optional[List[Dict[str, Any]]] = None,
- meters: Optional[Dict[str, Any]] = None,
- loss: Optional[Dict[str, Any]] = None,
- skip_first_val: bool = False,
- skip_saving_ckpts: bool = False,
- empty_gpu_mem_cache_after_eval: bool = True,
- gradient_accumulation_steps: int = 1,
- ):
- self._setup_env_variables(env_variables)
- self._setup_timers()
- self.data_conf = data
- self.model_conf = model
- self.logging_conf = LoggingConf(**logging)
- self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
- self.max_epochs = max_epochs
- self.mode = mode
- self.val_epoch_freq = val_epoch_freq
- self.optim_conf = OptimConf(**optim) if optim is not None else OptimConf()
- self.meters_conf = meters
- self.loss_conf = loss
- self.gradient_accumulation_steps = gradient_accumulation_steps
- distributed = DistributedConf(**distributed or {})
- cuda = CudaConf(**cuda or {})
- self.where = 0.0
- self.skip_first_val = skip_first_val
- self.skip_saving_ckpts = skip_saving_ckpts
- self.empty_gpu_mem_cache_after_eval = empty_gpu_mem_cache_after_eval
- self._infer_distributed_backend_if_none(distributed, accelerator)
- self._setup_device(accelerator)
- self._setup_torch_dist_and_backend(cuda, distributed)
- makedir(self.logging_conf.log_dir)
- setup_logging(
- __name__,
- output_dir=self.logging_conf.log_dir,
- rank=self.rank,
- log_level_primary=self.logging_conf.log_level_primary,
- log_level_secondary=self.logging_conf.log_level_secondary,
- )
- set_seeds(seed_value, self.max_epochs, self.distributed_rank)
- log_env_variables()
- assert is_dist_avail_and_initialized(), (
- "Torch distributed needs to be initialized before calling the trainer."
- )
- self._setup_components() # Except Optimizer everything is setup here.
- self._move_to_device()
- self._construct_optimizers()
- self._setup_dataloaders()
- self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
- if self.checkpoint_conf.resume_from is not None:
- assert os.path.exists(self.checkpoint_conf.resume_from), (
- f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
- )
- dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
- if self.distributed_rank == 0 and not os.path.exists(dst):
- # Copy the "resume_from" checkpoint to the checkpoint folder
- # if there is not a checkpoint to resume from already there
- makedir(self.checkpoint_conf.save_dir)
- g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
- barrier()
- self.load_checkpoint()
- self._setup_ddp_distributed_training(distributed, accelerator)
- barrier()
- def _setup_timers(self):
- """
- Initializes counters for elapsed time and eta.
- """
- self.start_time = time.time()
- self.ckpt_time_elapsed = 0
- self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
- def _get_meters(self, phase_filters=None):
- if self.meters is None:
- return {}
- meters = {}
- for phase, phase_meters in self.meters.items():
- if phase_filters is not None and phase not in phase_filters:
- continue
- for key, key_meters in phase_meters.items():
- if key_meters is None:
- continue
- for name, meter in key_meters.items():
- meters[f"{phase}_{key}/{name}"] = meter
- return meters
- def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
- if distributed_conf.backend is None:
- distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
- def _setup_env_variables(self, env_variables_conf) -> None:
- if env_variables_conf is not None:
- for variable_name, value in env_variables_conf.items():
- os.environ[variable_name] = value
- def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
- if torch.cuda.is_available():
- torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
- torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
- torch.backends.cuda.matmul.allow_tf32 = (
- cuda_conf.matmul_allow_tf32
- if cuda_conf.matmul_allow_tf32 is not None
- else cuda_conf.allow_tf32
- )
- torch.backends.cudnn.allow_tf32 = (
- cuda_conf.cudnn_allow_tf32
- if cuda_conf.cudnn_allow_tf32 is not None
- else cuda_conf.allow_tf32
- )
- self.rank = setup_distributed_backend(
- distributed_conf.backend, distributed_conf.timeout_mins
- )
- def _setup_device(self, accelerator):
- self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
- if accelerator == "cuda":
- self.device = torch.device("cuda", self.local_rank)
- torch.cuda.set_device(self.local_rank)
- elif accelerator == "cpu":
- self.device = torch.device("cpu")
- else:
- raise ValueError(f"Unsupported accelerator: {accelerator}")
- def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
- assert isinstance(self.model, torch.nn.Module)
- self.model = nn.parallel.DistributedDataParallel(
- self.model,
- device_ids=[self.local_rank] if accelerator == "cuda" else [],
- find_unused_parameters=distributed_conf.find_unused_parameters,
- gradient_as_bucket_view=distributed_conf.gradient_as_bucket_view,
- static_graph=distributed_conf.static_graph,
- )
- if distributed_conf.comms_dtype is not None: # noqa
- from torch.distributed.algorithms import ddp_comm_hooks
- amp_type = get_amp_type(distributed_conf.comms_dtype)
- if amp_type == torch.bfloat16:
- hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
- logging.info("Enabling bfloat16 grad communication")
- else:
- hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
- logging.info("Enabling fp16 grad communication")
- process_group = None
- self.model.register_comm_hook(process_group, hook)
- def _move_to_device(self):
- logging.info(
- f"Moving components to device {self.device} and local rank {self.local_rank}."
- )
- self.model.to(self.device)
- logging.info(
- f"Done moving components to device {self.device} and local rank {self.local_rank}."
- )
- def save_checkpoint(self, epoch, checkpoint_names=None):
- if self.skip_saving_ckpts:
- logging.info(
- "skip_saving_ckpts is set to True. So, no checkpoints have been saved."
- )
- return
- checkpoint_folder = self.checkpoint_conf.save_dir
- makedir(checkpoint_folder)
- if checkpoint_names is None:
- checkpoint_names = ["checkpoint"]
- if (
- self.checkpoint_conf.save_freq > 0
- and (int(epoch) % self.checkpoint_conf.save_freq == 0)
- ) or int(epoch) in self.checkpoint_conf.save_list:
- checkpoint_names.append(f"checkpoint_{int(epoch)}")
- checkpoint_paths = []
- for ckpt_name in checkpoint_names:
- checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
- state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
- state_dict = exclude_params_matching_unix_pattern(
- patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
- )
- checkpoint = {
- "model": state_dict,
- "optimizer": self.optim.optimizer.state_dict(),
- "epoch": epoch,
- "loss": self.loss.state_dict(),
- "steps": self.steps,
- "time_elapsed": self.time_elapsed_meter.val,
- "best_meter_values": self.best_meter_values,
- }
- if self.optim_conf.amp.enabled:
- checkpoint["scaler"] = self.scaler.state_dict()
- # DDP checkpoints are only saved on rank 0 (all workers are identical)
- if self.distributed_rank != 0:
- return
- for checkpoint_path in checkpoint_paths:
- self._save_checkpoint(checkpoint, checkpoint_path)
- def _save_checkpoint(self, checkpoint, checkpoint_path):
- """
- Save a checkpoint while guarding against the job being killed in the middle
- of checkpoint saving (which corrupts the checkpoint file and ruins the
- entire training since usually only the last checkpoint is kept per run).
- We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
- and move it to overwrite the old checkpoint_path.
- """
- checkpoint_path_tmp = f"{checkpoint_path}.tmp"
- with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
- torch.save(checkpoint, f)
- # after torch.save is completed, replace the old checkpoint with the new one
- if g_pathmgr.exists(checkpoint_path):
- # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
- g_pathmgr.rm(checkpoint_path)
- success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
- assert success
- def load_checkpoint(self):
- ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
- if ckpt_path is None:
- self._init_model_state()
- else:
- if self.checkpoint_conf.initialize_after_preemption:
- self._call_model_initializer()
- self._load_resuming_checkpoint(ckpt_path)
- def _init_model_state(self):
- # Checking that parameters that won't be saved are indeed frozen
- # We do this check here before even saving the model to catch errors
- # are early as possible and not at the end of the first epoch
- assert_skipped_parameters_are_frozen(
- patterns=self.checkpoint_conf.skip_saving_parameters,
- model=self.model,
- )
- # Checking that parameters that won't be saved are initialized from
- # within the model definition, unless `initialize_after_preemption`
- # is explicitly set to `True`. If not, this is a bug, and after
- # preemption, the `skip_saving_parameters` will have random values
- allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
- with with_check_parameter_frozen(
- patterns=self.checkpoint_conf.skip_saving_parameters,
- model=self.model,
- disabled=allow_init_skip_parameters,
- ):
- self._call_model_initializer()
- def _call_model_initializer(self):
- model_weight_initializer = instantiate(
- self.checkpoint_conf.model_weight_initializer
- )
- if model_weight_initializer is not None:
- logging.info(
- f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}"
- )
- self.model = model_weight_initializer(model=self.model)
- def _load_resuming_checkpoint(self, ckpt_path: str):
- logging.info(f"Resuming training from {ckpt_path}")
- with g_pathmgr.open(ckpt_path, "rb") as f:
- checkpoint = torch.load(f, map_location="cpu")
- load_state_dict_into_model(
- model=self.model,
- state_dict=checkpoint["model"],
- ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
- )
- self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
- self.loss.load_state_dict(checkpoint["loss"], strict=True)
- self.epoch = checkpoint["epoch"]
- self.steps = checkpoint["steps"]
- self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
- if self.optim_conf.amp.enabled and "scaler" in checkpoint:
- self.scaler.load_state_dict(checkpoint["scaler"])
- self.best_meter_values = checkpoint.get("best_meter_values", {})
- if "train_dataset" in checkpoint and self.train_dataset is not None:
- self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
- def is_intermediate_val_epoch(self, epoch):
- skip_epoch = self.skip_first_val and epoch == 0
- return (
- epoch % self.val_epoch_freq == 0
- and epoch < self.max_epochs - 1
- and not skip_epoch
- )
- def _find_loss(self, key: str):
- if key in self.loss:
- return self.loss[key]
- assert key != "all", "Loss must be specified for key='all'"
- assert "default" in self.loss, (
- f"Key {key} not found in losss, and no default provided"
- )
- return self.loss["default"]
- def _find_meter(self, phase: str, key: str):
- if key in self.meters[phase]:
- return self.meters[phase][key]
- for cand_key, meter in self.meters[phase].items():
- if fnmatch.fnmatch(key, cand_key):
- return meter
- return None
- def _step(
- self,
- batch: BatchedDatapoint,
- model: nn.Module,
- phase: str,
- ):
- key, batch = batch.popitem()
- batch = copy_data_to_device(batch, self.device, non_blocking=True)
- find_stages = model(batch)
- find_targets = [
- unwrap_ddp_if_wrapped(model).back_convert(x) for x in batch.find_targets
- ]
- batch_size = len(batch.img_batch)
- loss = self._find_loss(key)(find_stages, find_targets)
- loss_str = f"Losses/{phase}_{key}_loss"
- loss_log_str = os.path.join("Step_Losses", loss_str)
- # loss contains multiple sub-components we wish to log
- step_losses = {}
- if isinstance(loss, dict):
- step_losses.update(
- {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()}
- )
- loss = self._log_loss_detailed_and_return_core_loss(
- loss, loss_log_str, self.steps[phase]
- )
- if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
- self.logger.log(
- loss_log_str,
- loss,
- self.steps[phase],
- )
- self.steps[phase] += 1
- ret_tuple = {loss_str: loss}, batch_size, step_losses
- if phase not in self.meters:
- return ret_tuple
- meters_dict = self._find_meter(phase, key)
- if meters_dict is None:
- return ret_tuple
- if meters_dict is not None:
- for _, meter in meters_dict.items():
- meter.update(
- find_stages=find_stages,
- find_metadatas=batch.find_metadatas,
- model=model,
- batch=batch,
- key=key,
- )
- # Cleanup memory
- if isinstance(find_stages, SAM3Output):
- for fs in find_stages:
- for k in list(fs.keys()):
- del fs[k]
- return ret_tuple
- def run(self):
- assert self.mode in ["train", "train_only", "val"]
- if self.mode == "train":
- if self.epoch > 0:
- logging.info(f"Resuming training from epoch: {self.epoch}")
- # resuming from a checkpoint
- if self.is_intermediate_val_epoch(self.epoch - 1):
- logging.info("Running previous val epoch")
- self.epoch -= 1
- self.run_val()
- self.epoch += 1
- self.run_train()
- self.run_val()
- elif self.mode == "val":
- self.run_val()
- elif self.mode == "train_only":
- self.run_train()
- def _setup_dataloaders(self):
- self.train_dataset = None
- self.val_dataset = None
- if self.mode in ["train", "val"]:
- self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
- if self.mode in ["train", "train_only"]:
- self.train_dataset = instantiate(self.data_conf.train)
- def run_train(self):
- while self.epoch < self.max_epochs:
- dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
- barrier()
- outs = self.train_epoch(dataloader)
- self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
- # log train to text file.
- if self.distributed_rank == 0:
- with g_pathmgr.open(
- os.path.join(self.logging_conf.log_dir, "train_stats.json"),
- "a",
- ) as f:
- f.write(json.dumps(outs) + "\n")
- # Save checkpoint before validating
- self.save_checkpoint(self.epoch + 1)
- del dataloader
- gc.collect()
- # Run val, not running on last epoch since will run after the
- # loop anyway
- if self.is_intermediate_val_epoch(self.epoch):
- self.run_val()
- if torch.cuda.is_available() and self.empty_gpu_mem_cache_after_eval:
- # release memory buffers held by the model during eval (which typically
- # involves a lot more frames in video grounding that during training)
- torch.cuda.empty_cache()
- if self.distributed_rank == 0:
- self.best_meter_values.update(self._get_trainer_state("train"))
- with g_pathmgr.open(
- os.path.join(self.logging_conf.log_dir, "best_stats.json"),
- "a",
- ) as f:
- f.write(json.dumps(self.best_meter_values) + "\n")
- self.epoch += 1
- # epoch was incremented in the loop but the val step runs out of the loop
- self.epoch -= 1
- def run_val(self):
- if not self.val_dataset:
- return
- dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
- outs = self.val_epoch(dataloader, phase=Phase.VAL)
- del dataloader
- gc.collect()
- self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
- if self.distributed_rank == 0:
- with g_pathmgr.open(
- os.path.join(self.logging_conf.log_dir, "val_stats.json"),
- "a",
- ) as f:
- f.write(json.dumps(outs) + "\n")
- def val_epoch(self, val_loader, phase):
- batch_time = AverageMeter("Batch Time", self.device, ":.2f")
- data_time = AverageMeter("Data Time", self.device, ":.2f")
- mem = MemMeter("Mem (GB)", self.device, ":.2f")
- iters_per_epoch = len(val_loader)
- curr_phases = [phase]
- curr_models = [self.model]
- loss_names = []
- for p in curr_phases:
- for key in self.loss.keys():
- loss_names.append(f"Losses/{p}_{key}_loss")
- loss_mts = OrderedDict(
- [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
- )
- extra_loss_mts = {}
- for model in curr_models:
- model.eval()
- if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
- unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
- progress = ProgressMeter(
- iters_per_epoch,
- [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
- self._get_meters(curr_phases),
- prefix="Val Epoch: [{}]".format(self.epoch),
- )
- end = time.time()
- for data_iter, batch in enumerate(val_loader):
- # measure data loading time
- data_time.update(time.time() - end)
- # batch = batch.to(self.device, non_blocking=True)
- # compute output
- with torch.no_grad():
- with torch.amp.autocast(
- device_type="cuda",
- enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
- dtype=(
- get_amp_type(self.optim_conf.amp.amp_dtype)
- if self.optim_conf
- else None
- ),
- ):
- for phase, model in zip(curr_phases, curr_models):
- loss_dict, batch_size, extra_losses = self._step(
- batch,
- model,
- phase,
- )
- assert len(loss_dict) == 1
- loss_key, loss = loss_dict.popitem()
- if loss_key in loss_mts:
- loss_mts[loss_key].update(loss.item(), batch_size)
- for k, v in extra_losses.items():
- if k not in extra_loss_mts:
- extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
- extra_loss_mts[k].update(v.item(), batch_size)
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- self.time_elapsed_meter.update(
- time.time() - self.start_time + self.ckpt_time_elapsed
- )
- if torch.cuda.is_available():
- mem.update(reset_peak_usage=True)
- if data_iter % self.logging_conf.log_freq == 0:
- progress.display(data_iter)
- if data_iter % self.logging_conf.log_scalar_frequency == 0:
- # Log progress meters.
- for progress_meter in progress.meters:
- self.logger.log(
- os.path.join("Step_Stats", phase, progress_meter.name),
- progress_meter.val,
- self.steps[Phase.VAL],
- )
- if data_iter % 10 == 0:
- dist.barrier()
- self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
- self._log_timers(phase)
- for model in curr_models:
- if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
- unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
- out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
- for k, v in loss_mts.items():
- out_dict[k] = v.avg
- for k, v in extra_loss_mts.items():
- out_dict[k] = v.avg
- for phase in curr_phases:
- out_dict.update(self._get_trainer_state(phase))
- self._reset_meters(curr_phases)
- logging.info(f"Meters: {out_dict}")
- return out_dict
- def _get_trainer_state(self, phase):
- return {
- "Trainer/where": self.where,
- "Trainer/epoch": self.epoch,
- f"Trainer/steps_{phase}": self.steps[phase],
- }
- def train_epoch(self, train_loader):
- # Init stat meters
- batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
- data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
- mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
- data_times = []
- phase = Phase.TRAIN
- iters_per_epoch = len(train_loader)
- loss_names = []
- for batch_key in self.loss.keys():
- loss_names.append(f"Losses/{phase}_{batch_key}_loss")
- loss_mts = OrderedDict(
- [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
- )
- extra_loss_mts = {}
- progress = ProgressMeter(
- iters_per_epoch,
- [
- batch_time_meter,
- data_time_meter,
- mem_meter,
- self.time_elapsed_meter,
- *loss_mts.values(),
- ],
- self._get_meters([phase]),
- prefix="Train Epoch: [{}]".format(self.epoch),
- )
- # Model training loop
- self.model.train()
- end = time.time()
- for data_iter, batch in enumerate(train_loader):
- # measure data loading time
- data_time_meter.update(time.time() - end)
- data_times.append(data_time_meter.val)
- # batch = batch.to(
- # self.device, non_blocking=True
- # ) # move tensors in a tensorclass
- try:
- self._run_step(batch, phase, loss_mts, extra_loss_mts)
- # compute gradient and do optim step
- exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
- self.where = float(exact_epoch) / self.max_epochs
- assert self.where <= 1 + self.EPSILON
- if self.where < 1.0:
- self.optim.step_schedulers(
- self.where, step=int(exact_epoch * iters_per_epoch)
- )
- else:
- logging.warning(
- f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
- )
- # Log schedulers
- if data_iter % self.logging_conf.log_scalar_frequency == 0:
- for j, param_group in enumerate(self.optim.optimizer.param_groups):
- for option in self.optim.schedulers[j]:
- optim_prefix = (
- "" + f"{j}_"
- if len(self.optim.optimizer.param_groups) > 1
- else ""
- )
- self.logger.log(
- os.path.join("Optim", f"{optim_prefix}", option),
- param_group[option],
- self.steps[phase],
- )
- # Clipping gradients and detecting diverging gradients
- if self.gradient_clipper is not None:
- self.scaler.unscale_(self.optim.optimizer)
- self.gradient_clipper(model=self.model)
- if self.gradient_logger is not None:
- self.gradient_logger(
- self.model, rank=self.distributed_rank, where=self.where
- )
- # Optimizer step: the scaler will make sure gradients are not
- # applied if the gradients are infinite
- self.scaler.step(self.optim.optimizer)
- self.scaler.update()
- # measure elapsed time
- batch_time_meter.update(time.time() - end)
- end = time.time()
- self.time_elapsed_meter.update(
- time.time() - self.start_time + self.ckpt_time_elapsed
- )
- mem_meter.update(reset_peak_usage=True)
- if data_iter % self.logging_conf.log_freq == 0:
- progress.display(data_iter)
- if data_iter % self.logging_conf.log_scalar_frequency == 0:
- # Log progress meters.
- for progress_meter in progress.meters:
- self.logger.log(
- os.path.join("Step_Stats", phase, progress_meter.name),
- progress_meter.val,
- self.steps[phase],
- )
- # Catching NaN/Inf errors in the loss
- except FloatingPointError as e:
- raise e
- self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
- self._log_timers(Phase.TRAIN)
- self._log_sync_data_times(Phase.TRAIN, data_times)
- out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
- for k, v in loss_mts.items():
- out_dict[k] = v.avg
- for k, v in extra_loss_mts.items():
- out_dict[k] = v.avg
- out_dict.update(self._get_trainer_state(phase))
- logging.info(f"Losses and meters: {out_dict}")
- self._reset_meters([phase])
- return out_dict
- def _log_sync_data_times(self, phase, data_times):
- data_times = all_reduce_max(torch.tensor(data_times)).tolist()
- steps = range(self.steps[phase] - len(data_times), self.steps[phase])
- for step, data_time in zip(steps, data_times):
- if step % self.logging_conf.log_scalar_frequency == 0:
- self.logger.log(
- os.path.join("Step_Stats", phase, "Data Time Synced"),
- data_time,
- step,
- )
- def _run_step(
- self,
- batch: BatchedDatapoint,
- phase: str,
- loss_mts: Dict[str, AverageMeter],
- extra_loss_mts: Dict[str, AverageMeter],
- raise_on_error: bool = True,
- ):
- """
- Run the forward / backward
- """
- # it's important to set grads to None, especially with Adam since 0
- # grads will also update a model even if the step doesn't produce
- # gradients
- self.optim.zero_grad(set_to_none=True)
- if self.gradient_accumulation_steps > 1:
- assert isinstance(batch, list), (
- f"Expected a list of batches, got {type(batch)}"
- )
- assert len(batch) == self.gradient_accumulation_steps, (
- f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
- )
- accum_steps = len(batch)
- else:
- accum_steps = 1
- batch = [batch]
- for i, chunked_batch in enumerate(batch):
- ddp_context = (
- self.model.no_sync()
- if i < accum_steps - 1
- else contextlib.nullcontext()
- )
- with ddp_context:
- with torch.amp.autocast(
- device_type="cuda",
- enabled=self.optim_conf.amp.enabled,
- dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
- ):
- loss_dict, batch_size, extra_losses = self._step(
- chunked_batch,
- self.model,
- phase,
- )
- assert len(loss_dict) == 1
- loss_key, loss = loss_dict.popitem()
- if not math.isfinite(loss.item()):
- error_msg = f"Loss is {loss.item()}, attempting to stop training"
- logging.error(error_msg)
- if raise_on_error:
- raise FloatingPointError(error_msg)
- else:
- return
- self.scaler.scale(loss).backward()
- loss_mts[loss_key].update(loss.item(), batch_size)
- for extra_loss_key, extra_loss in extra_losses.items():
- if extra_loss_key not in extra_loss_mts:
- extra_loss_mts[extra_loss_key] = AverageMeter(
- extra_loss_key, self.device, ":.2e"
- )
- extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
- def _log_meters_and_save_best_ckpts(self, phases: List[str]):
- logging.info("Synchronizing meters")
- out_dict = {}
- checkpoint_save_keys = []
- for key, meter in self._get_meters(phases).items():
- meter_output = meter.compute_synced()
- is_better_check = getattr(meter, "is_better", None)
- for meter_subkey, meter_value in meter_output.items():
- out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
- if is_better_check is None:
- continue
- tracked_meter_key = os.path.join(key, meter_subkey)
- if tracked_meter_key not in self.best_meter_values or is_better_check(
- meter_value,
- self.best_meter_values[tracked_meter_key],
- ):
- self.best_meter_values[tracked_meter_key] = meter_value
- if (
- self.checkpoint_conf.save_best_meters is not None
- and key in self.checkpoint_conf.save_best_meters
- ):
- checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
- if len(checkpoint_save_keys) > 0:
- self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
- return out_dict
- def _log_timers(self, phase):
- time_remaining = 0
- epochs_remaining = self.max_epochs - self.epoch - 1
- val_epochs_remaining = sum(
- n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs)
- )
- # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
- # the end epoch.
- if (self.max_epochs - 1) % self.val_epoch_freq != 0:
- val_epochs_remaining += 1
- # Remove the current val run from estimate
- if phase == Phase.VAL:
- val_epochs_remaining -= 1
- time_remaining += (
- epochs_remaining * self.est_epoch_time[Phase.TRAIN]
- + val_epochs_remaining * self.est_epoch_time[Phase.VAL]
- )
- self.logger.log(
- os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
- self.time_elapsed_meter.val,
- self.steps[phase],
- )
- logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
- def _reset_meters(self, phases: str) -> None:
- for meter in self._get_meters(phases).values():
- meter.reset()
- def _check_val_key_match(self, val_keys, phase):
- if val_keys is not None:
- # Check if there are any duplicates
- assert len(val_keys) == len(set(val_keys)), (
- f"Duplicate keys in val datasets, keys: {val_keys}"
- )
- # Check that the keys match the meter keys
- if self.meters_conf is not None and phase in self.meters_conf:
- assert set(val_keys) == set(self.meters_conf[phase].keys()), (
- f"Keys in val datasets do not match the keys in meters."
- f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
- f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
- )
- if self.loss_conf is not None:
- loss_keys = set(self.loss_conf.keys()) - set(["all"])
- if "default" not in loss_keys:
- for k in val_keys:
- assert k in loss_keys, (
- f"Error: key {k} is not defined in the losses, and no default is set"
- )
- def _setup_components(self):
- # Get the keys for all the val datasets, if any
- val_phase = Phase.VAL
- val_keys = None
- if self.data_conf.get(val_phase, None) is not None:
- val_keys = collect_dict_keys(self.data_conf[val_phase])
- # Additional checks on the sanity of the config for val datasets
- self._check_val_key_match(val_keys, phase=val_phase)
- logging.info("Setting up components: Model, loss, optim, meters etc.")
- self.epoch = 0
- self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
- self.logger = Logger(self.logging_conf)
- self.model = instantiate(self.model_conf, _convert_="all")
- print_model_summary(self.model)
- self.loss = None
- if self.loss_conf:
- self.loss = {
- key: el # wrap_base_loss(el)
- for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
- }
- self.loss = nn.ModuleDict(self.loss)
- self.meters = {}
- self.best_meter_values = {}
- if self.meters_conf:
- self.meters = instantiate(self.meters_conf, _convert_="all")
- self.scaler = torch.amp.GradScaler(
- self.device,
- enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
- )
- self.gradient_clipper = (
- instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
- )
- self.gradient_logger = (
- instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
- )
- logging.info("Finished setting up components: Model, loss, optim, meters etc.")
- def _construct_optimizers(self):
- self.optim = construct_optimizer(
- self.model,
- self.optim_conf.optimizer,
- self.optim_conf.options,
- self.optim_conf.param_group_modifiers,
- )
- def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
- core_loss = loss.pop(CORE_LOSS_KEY)
- if step % self.logging_conf.log_scalar_frequency == 0:
- for k in loss:
- log_str = os.path.join(loss_str, k)
- self.logger.log(log_str, loss[k], step)
- return core_loss
- def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
- """
- Prints the model and the number of parameters in the model.
- # Multiple packages provide this info in a nice table format
- # However, they need us to provide an `input` (as they also write down the output sizes)
- # Our models are complex, and a single input is restrictive.
- # https://github.com/sksq96/pytorch-summary
- # https://github.com/nmhkahn/torchsummaryX
- """
- if get_rank() != 0:
- return
- param_kwargs = {}
- trainable_parameters = sum(
- p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
- )
- total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
- non_trainable_parameters = total_parameters - trainable_parameters
- logging.info("==" * 10)
- logging.info(f"Summary for model {type(model)}")
- logging.info(f"Model is {model}")
- logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
- logging.info(
- f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
- )
- logging.info(
- f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
- )
- logging.info("==" * 10)
- if log_dir:
- output_fpath = os.path.join(log_dir, "model.txt")
- with g_pathmgr.open(output_fpath, "w") as f:
- print(model, file=f)
- PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
- def get_human_readable_count(number: int) -> str:
- """
- Abbreviates an integer number with K, M, B, T for thousands, millions,
- billions and trillions, respectively.
- Examples:
- >>> get_human_readable_count(123)
- '123 '
- >>> get_human_readable_count(1234) # (one thousand)
- '1.2 K'
- >>> get_human_readable_count(2e6) # (two million)
- '2.0 M'
- >>> get_human_readable_count(3e9) # (three billion)
- '3.0 B'
- >>> get_human_readable_count(4e14) # (four hundred trillion)
- '400 T'
- >>> get_human_readable_count(5e15) # (more than trillion)
- '5,000 T'
- Args:
- number: a positive integer number
- Return:
- A string formatted according to the pattern described above.
- """
- assert number >= 0
- labels = PARAMETER_NUM_UNITS
- num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
- num_groups = int(np.ceil(num_digits / 3))
- num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
- shift = -3 * (num_groups - 1)
- number = number * (10**shift)
- index = num_groups - 1
- if index < 1 or number >= 100:
- return f"{int(number):,d} {labels[index]}"
- else:
- return f"{number:,.1f} {labels[index]}"
|