# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe from collections import defaultdict from dataclasses import fields, is_dataclass from typing import Any, Mapping, Protocol, runtime_checkable import torch def _is_named_tuple(x) -> bool: return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") @runtime_checkable class _CopyableData(Protocol): def to(self, device: torch.device, *args: Any, **kwargs: Any): """Copy data to the specified device""" ... def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): """Function that recursively copies data to a torch.device. Args: data: The data to copy to device device: The device to which the data should be copied args: positional arguments that will be passed to the `to` call kwargs: keyword arguments that will be passed to the `to` call Returns: The data on the correct device """ if _is_named_tuple(data): return type(data)( **copy_data_to_device(data._asdict(), device, *args, **kwargs) ) elif isinstance(data, (list, tuple)): return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) elif isinstance(data, defaultdict): return type(data)( data.default_factory, { k: copy_data_to_device(v, device, *args, **kwargs) for k, v in data.items() }, ) elif isinstance(data, Mapping): return type(data)( { k: copy_data_to_device(v, device, *args, **kwargs) for k, v in data.items() } ) elif is_dataclass(data) and not isinstance(data, type): new_data_class = type(data)( **{ field.name: copy_data_to_device( getattr(data, field.name), device, *args, **kwargs ) for field in fields(data) if field.init } ) for field in fields(data): if not field.init: setattr( new_data_class, field.name, copy_data_to_device( getattr(data, field.name), device, *args, **kwargs ), ) return new_data_class elif isinstance(data, _CopyableData): return data.to(device, *args, **kwargs) return data