| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import contextlib
- import fnmatch
- import logging
- from typing import (
- Any,
- Callable,
- Dict,
- List,
- Mapping,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
- )
- import numpy as np
- import torch
- import torch.nn as nn
- from iopath.common.file_io import g_pathmgr
- from torch.jit._script import RecursiveScriptModule
- def unix_pattern_to_parameter_names(
- constraints: List[str], all_parameter_names: Sequence[str]
- ) -> Union[None, Set[str]]:
- """
- Go through the list of parameter names and select those that match
- any of the provided constraints
- """
- parameter_names = []
- for param_name in constraints:
- matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
- assert (
- len(matching_parameters) > 0
- ), f"param_names {param_name} don't match any param in the given names."
- parameter_names.append(matching_parameters)
- return set.union(*parameter_names)
- def filter_params_matching_unix_pattern(
- patterns: List[str], state_dict: Dict[str, torch.Tensor]
- ) -> Dict[str, torch.Tensor]:
- """
- Remove from the state dictionary the parameters matching the provided unix patterns
- Args:
- patterns: the list of unix patterns to exclude
- state_dict: the dictionary to filter
- Returns:
- A new state dictionary
- """
- if len(patterns) == 0:
- return {}
- all_keys = list(state_dict.keys())
- included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
- return {k: state_dict[k] for k in included_keys}
- def exclude_params_matching_unix_pattern(
- patterns: List[str], state_dict: Dict[str, torch.Tensor]
- ) -> Dict[str, torch.Tensor]:
- """
- Remove from the state dictionary the parameters matching the provided unix patterns
- Args:
- patterns: the list of unix patterns to exclude
- state_dict: the dictionary to filter
- Returns:
- A new state dictionary
- """
- if len(patterns) == 0:
- return state_dict
- all_keys = list(state_dict.keys())
- excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
- return {k: v for k, v in state_dict.items() if k not in excluded_keys}
- def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
- keys = []
- trace = []
- for k, v in state_dict.items():
- keys.append(k)
- trace.append(v.sum().item())
- trace = np.array(trace)[np.argsort(keys)]
- return trace
- def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
- """
- Verifies that all the parameters matching the provided patterns
- are frozen - this acts as a safeguard when ignoring parameter
- when saving checkpoints - if the parameters are in fact trainable
- """
- if not patterns:
- return
- frozen_state_dict = filter_params_matching_unix_pattern(
- patterns=patterns, state_dict=model.state_dict()
- )
- non_frozen_keys = {
- n
- for n, p in model.named_parameters()
- if n in frozen_state_dict and p.requires_grad
- }
- if non_frozen_keys:
- raise ValueError(
- f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
- )
- @contextlib.contextmanager
- def with_check_parameter_frozen(
- model: nn.Module, patterns: List[str], disabled: bool = True
- ):
- """
- Context manager that inspects a model surrounding a piece of code
- and verifies if the model has been updated by this piece of code
- The function will raise an exception if the model has been updated
- on at least one of the parameter that matches one of the pattern
- Args:
- model: the model that might have been updated
- patterns: for the parameters we want to observe
- allowed:
- """
- if not patterns or disabled:
- yield
- return
- frozen_state_dict = filter_params_matching_unix_pattern(
- patterns=patterns, state_dict=model.state_dict()
- )
- summary_before = _get_state_dict_summary(frozen_state_dict)
- yield
- frozen_state_dict = filter_params_matching_unix_pattern(
- patterns=patterns, state_dict=model.state_dict()
- )
- summary_after = _get_state_dict_summary(frozen_state_dict)
- if not np.allclose(summary_before, summary_after, atol=1e-6):
- raise ValueError(
- f"""
- The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
- You can resolve this error by either initializing those parameters from within the model definition
- or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
- """
- )
- class CkptExcludeKernel:
- """
- Removes the keys from the given model state_dict that match the key_pattern.
- Args:
- key_pattern: Patterns used to select the keys in the state_dict
- that are eligible for this kernel.
- """
- def __init__(self, key_pattern: List[str]):
- self.key_pattern = key_pattern
- def __call__(self, state_dict: Dict):
- """
- Args:
- state_dict: A dictionary representing the given checkpoint's state dict.
- """
- if len(self.key_pattern) == 0:
- return state_dict
- exclude_keys = unix_pattern_to_parameter_names(
- self.key_pattern, state_dict.keys()
- )
- return {k: v for k, v in state_dict.items() if k not in exclude_keys}
- def load_checkpoint(
- path_list: List[str],
- pick_recursive_keys: Optional[List[str]] = None,
- map_location: str = "cpu",
- ) -> Any:
- """
- Loads a checkpoint from the specified path.
- Args:
- path_list: A list of paths which contain the checkpoint. Each element
- is tried (in order) until a file that exists is found. That file is then
- used to read the checkpoint.
- pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
- For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
- map_location (str): a function, torch.device, string or a dict specifying how to
- remap storage locations
- Returns: Model with the matchin pre-trained weights loaded.
- """
- path_exists = False
- for path in path_list:
- if g_pathmgr.exists(path):
- path_exists = True
- break
- if not path_exists:
- raise ValueError(f"No path exists in {path_list}")
- with g_pathmgr.open(path, "rb") as f:
- checkpoint = torch.load(f, map_location=map_location)
- logging.info(f"Loaded checkpoint from {path}")
- if pick_recursive_keys is not None:
- for key in pick_recursive_keys:
- checkpoint = checkpoint[key]
- return checkpoint
- def get_state_dict(checkpoint, ckpt_state_dict_keys):
- if isinstance(checkpoint, RecursiveScriptModule):
- # This is a torchscript JIT model
- return checkpoint.state_dict()
- pre_train_dict = checkpoint
- for i, key in enumerate(ckpt_state_dict_keys):
- if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
- isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
- ):
- key_str = (
- '["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
- )
- raise KeyError(
- f"'{key}' not found in checkpoint{key_str} "
- f"with keys: {pre_train_dict.keys()}"
- )
- pre_train_dict = pre_train_dict[key]
- return pre_train_dict
- def load_checkpoint_and_apply_kernels(
- checkpoint_path: str,
- checkpoint_kernels: List[Callable] = None,
- ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
- map_location: str = "cpu",
- ) -> nn.Module:
- """
- Performs checkpoint loading with a variety of pre-processing kernel applied in
- sequence.
- Args:
- checkpoint_path (str): Path to the checkpoint.
- checkpoint_kernels List(Callable): A list of checkpoint processing kernels
- to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
- `CkptExcludeKernel`, etc. These kernels are applied in the
- given order.
- ckpt_state_dict_keys (str): Keys containing the model state dict.
- map_location (str): a function, torch.device, string or a dict specifying how to
- remap storage locations
- Returns: Model with the matchin pre-trained weights loaded.
- """
- assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
- checkpoint_path
- )
- # Load the checkpoint on CPU to avoid GPU mem spike.
- with g_pathmgr.open(checkpoint_path, "rb") as f:
- checkpoint = torch.load(f, map_location=map_location)
- pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
- # Not logging into info etc since it's a huge log
- logging.debug(
- "Loaded Checkpoint State Dict pre-kernel application: %s"
- % str(", ".join(list(pre_train_dict.keys())))
- )
- # Apply kernels
- if checkpoint_kernels is not None:
- for f in checkpoint_kernels:
- pre_train_dict = f(state_dict=pre_train_dict)
- logging.debug(
- "Loaded Checkpoint State Dict Post-kernel application %s"
- % str(", ".join(list(pre_train_dict.keys())))
- )
- return pre_train_dict
- def check_load_state_dict_errors(
- missing_keys,
- unexpected_keys,
- strict: bool,
- ignore_missing_keys: List[str] = None,
- ignore_unexpected_keys: List[str] = None,
- ):
- if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
- ignored_keys = unix_pattern_to_parameter_names(
- ignore_missing_keys, missing_keys
- )
- missing_keys = [key for key in missing_keys if key not in ignored_keys]
- if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
- ignored_unexpected_keys = unix_pattern_to_parameter_names(
- ignore_unexpected_keys, unexpected_keys
- )
- unexpected_keys = [
- key for key in unexpected_keys if key not in ignored_unexpected_keys
- ]
- err = "State key mismatch."
- if unexpected_keys:
- err += f" Unexpected keys: {unexpected_keys}."
- if missing_keys:
- err += f" Missing keys: {missing_keys}."
- if unexpected_keys or missing_keys:
- logging.warning(err)
- if unexpected_keys or strict:
- raise KeyError(err)
- def load_state_dict_into_model(
- state_dict: Dict,
- model: nn.Module,
- strict: bool = True,
- ignore_missing_keys: List[str] = None,
- ignore_unexpected_keys: List[str] = None,
- checkpoint_kernels: List[Callable] = None,
- ):
- """
- Loads a state dict into the given model.
- Args:
- state_dict: A dictionary containing the model's
- state dict, or a subset if strict is False
- model: Model to load the checkpoint weights into
- strict: raise if the state_dict has missing state keys
- ignore_missing_keys: unix pattern of keys to ignore
- """
- # Apply kernels
- if checkpoint_kernels is not None:
- for f in checkpoint_kernels:
- state_dict = f(state_dict=state_dict)
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
- check_load_state_dict_errors(
- missing_keys,
- unexpected_keys,
- strict=strict,
- ignore_missing_keys=ignore_missing_keys,
- ignore_unexpected_keys=ignore_unexpected_keys,
- )
- return model
|