| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import math
- from typing import Callable, Iterable, List, Optional, Sequence
- import torch
- from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
- from torch.utils.data.distributed import DistributedSampler
- class MixedDataLoader:
- def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
- """
- Args:
- dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
- mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
- """
- assert len(dataloaders) == mixing_prob.shape[0]
- self.dataloaders = dataloaders
- self.mixing_prob = mixing_prob
- # Iterator state
- self._iter_dls = None
- self._iter_mixing_prob = None
- self.random_generator = torch.Generator()
- def __len__(self):
- return sum([len(d) for d in self.dataloaders])
- def __iter__(self):
- # Synchronize dataloader seeds
- self.random_generator.manual_seed(42)
- self._iter_dls = [iter(loader) for loader in self.dataloaders]
- self._iter_mixing_prob = self.mixing_prob.clone()
- return self
- def __next__(self):
- """
- Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
- """
- if self._iter_dls is None:
- raise TypeError(f"{type(self).__name__} object is not an iterator")
- while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
- dataset_idx = self._iter_mixing_prob.multinomial(
- 1, generator=self.random_generator
- ).item()
- try:
- item = next(self._iter_dls[dataset_idx])
- return item
- except StopIteration:
- # No more iterations for this dataset, set it's mixing probability to zero and try again.
- self._iter_mixing_prob[dataset_idx] = 0
- except Exception as e:
- # log and raise any other unexpected error.
- logging.error(e)
- raise e
- # Exhausted all iterators
- raise StopIteration
- class TorchTrainMixedDataset:
- def __init__(
- self,
- datasets: List[Dataset],
- batch_sizes: List[int],
- num_workers: int,
- shuffle: bool,
- pin_memory: bool,
- drop_last: bool,
- collate_fn: Optional[Callable] = None,
- worker_init_fn: Optional[Callable] = None,
- phases_per_epoch: int = 1,
- dataset_prob: Optional[List[float]] = None,
- ) -> None:
- """
- Args:
- datasets (List[Dataset]): List of Datasets to be mixed.
- batch_sizes (List[int]): Batch sizes for each dataset in the list.
- num_workers (int): Number of workers per dataloader.
- shuffle (bool): Whether or not to shuffle data.
- pin_memory (bool): If True, use pinned memory when loading tensors from disk.
- drop_last (bool): Whether or not to drop the last batch of data.
- collate_fn (Callable): Function to merge a list of samples into a mini-batch.
- worker_init_fn (Callable): Function to init each dataloader worker.
- phases_per_epoch (int): Number of phases per epoch.
- dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
- """
- self.datasets = datasets
- self.batch_sizes = batch_sizes
- self.num_workers = num_workers
- self.shuffle = shuffle
- self.pin_memory = pin_memory
- self.drop_last = drop_last
- self.collate_fn = collate_fn
- self.worker_init_fn = worker_init_fn
- assert len(self.datasets) > 0
- for dataset in self.datasets:
- assert not isinstance(dataset, IterableDataset), "Not supported"
- # `RepeatFactorWrapper` requires calling set_epoch first to get its length
- self._set_dataset_epoch(dataset, 0)
- self.phases_per_epoch = phases_per_epoch
- self.chunks = [None] * len(datasets)
- if dataset_prob is None:
- # If not provided, assign each dataset a probability proportional to its length.
- dataset_lens = [
- (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
- for d, bs in zip(datasets, batch_sizes)
- ]
- total_len = sum(dataset_lens)
- dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
- else:
- assert len(dataset_prob) == len(datasets)
- dataset_prob = torch.tensor(dataset_prob)
- logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
- assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
- self.dataset_prob = dataset_prob
- def _set_dataset_epoch(self, dataset, epoch: int) -> None:
- if hasattr(dataset, "epoch"):
- dataset.epoch = epoch
- if hasattr(dataset, "set_epoch"):
- dataset.set_epoch(epoch)
- def get_loader(self, epoch) -> Iterable:
- dataloaders = []
- for d_idx, (dataset, batch_size) in enumerate(
- zip(self.datasets, self.batch_sizes)
- ):
- if self.phases_per_epoch > 1:
- # Major epoch that looops over entire dataset
- # len(main_epoch) == phases_per_epoch * len(epoch)
- main_epoch = epoch // self.phases_per_epoch
- # Phase with in the main epoch
- local_phase = epoch % self.phases_per_epoch
- # Start of new data-epoch or job is resumed after preemtion.
- if local_phase == 0 or self.chunks[d_idx] is None:
- # set seed for dataset epoch
- # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
- self._set_dataset_epoch(dataset, main_epoch)
- # Separate random generator for subset sampling
- g = torch.Generator()
- g.manual_seed(main_epoch)
- self.chunks[d_idx] = torch.chunk(
- torch.randperm(len(dataset), generator=g),
- self.phases_per_epoch,
- )
- dataset = Subset(dataset, self.chunks[d_idx][local_phase])
- else:
- self._set_dataset_epoch(dataset, epoch)
- sampler = DistributedSampler(dataset, shuffle=self.shuffle)
- sampler.set_epoch(epoch)
- batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
- dataloaders.append(
- DataLoader(
- dataset,
- num_workers=self.num_workers,
- pin_memory=self.pin_memory,
- batch_sampler=batch_sampler,
- collate_fn=self.collate_fn,
- worker_init_fn=self.worker_init_fn,
- )
- )
- return MixedDataLoader(dataloaders, self.dataset_prob)
|