optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import fnmatch
  4. import inspect
  5. import itertools
  6. import logging
  7. import types
  8. from typing import (
  9. Any,
  10. Callable,
  11. Dict,
  12. Iterable,
  13. List,
  14. Mapping,
  15. Optional,
  16. Set,
  17. Tuple,
  18. Type,
  19. Union,
  20. )
  21. import hydra
  22. import torch
  23. import torch.nn as nn
  24. from omegaconf import DictConfig
  25. from torch import Tensor
  26. class Optimizer:
  27. def __init__(self, optimizer, schedulers=None) -> None:
  28. self.optimizer = optimizer
  29. self.schedulers = schedulers
  30. self._validate_optimizer_schedulers()
  31. self.step_schedulers(0.0, 0)
  32. def _validate_optimizer_schedulers(self):
  33. if self.schedulers is None:
  34. return
  35. for _, set_of_schedulers in enumerate(self.schedulers):
  36. for option, _ in set_of_schedulers.items():
  37. assert option in self.optimizer.defaults, (
  38. "Optimizer option "
  39. f"{option} not found in {self.optimizer}. Valid options are "
  40. f"{self.optimizer.defaults.keys()}"
  41. )
  42. def step_schedulers(self, where: float, step: int) -> None:
  43. if self.schedulers is None:
  44. return
  45. for i, param_group in enumerate(self.optimizer.param_groups):
  46. for option, scheduler in self.schedulers[i].items():
  47. if "step" in inspect.signature(scheduler.__call__).parameters:
  48. new_value = scheduler(step=step, where=where)
  49. elif (
  50. hasattr(scheduler, "scheduler")
  51. and "step"
  52. in inspect.signature(scheduler.scheduler.__call__).parameters
  53. ):
  54. # To handle ValueScaler wrappers
  55. new_value = scheduler(step=step, where=where)
  56. else:
  57. new_value = scheduler(where)
  58. param_group[option] = new_value
  59. def step(self, where, step, closure=None):
  60. self.step_schedulers(where, step)
  61. return self.optimizer.step(closure)
  62. def zero_grad(self, *args, **kwargs):
  63. return self.optimizer.zero_grad(*args, **kwargs)
  64. def set_default_parameters(
  65. scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
  66. ) -> None:
  67. """Set up the "default" scheduler with the right parameters.
  68. Args:
  69. scheduler_cgfs: A list of scheduler configs, where each scheduler also
  70. specifies which parameters it applies to, based on the names of parameters
  71. or the class of the modules. At most one scheduler is allowed to skip this
  72. specification, which is used as a "default" specification for any remaining
  73. parameters.
  74. all_parameter_names: Names of all the parameters to consider.
  75. """
  76. constraints = [
  77. scheduler_cfg.parameter_names
  78. for scheduler_cfg in scheduler_cfgs
  79. if scheduler_cfg.parameter_names is not None
  80. ]
  81. if len(constraints) == 0:
  82. default_params = set(all_parameter_names)
  83. else:
  84. default_params = all_parameter_names - set.union(*constraints)
  85. default_count = 0
  86. for scheduler_cfg in scheduler_cfgs:
  87. if scheduler_cfg.parameter_names is None:
  88. scheduler_cfg.parameter_names = default_params
  89. default_count += 1
  90. assert default_count <= 1, "Only one scheduler per option can be default"
  91. if default_count == 0:
  92. # No default scheduler specified, add a default, but without any scheduler
  93. # for that option
  94. scheduler_cfgs.append({"parameter_names": default_params})
  95. def name_constraints_to_parameters(
  96. param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
  97. ) -> List[torch.nn.Parameter]:
  98. """Return parameters which match the intersection of parameter constraints.
  99. Note that this returns the parameters themselves, not their names.
  100. Args:
  101. param_constraints: A list, with each element being a set of allowed parameters.
  102. named_parameters: Mapping from a parameter name to the parameter itself.
  103. Returns:
  104. A list containing the parameters which overlap with _each_ constraint set from
  105. param_constraints.
  106. """
  107. matching_names = set.intersection(*param_constraints)
  108. return [value for name, value in named_parameters.items() if name in matching_names]
  109. def map_scheduler_cfgs_to_param_groups(
  110. all_scheduler_cfgs: Iterable[List[Dict]],
  111. named_parameters: Dict[str, Tensor],
  112. ) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
  113. """Produce parameter groups corresponding to all the scheduler configs.
  114. Takes all the scheduler configs, each of which applies to a specific optimizer
  115. option (like "lr" or "weight_decay") and has a set of parameter names which it
  116. applies to, and produces a final set of param groups where each param group
  117. covers all the options which apply to a particular set of parameters.
  118. Args:
  119. all_scheduler_cfgs: All the scheduler configs covering every option.
  120. named_parameters: Mapping from a parameter name to the parameter itself.
  121. Returns:
  122. Tuple of lists of schedulers and param_groups, where schedulers[i]
  123. applies to param_groups[i].
  124. """
  125. scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
  126. schedulers = []
  127. param_groups = []
  128. for scheduler_cfgs in scheduler_cfgs_per_param_group:
  129. param_constraints = [
  130. scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
  131. ]
  132. matching_parameters = name_constraints_to_parameters(
  133. param_constraints, named_parameters
  134. )
  135. if len(matching_parameters) == 0: # If no overlap of parameters, skip
  136. continue
  137. schedulers_for_group = {
  138. scheduler_cfg["option"]: scheduler_cfg["scheduler"]
  139. for scheduler_cfg in scheduler_cfgs
  140. if "option" in scheduler_cfg
  141. }
  142. schedulers.append(schedulers_for_group)
  143. param_groups.append({"params": matching_parameters})
  144. return schedulers, param_groups
  145. def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
  146. """Check that the param groups are non-overlapping and cover all the parameters.
  147. Args:
  148. param_groups: List of all param groups
  149. model: Model to validate against. The check ensures that all the model
  150. parameters are part of param_groups
  151. """
  152. for pg in param_groups:
  153. # no param should be repeated within a group
  154. assert len(pg["params"]) == len(set(pg["params"]))
  155. parameters = [set(param_group["params"]) for param_group in param_groups]
  156. model_parameters = {parameter for _, parameter in model.named_parameters()}
  157. for p1, p2 in itertools.permutations(parameters, 2):
  158. assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
  159. assert set.union(*parameters) == model_parameters, (
  160. "Scheduler generated param_groups must include all parameters of the model."
  161. f" Found {len(set.union(*parameters))} params whereas model has"
  162. f" {len(model_parameters)} params"
  163. )
  164. def unix_module_cls_pattern_to_parameter_names(
  165. filter_module_cls_names: List[str],
  166. module_cls_to_param_names: Dict[Type, str],
  167. ) -> Union[None, Set[str]]:
  168. """Returns param names which pass the filters specified in filter_module_cls_names.
  169. Args:
  170. filter_module_cls_names: A list of filter strings containing class names, like
  171. ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
  172. module_cls_to_param_names: Mapping from module classes to the parameter names
  173. they contain. See `get_module_cls_to_param_names`.
  174. """
  175. if filter_module_cls_names is None:
  176. return set()
  177. allowed_parameter_names = []
  178. for module_cls_name in filter_module_cls_names:
  179. module_cls = hydra.utils.get_class(module_cls_name)
  180. if module_cls not in module_cls_to_param_names:
  181. raise AssertionError(
  182. f"module_cls_name {module_cls_name} does not "
  183. "match any classes in the model"
  184. )
  185. matching_parameters = module_cls_to_param_names[module_cls]
  186. assert len(matching_parameters) > 0, (
  187. f"module_cls_name {module_cls_name} does not contain any parameters in the model"
  188. )
  189. logging.info(
  190. f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
  191. )
  192. allowed_parameter_names.append(matching_parameters)
  193. return set.union(*allowed_parameter_names)
  194. def unix_param_pattern_to_parameter_names(
  195. filter_param_names: Optional[List[str]],
  196. parameter_names: Dict[str, torch.Tensor],
  197. ) -> Union[None, Set[str]]:
  198. """Returns param names which pass the filters specified in filter_param_names.
  199. Args:
  200. filter_param_names: A list of unix-style filter strings with optional
  201. wildcards, like ["block.2.*", "block.2.linear.weight"]
  202. module_cls_to_param_names: Mapping from module classes to the parameter names
  203. they contain. See `get_module_cls_to_param_names`.
  204. """
  205. if filter_param_names is None:
  206. return set()
  207. allowed_parameter_names = []
  208. for param_name in filter_param_names:
  209. matching_parameters = set(fnmatch.filter(parameter_names, param_name))
  210. assert len(matching_parameters) >= 1, (
  211. f"param_name {param_name} does not match any parameters in the model"
  212. )
  213. logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
  214. allowed_parameter_names.append(matching_parameters)
  215. return set.union(*allowed_parameter_names)
  216. def _unix_pattern_to_parameter_names(
  217. scheduler_cfg: DictConfig,
  218. parameter_names: Set[str],
  219. module_cls_to_param_names: Dict[Type, str],
  220. ) -> Union[None, Set[str]]:
  221. """Returns param names which pass the filters specified in scheduler_cfg.
  222. Args:
  223. scheduler_cfg: The config for the scheduler
  224. parameter_names: The set of all parameter names which will be filtered
  225. """
  226. if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
  227. return None
  228. return unix_param_pattern_to_parameter_names(
  229. scheduler_cfg.get("param_names"), parameter_names
  230. ).union(
  231. unix_module_cls_pattern_to_parameter_names(
  232. scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
  233. )
  234. )
  235. def get_module_cls_to_param_names(
  236. model: nn.Module, param_allowlist: Set[str] = None
  237. ) -> Dict[Type, str]:
  238. """Produce a mapping from all the modules classes to the names of parames they own.
  239. Only counts a parameter as part of the immediate parent module, i.e. recursive
  240. parents do not count.
  241. Args:
  242. model: Model to iterate over
  243. param_allowlist: If specified, only these param names will be processed
  244. """
  245. module_cls_to_params = {}
  246. for module_name, module in model.named_modules():
  247. module_cls = type(module)
  248. module_cls_to_params.setdefault(module_cls, set())
  249. for param_name, _ in module.named_parameters(recurse=False):
  250. full_param_name = get_full_parameter_name(module_name, param_name)
  251. if param_allowlist is None or full_param_name in param_allowlist:
  252. module_cls_to_params[module_cls].add(full_param_name)
  253. return module_cls_to_params
  254. def construct_optimizer(
  255. model: torch.nn.Module,
  256. optimizer_conf: Any,
  257. options_conf: Mapping[str, List] = None,
  258. param_group_modifiers_conf: List[Callable] = None,
  259. param_allowlist: Optional[Set[str]] = None,
  260. validate_param_groups=True,
  261. ) -> Optimizer:
  262. """
  263. Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
  264. with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
  265. Batchnorm and/or no-update 1-D parameters support, based on the config.
  266. Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
  267. (LARS): https://arxiv.org/abs/1708.03888
  268. Args:
  269. model: model to perform stochastic gradient descent
  270. optimization or ADAM optimization.
  271. optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
  272. ADAM, still missing the params argument which this function provides to
  273. produce the final optimizer
  274. param_group_modifiers_conf: Optional user specified functions which can modify
  275. the final scheduler configs before the optimizer's param groups are built
  276. param_allowlist: The parameters to optimize. Parameters which are not part of
  277. this allowlist will be skipped.
  278. validate_param_groups: If enabled, valides that the produced param_groups don't
  279. overlap and cover all the model parameters.
  280. """
  281. if param_allowlist is None:
  282. param_allowlist = {name for name, _ in model.named_parameters()}
  283. named_parameters = {
  284. name: param
  285. for name, param in model.named_parameters()
  286. if name in param_allowlist
  287. }
  288. if not options_conf:
  289. optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
  290. return Optimizer(optimizer)
  291. all_parameter_names = {
  292. name for name, _ in model.named_parameters() if name in param_allowlist
  293. }
  294. module_cls_to_all_param_names = get_module_cls_to_param_names(
  295. model, param_allowlist
  296. )
  297. scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
  298. all_scheduler_cfgs = []
  299. for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
  300. for config in scheduler_cfgs:
  301. config.option = option
  302. config.parameter_names = _unix_pattern_to_parameter_names(
  303. config, all_parameter_names, module_cls_to_all_param_names
  304. )
  305. set_default_parameters(scheduler_cfgs, all_parameter_names)
  306. all_scheduler_cfgs.append(scheduler_cfgs)
  307. if param_group_modifiers_conf:
  308. for custom_param_modifier in param_group_modifiers_conf:
  309. custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
  310. all_scheduler_cfgs = custom_param_modifier(
  311. scheduler_cfgs=all_scheduler_cfgs, model=model
  312. )
  313. schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
  314. all_scheduler_cfgs, named_parameters
  315. )
  316. if validate_param_groups:
  317. validate_param_group_params(param_groups, model)
  318. optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
  319. return Optimizer(optimizer, schedulers)
  320. def get_full_parameter_name(module_name, param_name):
  321. if module_name == "":
  322. return param_name
  323. return f"{module_name}.{param_name}"
  324. class GradientClipper:
  325. """
  326. Gradient clipping utils that works for DDP
  327. """
  328. def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
  329. assert isinstance(max_norm, (int, float)) or max_norm is None
  330. self.max_norm = max_norm if max_norm is None else float(max_norm)
  331. self.norm_type = norm_type
  332. def __call__(self, model: nn.Module):
  333. if self.max_norm is None:
  334. return # no-op
  335. nn.utils.clip_grad_norm_(
  336. model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
  337. )
  338. class ValueScaler:
  339. def __init__(self, scheduler, mult_val: float):
  340. self.scheduler = scheduler
  341. self.mult_val = mult_val
  342. def __call__(self, *args, **kwargs):
  343. val = self.scheduler(*args, **kwargs)
  344. return val * self.mult_val
  345. def rgetattr(obj, rattrs: str = None):
  346. """
  347. Like getattr(), but supports dotted notation for nested objects.
  348. rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
  349. """
  350. if rattrs is None:
  351. return obj
  352. attrs = rattrs.split(".")
  353. for attr in attrs:
  354. obj = getattr(obj, attr)
  355. return obj
  356. def layer_decay_param_modifier(
  357. scheduler_cfgs: List[List[Dict]],
  358. model,
  359. layer_decay_value: float,
  360. layer_decay_min: Optional[float] = None,
  361. apply_to: Optional[str] = None,
  362. overrides: List[Dict] = (),
  363. ) -> List[List[Dict]]:
  364. """
  365. Args
  366. - scheduler_cfgs: a list of omegaconf.ListConfigs.
  367. Each element in the list is a omegaconfg.DictConfig with the following structure
  368. {
  369. "scheduler": <some fvcore scheduler>
  370. "option": <value> possible options are "lr", "weight_decay" etc.
  371. "parameter_names": Set of str indicating param names that this scheduler applies to
  372. }
  373. - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
  374. and a method get_num_layers.
  375. Alternatively, use apply_to argument to select a specific component of the model.
  376. - layer_decay_value: float
  377. - layer_decay_min: min val for layer decay
  378. - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
  379. - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
  380. Returns
  381. - scheduler_configs: same structure as the input, elements can be modified
  382. """
  383. model = rgetattr(model, apply_to)
  384. num_layers = model.get_num_layers() + 1
  385. layer_decays = [
  386. layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
  387. ]
  388. if layer_decay_min is not None:
  389. layer_decays = [max(val, layer_decay_min) for val in layer_decays]
  390. final_scheduler_cfgs = []
  391. # scheduler_cfgs is a list of lists
  392. for scheduler_cfg_group in scheduler_cfgs:
  393. curr_cfg_group = []
  394. # scheduler_cfg_group is a list of dictionaries
  395. for scheduler_cfg in scheduler_cfg_group:
  396. if scheduler_cfg["option"] != "lr":
  397. curr_cfg_group.append(scheduler_cfg)
  398. continue
  399. # Need sorted so that the list of parameter names is deterministic and consistent
  400. # across re-runs of this job. Else it was causing issues with loading the optimizer
  401. # state during a job restart
  402. parameter_names = sorted(scheduler_cfg["parameter_names"])
  403. # Only want one cfg group per layer
  404. layer_cfg_groups = {}
  405. for param_name in parameter_names:
  406. layer_id = num_layers
  407. this_scale = layer_decays[layer_id]
  408. if param_name.startswith(apply_to):
  409. layer_id = model.get_layer_id(param_name)
  410. this_scale = layer_decays[layer_id]
  411. # Overrides
  412. for override in overrides:
  413. if fnmatch.fnmatchcase(param_name, override["pattern"]):
  414. this_scale = float(override["value"])
  415. layer_id = override["pattern"]
  416. break
  417. if layer_id not in layer_cfg_groups:
  418. curr_param = {
  419. "option": scheduler_cfg["option"],
  420. "scheduler": ValueScaler(
  421. scheduler_cfg["scheduler"], this_scale
  422. ),
  423. "parameter_names": {param_name},
  424. }
  425. else:
  426. curr_param = layer_cfg_groups[layer_id]
  427. curr_param["parameter_names"].add(param_name)
  428. layer_cfg_groups[layer_id] = curr_param
  429. for layer_cfg in layer_cfg_groups.values():
  430. curr_cfg_group.append(layer_cfg)
  431. final_scheduler_cfgs.append(curr_cfg_group)
  432. return final_scheduler_cfgs