# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe import inspect from functools import wraps from typing import Callable, TypeVar, Union import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from torch.utils._pytree import tree_map_only # Type variables for better type hinting T = TypeVar("T") Module = TypeVar("Module", bound=nn.Module) def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable: """ Wraps a given module to enable or disable activation checkpointing. Activation checkpointing (gradient checkpointing) trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them in memory during the forward pass. When activation checkpointing is enabled, the wrapper expects only keyword arguments, and it maps these to positional arguments based on the module's signature. Args: module: The module or function to wrap with activation checkpointing Returns: A wrapped callable that supports activation checkpointing Usage: The returned wrapper function can be called with the same arguments as the original module, with an additional `act_ckpt_enable` keyword argument to control activation checkpointing and optional `use_reentrant` parameter. Example: ```python wrapped_module = activation_ckpt_wrapper(my_module) output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True) ``` """ @wraps(module) def act_ckpt_wrapper( *args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs ): if act_ckpt_enable: if len(args) > 0: raise ValueError( "This wrapper expects keyword arguments only when `act_ckpt_enable=True`" ) # Get the signature of the target function/module callable_fn = module.forward if isinstance(module, nn.Module) else module sig = inspect.signature(callable_fn) # Create a mapping of parameter names to their default values param_defaults = { name: param.default for name, param in sig.parameters.items() } args = [] for p_name in param_defaults.keys(): if p_name in kwargs: args.append(kwargs.pop(p_name)) elif param_defaults[p_name] is not inspect.Parameter.empty: # Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None args.append(param_defaults[p_name]) elif ( sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD ): # Skip **kwargs parameter raise ValueError(f"Missing positional argument: {p_name}") # Scan remaining kwargs for torch.Tensor remaining_keys = list(kwargs.keys()) for key in remaining_keys: if isinstance(kwargs[key], torch.Tensor): # Remove the tensor from kwargs, assuming it's not required by the module. # If it is required, the module's signature should be modified to accept it as a positional or keyword argument. kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_" ret = checkpoint.checkpoint( module, *args, use_reentrant=use_reentrant, **kwargs ) else: ret = module(*args, **kwargs) return ret return act_ckpt_wrapper def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]: """ Clone the CUDA output tensors of a function to avoid in-place operations. This wrapper is useful when working with torch.compile to prevent errors related to in-place operations on tensors. Args: f: The function whose CUDA tensor outputs should be cloned Returns: A wrapped function that clones any CUDA tensor outputs """ @wraps(f) def wrapped(*args, **kwargs): outputs = f(*args, **kwargs) return tree_map_only( torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs ) return wrapped