| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import fnmatch
- import inspect
- import itertools
- import logging
- import types
- from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
- )
- import hydra
- import torch
- import torch.nn as nn
- from omegaconf import DictConfig
- from torch import Tensor
- class Optimizer:
- def __init__(self, optimizer, schedulers=None) -> None:
- self.optimizer = optimizer
- self.schedulers = schedulers
- self._validate_optimizer_schedulers()
- self.step_schedulers(0.0, 0)
- def _validate_optimizer_schedulers(self):
- if self.schedulers is None:
- return
- for _, set_of_schedulers in enumerate(self.schedulers):
- for option, _ in set_of_schedulers.items():
- assert option in self.optimizer.defaults, (
- "Optimizer option "
- f"{option} not found in {self.optimizer}. Valid options are "
- f"{self.optimizer.defaults.keys()}"
- )
- def step_schedulers(self, where: float, step: int) -> None:
- if self.schedulers is None:
- return
- for i, param_group in enumerate(self.optimizer.param_groups):
- for option, scheduler in self.schedulers[i].items():
- if "step" in inspect.signature(scheduler.__call__).parameters:
- new_value = scheduler(step=step, where=where)
- elif (
- hasattr(scheduler, "scheduler")
- and "step"
- in inspect.signature(scheduler.scheduler.__call__).parameters
- ):
- # To handle ValueScaler wrappers
- new_value = scheduler(step=step, where=where)
- else:
- new_value = scheduler(where)
- param_group[option] = new_value
- def step(self, where, step, closure=None):
- self.step_schedulers(where, step)
- return self.optimizer.step(closure)
- def zero_grad(self, *args, **kwargs):
- return self.optimizer.zero_grad(*args, **kwargs)
- def set_default_parameters(
- scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
- ) -> None:
- """Set up the "default" scheduler with the right parameters.
- Args:
- scheduler_cgfs: A list of scheduler configs, where each scheduler also
- specifies which parameters it applies to, based on the names of parameters
- or the class of the modules. At most one scheduler is allowed to skip this
- specification, which is used as a "default" specification for any remaining
- parameters.
- all_parameter_names: Names of all the parameters to consider.
- """
- constraints = [
- scheduler_cfg.parameter_names
- for scheduler_cfg in scheduler_cfgs
- if scheduler_cfg.parameter_names is not None
- ]
- if len(constraints) == 0:
- default_params = set(all_parameter_names)
- else:
- default_params = all_parameter_names - set.union(*constraints)
- default_count = 0
- for scheduler_cfg in scheduler_cfgs:
- if scheduler_cfg.parameter_names is None:
- scheduler_cfg.parameter_names = default_params
- default_count += 1
- assert default_count <= 1, "Only one scheduler per option can be default"
- if default_count == 0:
- # No default scheduler specified, add a default, but without any scheduler
- # for that option
- scheduler_cfgs.append({"parameter_names": default_params})
- def name_constraints_to_parameters(
- param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
- ) -> List[torch.nn.Parameter]:
- """Return parameters which match the intersection of parameter constraints.
- Note that this returns the parameters themselves, not their names.
- Args:
- param_constraints: A list, with each element being a set of allowed parameters.
- named_parameters: Mapping from a parameter name to the parameter itself.
- Returns:
- A list containing the parameters which overlap with _each_ constraint set from
- param_constraints.
- """
- matching_names = set.intersection(*param_constraints)
- return [value for name, value in named_parameters.items() if name in matching_names]
- def map_scheduler_cfgs_to_param_groups(
- all_scheduler_cfgs: Iterable[List[Dict]],
- named_parameters: Dict[str, Tensor],
- ) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
- """Produce parameter groups corresponding to all the scheduler configs.
- Takes all the scheduler configs, each of which applies to a specific optimizer
- option (like "lr" or "weight_decay") and has a set of parameter names which it
- applies to, and produces a final set of param groups where each param group
- covers all the options which apply to a particular set of parameters.
- Args:
- all_scheduler_cfgs: All the scheduler configs covering every option.
- named_parameters: Mapping from a parameter name to the parameter itself.
- Returns:
- Tuple of lists of schedulers and param_groups, where schedulers[i]
- applies to param_groups[i].
- """
- scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
- schedulers = []
- param_groups = []
- for scheduler_cfgs in scheduler_cfgs_per_param_group:
- param_constraints = [
- scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
- ]
- matching_parameters = name_constraints_to_parameters(
- param_constraints, named_parameters
- )
- if len(matching_parameters) == 0: # If no overlap of parameters, skip
- continue
- schedulers_for_group = {
- scheduler_cfg["option"]: scheduler_cfg["scheduler"]
- for scheduler_cfg in scheduler_cfgs
- if "option" in scheduler_cfg
- }
- schedulers.append(schedulers_for_group)
- param_groups.append({"params": matching_parameters})
- return schedulers, param_groups
- def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
- """Check that the param groups are non-overlapping and cover all the parameters.
- Args:
- param_groups: List of all param groups
- model: Model to validate against. The check ensures that all the model
- parameters are part of param_groups
- """
- for pg in param_groups:
- # no param should be repeated within a group
- assert len(pg["params"]) == len(set(pg["params"]))
- parameters = [set(param_group["params"]) for param_group in param_groups]
- model_parameters = {parameter for _, parameter in model.named_parameters()}
- for p1, p2 in itertools.permutations(parameters, 2):
- assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
- assert set.union(*parameters) == model_parameters, (
- "Scheduler generated param_groups must include all parameters of the model."
- f" Found {len(set.union(*parameters))} params whereas model has"
- f" {len(model_parameters)} params"
- )
- def unix_module_cls_pattern_to_parameter_names(
- filter_module_cls_names: List[str],
- module_cls_to_param_names: Dict[Type, str],
- ) -> Union[None, Set[str]]:
- """Returns param names which pass the filters specified in filter_module_cls_names.
- Args:
- filter_module_cls_names: A list of filter strings containing class names, like
- ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
- module_cls_to_param_names: Mapping from module classes to the parameter names
- they contain. See `get_module_cls_to_param_names`.
- """
- if filter_module_cls_names is None:
- return set()
- allowed_parameter_names = []
- for module_cls_name in filter_module_cls_names:
- module_cls = hydra.utils.get_class(module_cls_name)
- if module_cls not in module_cls_to_param_names:
- raise AssertionError(
- f"module_cls_name {module_cls_name} does not "
- "match any classes in the model"
- )
- matching_parameters = module_cls_to_param_names[module_cls]
- assert (
- len(matching_parameters) > 0
- ), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
- logging.info(
- f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
- )
- allowed_parameter_names.append(matching_parameters)
- return set.union(*allowed_parameter_names)
- def unix_param_pattern_to_parameter_names(
- filter_param_names: Optional[List[str]],
- parameter_names: Dict[str, torch.Tensor],
- ) -> Union[None, Set[str]]:
- """Returns param names which pass the filters specified in filter_param_names.
- Args:
- filter_param_names: A list of unix-style filter strings with optional
- wildcards, like ["block.2.*", "block.2.linear.weight"]
- module_cls_to_param_names: Mapping from module classes to the parameter names
- they contain. See `get_module_cls_to_param_names`.
- """
- if filter_param_names is None:
- return set()
- allowed_parameter_names = []
- for param_name in filter_param_names:
- matching_parameters = set(fnmatch.filter(parameter_names, param_name))
- assert (
- len(matching_parameters) >= 1
- ), f"param_name {param_name} does not match any parameters in the model"
- logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
- allowed_parameter_names.append(matching_parameters)
- return set.union(*allowed_parameter_names)
- def _unix_pattern_to_parameter_names(
- scheduler_cfg: DictConfig,
- parameter_names: Set[str],
- module_cls_to_param_names: Dict[Type, str],
- ) -> Union[None, Set[str]]:
- """Returns param names which pass the filters specified in scheduler_cfg.
- Args:
- scheduler_cfg: The config for the scheduler
- parameter_names: The set of all parameter names which will be filtered
- """
- if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
- return None
- return unix_param_pattern_to_parameter_names(
- scheduler_cfg.get("param_names"), parameter_names
- ).union(
- unix_module_cls_pattern_to_parameter_names(
- scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
- )
- )
- def get_module_cls_to_param_names(
- model: nn.Module, param_allowlist: Set[str] = None
- ) -> Dict[Type, str]:
- """Produce a mapping from all the modules classes to the names of parames they own.
- Only counts a parameter as part of the immediate parent module, i.e. recursive
- parents do not count.
- Args:
- model: Model to iterate over
- param_allowlist: If specified, only these param names will be processed
- """
- module_cls_to_params = {}
- for module_name, module in model.named_modules():
- module_cls = type(module)
- module_cls_to_params.setdefault(module_cls, set())
- for param_name, _ in module.named_parameters(recurse=False):
- full_param_name = get_full_parameter_name(module_name, param_name)
- if param_allowlist is None or full_param_name in param_allowlist:
- module_cls_to_params[module_cls].add(full_param_name)
- return module_cls_to_params
- def construct_optimizer(
- model: torch.nn.Module,
- optimizer_conf: Any,
- options_conf: Mapping[str, List] = None,
- param_group_modifiers_conf: List[Callable] = None,
- param_allowlist: Optional[Set[str]] = None,
- validate_param_groups=True,
- ) -> Optimizer:
- """
- Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
- with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
- Batchnorm and/or no-update 1-D parameters support, based on the config.
- Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
- (LARS): https://arxiv.org/abs/1708.03888
- Args:
- model: model to perform stochastic gradient descent
- optimization or ADAM optimization.
- optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
- ADAM, still missing the params argument which this function provides to
- produce the final optimizer
- param_group_modifiers_conf: Optional user specified functions which can modify
- the final scheduler configs before the optimizer's param groups are built
- param_allowlist: The parameters to optimize. Parameters which are not part of
- this allowlist will be skipped.
- validate_param_groups: If enabled, valides that the produced param_groups don't
- overlap and cover all the model parameters.
- """
- if param_allowlist is None:
- param_allowlist = {name for name, _ in model.named_parameters()}
- named_parameters = {
- name: param
- for name, param in model.named_parameters()
- if name in param_allowlist
- }
- if not options_conf:
- optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
- return Optimizer(optimizer)
- all_parameter_names = {
- name for name, _ in model.named_parameters() if name in param_allowlist
- }
- module_cls_to_all_param_names = get_module_cls_to_param_names(
- model, param_allowlist
- )
- scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
- all_scheduler_cfgs = []
- for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
- for config in scheduler_cfgs:
- config.option = option
- config.parameter_names = _unix_pattern_to_parameter_names(
- config, all_parameter_names, module_cls_to_all_param_names
- )
- set_default_parameters(scheduler_cfgs, all_parameter_names)
- all_scheduler_cfgs.append(scheduler_cfgs)
- if param_group_modifiers_conf:
- for custom_param_modifier in param_group_modifiers_conf:
- custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
- all_scheduler_cfgs = custom_param_modifier(
- scheduler_cfgs=all_scheduler_cfgs, model=model
- )
- schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
- all_scheduler_cfgs, named_parameters
- )
- if validate_param_groups:
- validate_param_group_params(param_groups, model)
- optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
- return Optimizer(optimizer, schedulers)
- def get_full_parameter_name(module_name, param_name):
- if module_name == "":
- return param_name
- return f"{module_name}.{param_name}"
- class GradientClipper:
- """
- Gradient clipping utils that works for DDP
- """
- def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
- assert isinstance(max_norm, (int, float)) or max_norm is None
- self.max_norm = max_norm if max_norm is None else float(max_norm)
- self.norm_type = norm_type
- def __call__(self, model: nn.Module):
- if self.max_norm is None:
- return # no-op
- nn.utils.clip_grad_norm_(
- model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
- )
- class ValueScaler:
- def __init__(self, scheduler, mult_val: float):
- self.scheduler = scheduler
- self.mult_val = mult_val
- def __call__(self, *args, **kwargs):
- val = self.scheduler(*args, **kwargs)
- return val * self.mult_val
- def rgetattr(obj, rattrs: str = None):
- """
- Like getattr(), but supports dotted notation for nested objects.
- rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
- """
- if rattrs is None:
- return obj
- attrs = rattrs.split(".")
- for attr in attrs:
- obj = getattr(obj, attr)
- return obj
- def layer_decay_param_modifier(
- scheduler_cfgs: List[List[Dict]],
- model,
- layer_decay_value: float,
- layer_decay_min: Optional[float] = None,
- apply_to: Optional[str] = None,
- overrides: List[Dict] = (),
- ) -> List[List[Dict]]:
- """
- Args
- - scheduler_cfgs: a list of omegaconf.ListConfigs.
- Each element in the list is a omegaconfg.DictConfig with the following structure
- {
- "scheduler": <some fvcore scheduler>
- "option": <value> possible options are "lr", "weight_decay" etc.
- "parameter_names": Set of str indicating param names that this scheduler applies to
- }
- - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
- and a method get_num_layers.
- Alternatively, use apply_to argument to select a specific component of the model.
- - layer_decay_value: float
- - layer_decay_min: min val for layer decay
- - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
- - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
- Returns
- - scheduler_configs: same structure as the input, elements can be modified
- """
- model = rgetattr(model, apply_to)
- num_layers = model.get_num_layers() + 1
- layer_decays = [
- layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
- ]
- if layer_decay_min is not None:
- layer_decays = [max(val, layer_decay_min) for val in layer_decays]
- final_scheduler_cfgs = []
- # scheduler_cfgs is a list of lists
- for scheduler_cfg_group in scheduler_cfgs:
- curr_cfg_group = []
- # scheduler_cfg_group is a list of dictionaries
- for scheduler_cfg in scheduler_cfg_group:
- if scheduler_cfg["option"] != "lr":
- curr_cfg_group.append(scheduler_cfg)
- continue
- # Need sorted so that the list of parameter names is deterministic and consistent
- # across re-runs of this job. Else it was causing issues with loading the optimizer
- # state during a job restart (D38591759)
- parameter_names = sorted(scheduler_cfg["parameter_names"])
- # Only want one cfg group per layer
- layer_cfg_groups = {}
- for param_name in parameter_names:
- layer_id = num_layers
- this_scale = layer_decays[layer_id]
- if param_name.startswith(apply_to):
- layer_id = model.get_layer_id(param_name)
- this_scale = layer_decays[layer_id]
- # Overrides
- for override in overrides:
- if fnmatch.fnmatchcase(param_name, override["pattern"]):
- this_scale = float(override["value"])
- layer_id = override["pattern"]
- break
- if layer_id not in layer_cfg_groups:
- curr_param = {
- "option": scheduler_cfg["option"],
- "scheduler": ValueScaler(
- scheduler_cfg["scheduler"], this_scale
- ),
- "parameter_names": {param_name},
- }
- else:
- curr_param = layer_cfg_groups[layer_id]
- curr_param["parameter_names"].add(param_name)
- layer_cfg_groups[layer_id] = curr_param
- for layer_cfg in layer_cfg_groups.values():
- curr_cfg_group.append(layer_cfg)
- final_scheduler_cfgs.append(curr_cfg_group)
- return final_scheduler_cfgs
|