memory.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. from contextlib import contextmanager
  5. from functools import wraps
  6. import torch
  7. __all__ = ["retry_if_cuda_oom"]
  8. @contextmanager
  9. def _ignore_torch_cuda_oom():
  10. """
  11. A context which ignores CUDA OOM exception from pytorch.
  12. """
  13. try:
  14. yield
  15. except RuntimeError as e:
  16. # NOTE: the string may change?
  17. if "CUDA out of memory. " in str(e):
  18. pass
  19. else:
  20. raise
  21. def retry_if_cuda_oom(func):
  22. """
  23. Makes a function retry itself after encountering
  24. pytorch's CUDA OOM error.
  25. It will first retry after calling `torch.cuda.empty_cache()`.
  26. If that still fails, it will then retry by trying to convert inputs to CPUs.
  27. In this case, it expects the function to dispatch to CPU implementation.
  28. The return values may become CPU tensors as well and it's user's
  29. responsibility to convert it back to CUDA tensor if needed.
  30. Args:
  31. func: a stateless callable that takes tensor-like objects as arguments
  32. Returns:
  33. a callable which retries `func` if OOM is encountered.
  34. Examples:
  35. ::
  36. output = retry_if_cuda_oom(some_torch_function)(input1, input2)
  37. # output may be on CPU even if inputs are on GPU
  38. Note:
  39. 1. When converting inputs to CPU, it will only look at each argument and check
  40. if it has `.device` and `.to` for conversion. Nested structures of tensors
  41. are not supported.
  42. 2. Since the function might be called more than once, it has to be
  43. stateless.
  44. """
  45. def maybe_to_cpu(x):
  46. try:
  47. like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
  48. except AttributeError:
  49. like_gpu_tensor = False
  50. if like_gpu_tensor:
  51. return x.to(device="cpu")
  52. else:
  53. return x
  54. @wraps(func)
  55. def wrapped(*args, **kwargs):
  56. with _ignore_torch_cuda_oom():
  57. return func(*args, **kwargs)
  58. # Clear cache and retry
  59. torch.cuda.empty_cache()
  60. with _ignore_torch_cuda_oom():
  61. return func(*args, **kwargs)
  62. # Try on CPU. This slows down the code significantly, therefore print a notice.
  63. logger = logging.getLogger(__name__)
  64. logger.info(
  65. "Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
  66. )
  67. new_args = (maybe_to_cpu(x) for x in args)
  68. new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
  69. return func(*new_args, **new_kwargs)
  70. return wrapped