| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import inspect
- from functools import wraps
- from typing import Callable, TypeVar, Union
- import torch
- import torch.nn as nn
- import torch.utils.checkpoint as checkpoint
- from torch.utils._pytree import tree_map_only
- # Type variables for better type hinting
- T = TypeVar("T")
- Module = TypeVar("Module", bound=nn.Module)
- def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
- """
- Wraps a given module to enable or disable activation checkpointing.
- Activation checkpointing (gradient checkpointing) trades compute for memory by
- recomputing intermediate activations during the backward pass instead of storing
- them in memory during the forward pass.
- When activation checkpointing is enabled, the wrapper expects only keyword arguments,
- and it maps these to positional arguments based on the module's signature.
- Args:
- module: The module or function to wrap with activation checkpointing
- Returns:
- A wrapped callable that supports activation checkpointing
- Usage:
- The returned wrapper function can be called with the same arguments as the
- original module, with an additional `act_ckpt_enable` keyword argument to control
- activation checkpointing and optional `use_reentrant` parameter.
- Example:
- ```python
- wrapped_module = activation_ckpt_wrapper(my_module)
- output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
- ```
- """
- @wraps(module)
- def act_ckpt_wrapper(
- *args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
- ):
- if act_ckpt_enable:
- if len(args) > 0:
- raise ValueError(
- "This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
- )
- # Get the signature of the target function/module
- callable_fn = module.forward if isinstance(module, nn.Module) else module
- sig = inspect.signature(callable_fn)
- # Create a mapping of parameter names to their default values
- param_defaults = {
- name: param.default for name, param in sig.parameters.items()
- }
- args = []
- for p_name in param_defaults.keys():
- if p_name in kwargs:
- args.append(kwargs.pop(p_name))
- elif param_defaults[p_name] is not inspect.Parameter.empty:
- # Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
- args.append(param_defaults[p_name])
- elif (
- sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
- ): # Skip **kwargs parameter
- raise ValueError(f"Missing positional argument: {p_name}")
- # Scan remaining kwargs for torch.Tensor
- remaining_keys = list(kwargs.keys())
- for key in remaining_keys:
- if isinstance(kwargs[key], torch.Tensor):
- # Remove the tensor from kwargs, assuming it's not required by the module.
- # If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
- kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
- ret = checkpoint.checkpoint(
- module, *args, use_reentrant=use_reentrant, **kwargs
- )
- else:
- ret = module(*args, **kwargs)
- return ret
- return act_ckpt_wrapper
- def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
- """
- Clone the CUDA output tensors of a function to avoid in-place operations.
- This wrapper is useful when working with torch.compile to prevent errors
- related to in-place operations on tensors.
- Args:
- f: The function whose CUDA tensor outputs should be cloned
- Returns:
- A wrapped function that clones any CUDA tensor outputs
- """
- @wraps(f)
- def wrapped(*args, **kwargs):
- outputs = f(*args, **kwargs)
- return tree_map_only(
- torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
- )
- return wrapped
|