act_ckpt_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import inspect
  4. from functools import wraps
  5. from typing import Callable, TypeVar, Union
  6. import torch
  7. import torch.nn as nn
  8. import torch.utils.checkpoint as checkpoint
  9. from torch.utils._pytree import tree_map_only
  10. # Type variables for better type hinting
  11. T = TypeVar("T")
  12. Module = TypeVar("Module", bound=nn.Module)
  13. def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
  14. """
  15. Wraps a given module to enable or disable activation checkpointing.
  16. Activation checkpointing (gradient checkpointing) trades compute for memory by
  17. recomputing intermediate activations during the backward pass instead of storing
  18. them in memory during the forward pass.
  19. When activation checkpointing is enabled, the wrapper expects only keyword arguments,
  20. and it maps these to positional arguments based on the module's signature.
  21. Args:
  22. module: The module or function to wrap with activation checkpointing
  23. Returns:
  24. A wrapped callable that supports activation checkpointing
  25. Usage:
  26. The returned wrapper function can be called with the same arguments as the
  27. original module, with an additional `act_ckpt_enable` keyword argument to control
  28. activation checkpointing and optional `use_reentrant` parameter.
  29. Example:
  30. ```python
  31. wrapped_module = activation_ckpt_wrapper(my_module)
  32. output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
  33. ```
  34. """
  35. @wraps(module)
  36. def act_ckpt_wrapper(
  37. *args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
  38. ):
  39. if act_ckpt_enable:
  40. if len(args) > 0:
  41. raise ValueError(
  42. "This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
  43. )
  44. # Get the signature of the target function/module
  45. callable_fn = module.forward if isinstance(module, nn.Module) else module
  46. sig = inspect.signature(callable_fn)
  47. # Create a mapping of parameter names to their default values
  48. param_defaults = {
  49. name: param.default for name, param in sig.parameters.items()
  50. }
  51. args = []
  52. for p_name in param_defaults.keys():
  53. if p_name in kwargs:
  54. args.append(kwargs.pop(p_name))
  55. elif param_defaults[p_name] is not inspect.Parameter.empty:
  56. # Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
  57. args.append(param_defaults[p_name])
  58. elif (
  59. sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
  60. ): # Skip **kwargs parameter
  61. raise ValueError(f"Missing positional argument: {p_name}")
  62. # Scan remaining kwargs for torch.Tensor
  63. remaining_keys = list(kwargs.keys())
  64. for key in remaining_keys:
  65. if isinstance(kwargs[key], torch.Tensor):
  66. # Remove the tensor from kwargs, assuming it's not required by the module.
  67. # If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
  68. kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
  69. ret = checkpoint.checkpoint(
  70. module, *args, use_reentrant=use_reentrant, **kwargs
  71. )
  72. else:
  73. ret = module(*args, **kwargs)
  74. return ret
  75. return act_ckpt_wrapper
  76. def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
  77. """
  78. Clone the CUDA output tensors of a function to avoid in-place operations.
  79. This wrapper is useful when working with torch.compile to prevent errors
  80. related to in-place operations on tensors.
  81. Args:
  82. f: The function whose CUDA tensor outputs should be cloned
  83. Returns:
  84. A wrapped function that clones any CUDA tensor outputs
  85. """
  86. @wraps(f)
  87. def wrapped(*args, **kwargs):
  88. outputs = f(*args, **kwargs)
  89. return tree_map_only(
  90. torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
  91. )
  92. return wrapped