torch_dataset.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Callable, Iterable, Optional
  4. from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
  5. class TorchDataset:
  6. def __init__(
  7. self,
  8. dataset: Dataset,
  9. batch_size: int,
  10. num_workers: int,
  11. shuffle: bool,
  12. pin_memory: bool,
  13. drop_last: bool,
  14. collate_fn: Optional[Callable] = None,
  15. worker_init_fn: Optional[Callable] = None,
  16. enable_distributed_sampler=True,
  17. ) -> None:
  18. self.dataset = dataset
  19. self.batch_size = batch_size
  20. self.num_workers = num_workers
  21. self.shuffle = shuffle
  22. self.pin_memory = pin_memory
  23. self.drop_last = drop_last
  24. self.collate_fn = collate_fn
  25. self.worker_init_fn = worker_init_fn
  26. assert not isinstance(self.dataset, IterableDataset), "Not supported yet"
  27. if enable_distributed_sampler:
  28. self.sampler = DistributedSampler(self.dataset, shuffle=self.shuffle)
  29. else:
  30. self.sampler = None
  31. def get_loader(self, epoch) -> Iterable:
  32. if self.sampler:
  33. self.sampler.set_epoch(epoch)
  34. if hasattr(self.dataset, "epoch"):
  35. self.dataset.epoch = epoch
  36. if hasattr(self.dataset, "set_epoch"):
  37. self.dataset.set_epoch(epoch)
  38. return DataLoader(
  39. self.dataset,
  40. batch_size=self.batch_size,
  41. num_workers=self.num_workers,
  42. pin_memory=self.pin_memory,
  43. drop_last=self.drop_last,
  44. sampler=self.sampler,
  45. collate_fn=self.collate_fn,
  46. worker_init_fn=self.worker_init_fn,
  47. )