| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- from collections import defaultdict
- from dataclasses import fields, is_dataclass
- from typing import Any, Mapping, Protocol, runtime_checkable
- import torch
- def _is_named_tuple(x) -> bool:
- return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
- @runtime_checkable
- class _CopyableData(Protocol):
- def to(self, device: torch.device, *args: Any, **kwargs: Any):
- """Copy data to the specified device"""
- ...
- def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
- """Function that recursively copies data to a torch.device.
- Args:
- data: The data to copy to device
- device: The device to which the data should be copied
- args: positional arguments that will be passed to the `to` call
- kwargs: keyword arguments that will be passed to the `to` call
- Returns:
- The data on the correct device
- """
- if _is_named_tuple(data):
- return type(data)(
- **copy_data_to_device(data._asdict(), device, *args, **kwargs)
- )
- elif isinstance(data, (list, tuple)):
- return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
- elif isinstance(data, defaultdict):
- return type(data)(
- data.default_factory,
- {
- k: copy_data_to_device(v, device, *args, **kwargs)
- for k, v in data.items()
- },
- )
- elif isinstance(data, Mapping):
- return type(data)(
- {
- k: copy_data_to_device(v, device, *args, **kwargs)
- for k, v in data.items()
- }
- )
- elif is_dataclass(data) and not isinstance(data, type):
- new_data_class = type(data)(
- **{
- field.name: copy_data_to_device(
- getattr(data, field.name), device, *args, **kwargs
- )
- for field in fields(data)
- if field.init
- }
- )
- for field in fields(data):
- if not field.init:
- setattr(
- new_data_class,
- field.name,
- copy_data_to_device(
- getattr(data, field.name), device, *args, **kwargs
- ),
- )
- return new_data_class
- elif isinstance(data, _CopyableData):
- return data.to(device, *args, **kwargs)
- return data
|