| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import logging
- from contextlib import contextmanager
- from functools import wraps
- import torch
- __all__ = ["retry_if_cuda_oom"]
- @contextmanager
- def _ignore_torch_cuda_oom():
- """
- A context which ignores CUDA OOM exception from pytorch.
- """
- try:
- yield
- except RuntimeError as e:
- # NOTE: the string may change?
- if "CUDA out of memory. " in str(e):
- pass
- else:
- raise
- def retry_if_cuda_oom(func):
- """
- Makes a function retry itself after encountering
- pytorch's CUDA OOM error.
- It will first retry after calling `torch.cuda.empty_cache()`.
- If that still fails, it will then retry by trying to convert inputs to CPUs.
- In this case, it expects the function to dispatch to CPU implementation.
- The return values may become CPU tensors as well and it's user's
- responsibility to convert it back to CUDA tensor if needed.
- Args:
- func: a stateless callable that takes tensor-like objects as arguments
- Returns:
- a callable which retries `func` if OOM is encountered.
- Examples:
- ::
- output = retry_if_cuda_oom(some_torch_function)(input1, input2)
- # output may be on CPU even if inputs are on GPU
- Note:
- 1. When converting inputs to CPU, it will only look at each argument and check
- if it has `.device` and `.to` for conversion. Nested structures of tensors
- are not supported.
- 2. Since the function might be called more than once, it has to be
- stateless.
- """
- def maybe_to_cpu(x):
- try:
- like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
- except AttributeError:
- like_gpu_tensor = False
- if like_gpu_tensor:
- return x.to(device="cpu")
- else:
- return x
- @wraps(func)
- def wrapped(*args, **kwargs):
- with _ignore_torch_cuda_oom():
- return func(*args, **kwargs)
- # Clear cache and retry
- torch.cuda.empty_cache()
- with _ignore_torch_cuda_oom():
- return func(*args, **kwargs)
- # Try on CPU. This slows down the code significantly, therefore print a notice.
- logger = logging.getLogger(__name__)
- logger.info(
- "Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
- )
- new_args = (maybe_to_cpu(x) for x in args)
- new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
- return func(*new_args, **new_kwargs)
- return wrapped
|