utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
  6. from typing import Iterable
  7. import torch
  8. from torch.utils.data import (
  9. ConcatDataset as TorchConcatDataset,
  10. Dataset,
  11. Subset as TorchSubset,
  12. )
  13. class ConcatDataset(TorchConcatDataset):
  14. def __init__(self, datasets: Iterable[Dataset]) -> None:
  15. super(ConcatDataset, self).__init__(datasets)
  16. self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
  17. def set_epoch(self, epoch: int):
  18. for dataset in self.datasets:
  19. if hasattr(dataset, "epoch"):
  20. dataset.epoch = epoch
  21. if hasattr(dataset, "set_epoch"):
  22. dataset.set_epoch(epoch)
  23. class Subset(TorchSubset):
  24. def __init__(self, dataset, indices) -> None:
  25. super(Subset, self).__init__(dataset, indices)
  26. self.repeat_factors = dataset.repeat_factors[indices]
  27. assert len(indices) == len(self.repeat_factors)
  28. # Adapted from Detectron2
  29. class RepeatFactorWrapper(Dataset):
  30. """
  31. Thin wrapper around a dataset to implement repeat factor sampling.
  32. The underlying dataset must have a repeat_factors member to indicate the per-image factor.
  33. Set it to uniformly ones to disable repeat factor sampling
  34. """
  35. def __init__(self, dataset, seed: int = 0):
  36. self.dataset = dataset
  37. self.epoch_ids = None
  38. self._seed = seed
  39. # Split into whole number (_int_part) and fractional (_frac_part) parts.
  40. self._int_part = torch.trunc(dataset.repeat_factors)
  41. self._frac_part = dataset.repeat_factors - self._int_part
  42. def _get_epoch_indices(self, generator):
  43. """
  44. Create a list of dataset indices (with repeats) to use for one epoch.
  45. Args:
  46. generator (torch.Generator): pseudo random number generator used for
  47. stochastic rounding.
  48. Returns:
  49. torch.Tensor: list of dataset indices to use in one epoch. Each index
  50. is repeated based on its calculated repeat factor.
  51. """
  52. # Since repeat factors are fractional, we use stochastic rounding so
  53. # that the target repeat factor is achieved in expectation over the
  54. # course of training
  55. rands = torch.rand(len(self._frac_part), generator=generator)
  56. rep_factors = self._int_part + (rands < self._frac_part).float()
  57. # Construct a list of indices in which we repeat images as specified
  58. indices = []
  59. for dataset_index, rep_factor in enumerate(rep_factors):
  60. indices.extend([dataset_index] * int(rep_factor.item()))
  61. return torch.tensor(indices, dtype=torch.int64)
  62. def __len__(self):
  63. if self.epoch_ids is None:
  64. # Here we raise an error instead of returning original len(self.dataset) avoid
  65. # accidentally using unwrapped length. Otherwise it's error-prone since the
  66. # length changes to `len(self.epoch_ids)`changes after set_epoch is called.
  67. raise RuntimeError("please call set_epoch first to get wrapped length")
  68. # return len(self.dataset)
  69. return len(self.epoch_ids)
  70. def set_epoch(self, epoch: int):
  71. g = torch.Generator()
  72. g.manual_seed(self._seed + epoch)
  73. self.epoch_ids = self._get_epoch_indices(g)
  74. if hasattr(self.dataset, "set_epoch"):
  75. self.dataset.set_epoch(epoch)
  76. def __getitem__(self, idx):
  77. if self.epoch_ids is None:
  78. raise RuntimeError(
  79. "Repeat ids haven't been computed. Did you forget to call set_epoch?"
  80. )
  81. return self.dataset[self.epoch_ids[idx]]