misc.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from collections import defaultdict
  4. from dataclasses import fields, is_dataclass
  5. from typing import Any, Mapping, Protocol, runtime_checkable
  6. import torch
  7. def _is_named_tuple(x) -> bool:
  8. return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
  9. @runtime_checkable
  10. class _CopyableData(Protocol):
  11. def to(self, device: torch.device, *args: Any, **kwargs: Any):
  12. """Copy data to the specified device"""
  13. ...
  14. def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
  15. """Function that recursively copies data to a torch.device.
  16. Args:
  17. data: The data to copy to device
  18. device: The device to which the data should be copied
  19. args: positional arguments that will be passed to the `to` call
  20. kwargs: keyword arguments that will be passed to the `to` call
  21. Returns:
  22. The data on the correct device
  23. """
  24. if _is_named_tuple(data):
  25. return type(data)(
  26. **copy_data_to_device(data._asdict(), device, *args, **kwargs)
  27. )
  28. elif isinstance(data, (list, tuple)):
  29. return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
  30. elif isinstance(data, defaultdict):
  31. return type(data)(
  32. data.default_factory,
  33. {
  34. k: copy_data_to_device(v, device, *args, **kwargs)
  35. for k, v in data.items()
  36. },
  37. )
  38. elif isinstance(data, Mapping):
  39. return type(data)(
  40. {
  41. k: copy_data_to_device(v, device, *args, **kwargs)
  42. for k, v in data.items()
  43. }
  44. )
  45. elif is_dataclass(data) and not isinstance(data, type):
  46. new_data_class = type(data)(
  47. **{
  48. field.name: copy_data_to_device(
  49. getattr(data, field.name), device, *args, **kwargs
  50. )
  51. for field in fields(data)
  52. if field.init
  53. }
  54. )
  55. for field in fields(data):
  56. if not field.init:
  57. setattr(
  58. new_data_class,
  59. field.name,
  60. copy_data_to_device(
  61. getattr(data, field.name), device, *args, **kwargs
  62. ),
  63. )
  64. return new_data_class
  65. elif isinstance(data, _CopyableData):
  66. return data.to(device, *args, **kwargs)
  67. return data