optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import fnmatch
  6. import inspect
  7. import itertools
  8. import logging
  9. import types
  10. from typing import (
  11. Any,
  12. Callable,
  13. Dict,
  14. Iterable,
  15. List,
  16. Mapping,
  17. Optional,
  18. Set,
  19. Tuple,
  20. Type,
  21. Union,
  22. )
  23. import hydra
  24. import torch
  25. import torch.nn as nn
  26. from omegaconf import DictConfig
  27. from torch import Tensor
  28. class Optimizer:
  29. def __init__(self, optimizer, schedulers=None) -> None:
  30. self.optimizer = optimizer
  31. self.schedulers = schedulers
  32. self._validate_optimizer_schedulers()
  33. self.step_schedulers(0.0, 0)
  34. def _validate_optimizer_schedulers(self):
  35. if self.schedulers is None:
  36. return
  37. for _, set_of_schedulers in enumerate(self.schedulers):
  38. for option, _ in set_of_schedulers.items():
  39. assert option in self.optimizer.defaults, (
  40. "Optimizer option "
  41. f"{option} not found in {self.optimizer}. Valid options are "
  42. f"{self.optimizer.defaults.keys()}"
  43. )
  44. def step_schedulers(self, where: float, step: int) -> None:
  45. if self.schedulers is None:
  46. return
  47. for i, param_group in enumerate(self.optimizer.param_groups):
  48. for option, scheduler in self.schedulers[i].items():
  49. if "step" in inspect.signature(scheduler.__call__).parameters:
  50. new_value = scheduler(step=step, where=where)
  51. elif (
  52. hasattr(scheduler, "scheduler")
  53. and "step"
  54. in inspect.signature(scheduler.scheduler.__call__).parameters
  55. ):
  56. # To handle ValueScaler wrappers
  57. new_value = scheduler(step=step, where=where)
  58. else:
  59. new_value = scheduler(where)
  60. param_group[option] = new_value
  61. def step(self, where, step, closure=None):
  62. self.step_schedulers(where, step)
  63. return self.optimizer.step(closure)
  64. def zero_grad(self, *args, **kwargs):
  65. return self.optimizer.zero_grad(*args, **kwargs)
  66. def set_default_parameters(
  67. scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
  68. ) -> None:
  69. """Set up the "default" scheduler with the right parameters.
  70. Args:
  71. scheduler_cgfs: A list of scheduler configs, where each scheduler also
  72. specifies which parameters it applies to, based on the names of parameters
  73. or the class of the modules. At most one scheduler is allowed to skip this
  74. specification, which is used as a "default" specification for any remaining
  75. parameters.
  76. all_parameter_names: Names of all the parameters to consider.
  77. """
  78. constraints = [
  79. scheduler_cfg.parameter_names
  80. for scheduler_cfg in scheduler_cfgs
  81. if scheduler_cfg.parameter_names is not None
  82. ]
  83. if len(constraints) == 0:
  84. default_params = set(all_parameter_names)
  85. else:
  86. default_params = all_parameter_names - set.union(*constraints)
  87. default_count = 0
  88. for scheduler_cfg in scheduler_cfgs:
  89. if scheduler_cfg.parameter_names is None:
  90. scheduler_cfg.parameter_names = default_params
  91. default_count += 1
  92. assert default_count <= 1, "Only one scheduler per option can be default"
  93. if default_count == 0:
  94. # No default scheduler specified, add a default, but without any scheduler
  95. # for that option
  96. scheduler_cfgs.append({"parameter_names": default_params})
  97. def name_constraints_to_parameters(
  98. param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
  99. ) -> List[torch.nn.Parameter]:
  100. """Return parameters which match the intersection of parameter constraints.
  101. Note that this returns the parameters themselves, not their names.
  102. Args:
  103. param_constraints: A list, with each element being a set of allowed parameters.
  104. named_parameters: Mapping from a parameter name to the parameter itself.
  105. Returns:
  106. A list containing the parameters which overlap with _each_ constraint set from
  107. param_constraints.
  108. """
  109. matching_names = set.intersection(*param_constraints)
  110. return [value for name, value in named_parameters.items() if name in matching_names]
  111. def map_scheduler_cfgs_to_param_groups(
  112. all_scheduler_cfgs: Iterable[List[Dict]],
  113. named_parameters: Dict[str, Tensor],
  114. ) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
  115. """Produce parameter groups corresponding to all the scheduler configs.
  116. Takes all the scheduler configs, each of which applies to a specific optimizer
  117. option (like "lr" or "weight_decay") and has a set of parameter names which it
  118. applies to, and produces a final set of param groups where each param group
  119. covers all the options which apply to a particular set of parameters.
  120. Args:
  121. all_scheduler_cfgs: All the scheduler configs covering every option.
  122. named_parameters: Mapping from a parameter name to the parameter itself.
  123. Returns:
  124. Tuple of lists of schedulers and param_groups, where schedulers[i]
  125. applies to param_groups[i].
  126. """
  127. scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
  128. schedulers = []
  129. param_groups = []
  130. for scheduler_cfgs in scheduler_cfgs_per_param_group:
  131. param_constraints = [
  132. scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
  133. ]
  134. matching_parameters = name_constraints_to_parameters(
  135. param_constraints, named_parameters
  136. )
  137. if len(matching_parameters) == 0: # If no overlap of parameters, skip
  138. continue
  139. schedulers_for_group = {
  140. scheduler_cfg["option"]: scheduler_cfg["scheduler"]
  141. for scheduler_cfg in scheduler_cfgs
  142. if "option" in scheduler_cfg
  143. }
  144. schedulers.append(schedulers_for_group)
  145. param_groups.append({"params": matching_parameters})
  146. return schedulers, param_groups
  147. def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
  148. """Check that the param groups are non-overlapping and cover all the parameters.
  149. Args:
  150. param_groups: List of all param groups
  151. model: Model to validate against. The check ensures that all the model
  152. parameters are part of param_groups
  153. """
  154. for pg in param_groups:
  155. # no param should be repeated within a group
  156. assert len(pg["params"]) == len(set(pg["params"]))
  157. parameters = [set(param_group["params"]) for param_group in param_groups]
  158. model_parameters = {parameter for _, parameter in model.named_parameters()}
  159. for p1, p2 in itertools.permutations(parameters, 2):
  160. assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
  161. assert set.union(*parameters) == model_parameters, (
  162. "Scheduler generated param_groups must include all parameters of the model."
  163. f" Found {len(set.union(*parameters))} params whereas model has"
  164. f" {len(model_parameters)} params"
  165. )
  166. def unix_module_cls_pattern_to_parameter_names(
  167. filter_module_cls_names: List[str],
  168. module_cls_to_param_names: Dict[Type, str],
  169. ) -> Union[None, Set[str]]:
  170. """Returns param names which pass the filters specified in filter_module_cls_names.
  171. Args:
  172. filter_module_cls_names: A list of filter strings containing class names, like
  173. ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
  174. module_cls_to_param_names: Mapping from module classes to the parameter names
  175. they contain. See `get_module_cls_to_param_names`.
  176. """
  177. if filter_module_cls_names is None:
  178. return set()
  179. allowed_parameter_names = []
  180. for module_cls_name in filter_module_cls_names:
  181. module_cls = hydra.utils.get_class(module_cls_name)
  182. if module_cls not in module_cls_to_param_names:
  183. raise AssertionError(
  184. f"module_cls_name {module_cls_name} does not "
  185. "match any classes in the model"
  186. )
  187. matching_parameters = module_cls_to_param_names[module_cls]
  188. assert (
  189. len(matching_parameters) > 0
  190. ), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
  191. logging.info(
  192. f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
  193. )
  194. allowed_parameter_names.append(matching_parameters)
  195. return set.union(*allowed_parameter_names)
  196. def unix_param_pattern_to_parameter_names(
  197. filter_param_names: Optional[List[str]],
  198. parameter_names: Dict[str, torch.Tensor],
  199. ) -> Union[None, Set[str]]:
  200. """Returns param names which pass the filters specified in filter_param_names.
  201. Args:
  202. filter_param_names: A list of unix-style filter strings with optional
  203. wildcards, like ["block.2.*", "block.2.linear.weight"]
  204. module_cls_to_param_names: Mapping from module classes to the parameter names
  205. they contain. See `get_module_cls_to_param_names`.
  206. """
  207. if filter_param_names is None:
  208. return set()
  209. allowed_parameter_names = []
  210. for param_name in filter_param_names:
  211. matching_parameters = set(fnmatch.filter(parameter_names, param_name))
  212. assert (
  213. len(matching_parameters) >= 1
  214. ), f"param_name {param_name} does not match any parameters in the model"
  215. logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
  216. allowed_parameter_names.append(matching_parameters)
  217. return set.union(*allowed_parameter_names)
  218. def _unix_pattern_to_parameter_names(
  219. scheduler_cfg: DictConfig,
  220. parameter_names: Set[str],
  221. module_cls_to_param_names: Dict[Type, str],
  222. ) -> Union[None, Set[str]]:
  223. """Returns param names which pass the filters specified in scheduler_cfg.
  224. Args:
  225. scheduler_cfg: The config for the scheduler
  226. parameter_names: The set of all parameter names which will be filtered
  227. """
  228. if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
  229. return None
  230. return unix_param_pattern_to_parameter_names(
  231. scheduler_cfg.get("param_names"), parameter_names
  232. ).union(
  233. unix_module_cls_pattern_to_parameter_names(
  234. scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
  235. )
  236. )
  237. def get_module_cls_to_param_names(
  238. model: nn.Module, param_allowlist: Set[str] = None
  239. ) -> Dict[Type, str]:
  240. """Produce a mapping from all the modules classes to the names of parames they own.
  241. Only counts a parameter as part of the immediate parent module, i.e. recursive
  242. parents do not count.
  243. Args:
  244. model: Model to iterate over
  245. param_allowlist: If specified, only these param names will be processed
  246. """
  247. module_cls_to_params = {}
  248. for module_name, module in model.named_modules():
  249. module_cls = type(module)
  250. module_cls_to_params.setdefault(module_cls, set())
  251. for param_name, _ in module.named_parameters(recurse=False):
  252. full_param_name = get_full_parameter_name(module_name, param_name)
  253. if param_allowlist is None or full_param_name in param_allowlist:
  254. module_cls_to_params[module_cls].add(full_param_name)
  255. return module_cls_to_params
  256. def construct_optimizer(
  257. model: torch.nn.Module,
  258. optimizer_conf: Any,
  259. options_conf: Mapping[str, List] = None,
  260. param_group_modifiers_conf: List[Callable] = None,
  261. param_allowlist: Optional[Set[str]] = None,
  262. validate_param_groups=True,
  263. ) -> Optimizer:
  264. """
  265. Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
  266. with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
  267. Batchnorm and/or no-update 1-D parameters support, based on the config.
  268. Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
  269. (LARS): https://arxiv.org/abs/1708.03888
  270. Args:
  271. model: model to perform stochastic gradient descent
  272. optimization or ADAM optimization.
  273. optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
  274. ADAM, still missing the params argument which this function provides to
  275. produce the final optimizer
  276. param_group_modifiers_conf: Optional user specified functions which can modify
  277. the final scheduler configs before the optimizer's param groups are built
  278. param_allowlist: The parameters to optimize. Parameters which are not part of
  279. this allowlist will be skipped.
  280. validate_param_groups: If enabled, valides that the produced param_groups don't
  281. overlap and cover all the model parameters.
  282. """
  283. if param_allowlist is None:
  284. param_allowlist = {name for name, _ in model.named_parameters()}
  285. named_parameters = {
  286. name: param
  287. for name, param in model.named_parameters()
  288. if name in param_allowlist
  289. }
  290. if not options_conf:
  291. optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
  292. return Optimizer(optimizer)
  293. all_parameter_names = {
  294. name for name, _ in model.named_parameters() if name in param_allowlist
  295. }
  296. module_cls_to_all_param_names = get_module_cls_to_param_names(
  297. model, param_allowlist
  298. )
  299. scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
  300. all_scheduler_cfgs = []
  301. for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
  302. for config in scheduler_cfgs:
  303. config.option = option
  304. config.parameter_names = _unix_pattern_to_parameter_names(
  305. config, all_parameter_names, module_cls_to_all_param_names
  306. )
  307. set_default_parameters(scheduler_cfgs, all_parameter_names)
  308. all_scheduler_cfgs.append(scheduler_cfgs)
  309. if param_group_modifiers_conf:
  310. for custom_param_modifier in param_group_modifiers_conf:
  311. custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
  312. all_scheduler_cfgs = custom_param_modifier(
  313. scheduler_cfgs=all_scheduler_cfgs, model=model
  314. )
  315. schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
  316. all_scheduler_cfgs, named_parameters
  317. )
  318. if validate_param_groups:
  319. validate_param_group_params(param_groups, model)
  320. optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
  321. return Optimizer(optimizer, schedulers)
  322. def get_full_parameter_name(module_name, param_name):
  323. if module_name == "":
  324. return param_name
  325. return f"{module_name}.{param_name}"
  326. class GradientClipper:
  327. """
  328. Gradient clipping utils that works for DDP
  329. """
  330. def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
  331. assert isinstance(max_norm, (int, float)) or max_norm is None
  332. self.max_norm = max_norm if max_norm is None else float(max_norm)
  333. self.norm_type = norm_type
  334. def __call__(self, model: nn.Module):
  335. if self.max_norm is None:
  336. return # no-op
  337. nn.utils.clip_grad_norm_(
  338. model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
  339. )
  340. class ValueScaler:
  341. def __init__(self, scheduler, mult_val: float):
  342. self.scheduler = scheduler
  343. self.mult_val = mult_val
  344. def __call__(self, *args, **kwargs):
  345. val = self.scheduler(*args, **kwargs)
  346. return val * self.mult_val
  347. def rgetattr(obj, rattrs: str = None):
  348. """
  349. Like getattr(), but supports dotted notation for nested objects.
  350. rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
  351. """
  352. if rattrs is None:
  353. return obj
  354. attrs = rattrs.split(".")
  355. for attr in attrs:
  356. obj = getattr(obj, attr)
  357. return obj
  358. def layer_decay_param_modifier(
  359. scheduler_cfgs: List[List[Dict]],
  360. model,
  361. layer_decay_value: float,
  362. layer_decay_min: Optional[float] = None,
  363. apply_to: Optional[str] = None,
  364. overrides: List[Dict] = (),
  365. ) -> List[List[Dict]]:
  366. """
  367. Args
  368. - scheduler_cfgs: a list of omegaconf.ListConfigs.
  369. Each element in the list is a omegaconfg.DictConfig with the following structure
  370. {
  371. "scheduler": <some fvcore scheduler>
  372. "option": <value> possible options are "lr", "weight_decay" etc.
  373. "parameter_names": Set of str indicating param names that this scheduler applies to
  374. }
  375. - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
  376. and a method get_num_layers.
  377. Alternatively, use apply_to argument to select a specific component of the model.
  378. - layer_decay_value: float
  379. - layer_decay_min: min val for layer decay
  380. - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
  381. - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
  382. Returns
  383. - scheduler_configs: same structure as the input, elements can be modified
  384. """
  385. model = rgetattr(model, apply_to)
  386. num_layers = model.get_num_layers() + 1
  387. layer_decays = [
  388. layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
  389. ]
  390. if layer_decay_min is not None:
  391. layer_decays = [max(val, layer_decay_min) for val in layer_decays]
  392. final_scheduler_cfgs = []
  393. # scheduler_cfgs is a list of lists
  394. for scheduler_cfg_group in scheduler_cfgs:
  395. curr_cfg_group = []
  396. # scheduler_cfg_group is a list of dictionaries
  397. for scheduler_cfg in scheduler_cfg_group:
  398. if scheduler_cfg["option"] != "lr":
  399. curr_cfg_group.append(scheduler_cfg)
  400. continue
  401. # Need sorted so that the list of parameter names is deterministic and consistent
  402. # across re-runs of this job. Else it was causing issues with loading the optimizer
  403. # state during a job restart (D38591759)
  404. parameter_names = sorted(scheduler_cfg["parameter_names"])
  405. # Only want one cfg group per layer
  406. layer_cfg_groups = {}
  407. for param_name in parameter_names:
  408. layer_id = num_layers
  409. this_scale = layer_decays[layer_id]
  410. if param_name.startswith(apply_to):
  411. layer_id = model.get_layer_id(param_name)
  412. this_scale = layer_decays[layer_id]
  413. # Overrides
  414. for override in overrides:
  415. if fnmatch.fnmatchcase(param_name, override["pattern"]):
  416. this_scale = float(override["value"])
  417. layer_id = override["pattern"]
  418. break
  419. if layer_id not in layer_cfg_groups:
  420. curr_param = {
  421. "option": scheduler_cfg["option"],
  422. "scheduler": ValueScaler(
  423. scheduler_cfg["scheduler"], this_scale
  424. ),
  425. "parameter_names": {param_name},
  426. }
  427. else:
  428. curr_param = layer_cfg_groups[layer_id]
  429. curr_param["parameter_names"].add(param_name)
  430. layer_cfg_groups[layer_id] = curr_param
  431. for layer_cfg in layer_cfg_groups.values():
  432. curr_cfg_group.append(layer_cfg)
  433. final_scheduler_cfgs.append(curr_cfg_group)
  434. return final_scheduler_cfgs