sam2_datasets.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. from typing import Callable, Iterable, List, Optional, Sequence
  8. import torch
  9. from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
  10. from torch.utils.data.distributed import DistributedSampler
  11. class MixedDataLoader:
  12. def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
  13. """
  14. Args:
  15. dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
  16. mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
  17. """
  18. assert len(dataloaders) == mixing_prob.shape[0]
  19. self.dataloaders = dataloaders
  20. self.mixing_prob = mixing_prob
  21. # Iterator state
  22. self._iter_dls = None
  23. self._iter_mixing_prob = None
  24. self.random_generator = torch.Generator()
  25. def __len__(self):
  26. return sum([len(d) for d in self.dataloaders])
  27. def __iter__(self):
  28. # Synchronize dataloader seeds
  29. self.random_generator.manual_seed(42)
  30. self._iter_dls = [iter(loader) for loader in self.dataloaders]
  31. self._iter_mixing_prob = self.mixing_prob.clone()
  32. return self
  33. def __next__(self):
  34. """
  35. Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
  36. """
  37. if self._iter_dls is None:
  38. raise TypeError(f"{type(self).__name__} object is not an iterator")
  39. while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
  40. dataset_idx = self._iter_mixing_prob.multinomial(
  41. 1, generator=self.random_generator
  42. ).item()
  43. try:
  44. item = next(self._iter_dls[dataset_idx])
  45. return item
  46. except StopIteration:
  47. # No more iterations for this dataset, set it's mixing probability to zero and try again.
  48. self._iter_mixing_prob[dataset_idx] = 0
  49. except Exception as e:
  50. # log and raise any other unexpected error.
  51. logging.error(e)
  52. raise e
  53. # Exhausted all iterators
  54. raise StopIteration
  55. class TorchTrainMixedDataset:
  56. def __init__(
  57. self,
  58. datasets: List[Dataset],
  59. batch_sizes: List[int],
  60. num_workers: int,
  61. shuffle: bool,
  62. pin_memory: bool,
  63. drop_last: bool,
  64. collate_fn: Optional[Callable] = None,
  65. worker_init_fn: Optional[Callable] = None,
  66. phases_per_epoch: int = 1,
  67. dataset_prob: Optional[List[float]] = None,
  68. ) -> None:
  69. """
  70. Args:
  71. datasets (List[Dataset]): List of Datasets to be mixed.
  72. batch_sizes (List[int]): Batch sizes for each dataset in the list.
  73. num_workers (int): Number of workers per dataloader.
  74. shuffle (bool): Whether or not to shuffle data.
  75. pin_memory (bool): If True, use pinned memory when loading tensors from disk.
  76. drop_last (bool): Whether or not to drop the last batch of data.
  77. collate_fn (Callable): Function to merge a list of samples into a mini-batch.
  78. worker_init_fn (Callable): Function to init each dataloader worker.
  79. phases_per_epoch (int): Number of phases per epoch.
  80. dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
  81. """
  82. self.datasets = datasets
  83. self.batch_sizes = batch_sizes
  84. self.num_workers = num_workers
  85. self.shuffle = shuffle
  86. self.pin_memory = pin_memory
  87. self.drop_last = drop_last
  88. self.collate_fn = collate_fn
  89. self.worker_init_fn = worker_init_fn
  90. assert len(self.datasets) > 0
  91. for dataset in self.datasets:
  92. assert not isinstance(dataset, IterableDataset), "Not supported"
  93. # `RepeatFactorWrapper` requires calling set_epoch first to get its length
  94. self._set_dataset_epoch(dataset, 0)
  95. self.phases_per_epoch = phases_per_epoch
  96. self.chunks = [None] * len(datasets)
  97. if dataset_prob is None:
  98. # If not provided, assign each dataset a probability proportional to its length.
  99. dataset_lens = [
  100. (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
  101. for d, bs in zip(datasets, batch_sizes)
  102. ]
  103. total_len = sum(dataset_lens)
  104. dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
  105. else:
  106. assert len(dataset_prob) == len(datasets)
  107. dataset_prob = torch.tensor(dataset_prob)
  108. logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
  109. assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
  110. self.dataset_prob = dataset_prob
  111. def _set_dataset_epoch(self, dataset, epoch: int) -> None:
  112. if hasattr(dataset, "epoch"):
  113. dataset.epoch = epoch
  114. if hasattr(dataset, "set_epoch"):
  115. dataset.set_epoch(epoch)
  116. def get_loader(self, epoch) -> Iterable:
  117. dataloaders = []
  118. for d_idx, (dataset, batch_size) in enumerate(
  119. zip(self.datasets, self.batch_sizes)
  120. ):
  121. if self.phases_per_epoch > 1:
  122. # Major epoch that looops over entire dataset
  123. # len(main_epoch) == phases_per_epoch * len(epoch)
  124. main_epoch = epoch // self.phases_per_epoch
  125. # Phase with in the main epoch
  126. local_phase = epoch % self.phases_per_epoch
  127. # Start of new data-epoch or job is resumed after preemtion.
  128. if local_phase == 0 or self.chunks[d_idx] is None:
  129. # set seed for dataset epoch
  130. # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
  131. self._set_dataset_epoch(dataset, main_epoch)
  132. # Separate random generator for subset sampling
  133. g = torch.Generator()
  134. g.manual_seed(main_epoch)
  135. self.chunks[d_idx] = torch.chunk(
  136. torch.randperm(len(dataset), generator=g),
  137. self.phases_per_epoch,
  138. )
  139. dataset = Subset(dataset, self.chunks[d_idx][local_phase])
  140. else:
  141. self._set_dataset_epoch(dataset, epoch)
  142. sampler = DistributedSampler(dataset, shuffle=self.shuffle)
  143. sampler.set_epoch(epoch)
  144. batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
  145. dataloaders.append(
  146. DataLoader(
  147. dataset,
  148. num_workers=self.num_workers,
  149. pin_memory=self.pin_memory,
  150. batch_sampler=batch_sampler,
  151. collate_fn=self.collate_fn,
  152. worker_init_fn=self.worker_init_fn,
  153. )
  154. )
  155. return MixedDataLoader(dataloaders, self.dataset_prob)