vos_dataset.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import random
  7. from copy import deepcopy
  8. import numpy as np
  9. import torch
  10. from iopath.common.file_io import g_pathmgr
  11. from PIL import Image as PILImage
  12. from torchvision.datasets.vision import VisionDataset
  13. from training.dataset.vos_raw_dataset import VOSRawDataset
  14. from training.dataset.vos_sampler import VOSSampler
  15. from training.dataset.vos_segment_loader import JSONSegmentLoader
  16. from training.utils.data_utils import Frame, Object, VideoDatapoint
  17. MAX_RETRIES = 100
  18. class VOSDataset(VisionDataset):
  19. def __init__(
  20. self,
  21. transforms,
  22. training: bool,
  23. video_dataset: VOSRawDataset,
  24. sampler: VOSSampler,
  25. multiplier: int,
  26. always_target=True,
  27. target_segments_available=True,
  28. ):
  29. self._transforms = transforms
  30. self.training = training
  31. self.video_dataset = video_dataset
  32. self.sampler = sampler
  33. self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
  34. self.repeat_factors *= multiplier
  35. print(f"Raw dataset length = {len(self.video_dataset)}")
  36. self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
  37. self.always_target = always_target
  38. self.target_segments_available = target_segments_available
  39. def _get_datapoint(self, idx):
  40. for retry in range(MAX_RETRIES):
  41. try:
  42. if isinstance(idx, torch.Tensor):
  43. idx = idx.item()
  44. # sample a video
  45. video, segment_loader = self.video_dataset.get_video(idx)
  46. # sample frames and object indices to be used in a datapoint
  47. sampled_frms_and_objs = self.sampler.sample(
  48. video, segment_loader, epoch=self.curr_epoch
  49. )
  50. break # Succesfully loaded video
  51. except Exception as e:
  52. if self.training:
  53. logging.warning(
  54. f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
  55. )
  56. idx = random.randrange(0, len(self.video_dataset))
  57. else:
  58. # Shouldn't fail to load a val video
  59. raise e
  60. datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
  61. for transform in self._transforms:
  62. datapoint = transform(datapoint, epoch=self.curr_epoch)
  63. return datapoint
  64. def construct(self, video, sampled_frms_and_objs, segment_loader):
  65. """
  66. Constructs a VideoDatapoint sample to pass to transforms
  67. """
  68. sampled_frames = sampled_frms_and_objs.frames
  69. sampled_object_ids = sampled_frms_and_objs.object_ids
  70. images = []
  71. rgb_images = load_images(sampled_frames)
  72. # Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
  73. for frame_idx, frame in enumerate(sampled_frames):
  74. w, h = rgb_images[frame_idx].size
  75. images.append(
  76. Frame(
  77. data=rgb_images[frame_idx],
  78. objects=[],
  79. )
  80. )
  81. # We load the gt segments associated with the current frame
  82. if isinstance(segment_loader, JSONSegmentLoader):
  83. segments = segment_loader.load(
  84. frame.frame_idx, obj_ids=sampled_object_ids
  85. )
  86. else:
  87. segments = segment_loader.load(frame.frame_idx)
  88. for obj_id in sampled_object_ids:
  89. # Extract the segment
  90. if obj_id in segments:
  91. assert (
  92. segments[obj_id] is not None
  93. ), "None targets are not supported"
  94. # segment is uint8 and remains uint8 throughout the transforms
  95. segment = segments[obj_id].to(torch.uint8)
  96. else:
  97. # There is no target, we either use a zero mask target or drop this object
  98. if not self.always_target:
  99. continue
  100. segment = torch.zeros(h, w, dtype=torch.uint8)
  101. images[frame_idx].objects.append(
  102. Object(
  103. object_id=obj_id,
  104. frame_index=frame.frame_idx,
  105. segment=segment,
  106. )
  107. )
  108. return VideoDatapoint(
  109. frames=images,
  110. video_id=video.video_id,
  111. size=(h, w),
  112. )
  113. def __getitem__(self, idx):
  114. return self._get_datapoint(idx)
  115. def __len__(self):
  116. return len(self.video_dataset)
  117. def load_images(frames):
  118. all_images = []
  119. cache = {}
  120. for frame in frames:
  121. if frame.data is None:
  122. # Load the frame rgb data from file
  123. path = frame.image_path
  124. if path in cache:
  125. all_images.append(deepcopy(all_images[cache[path]]))
  126. continue
  127. with g_pathmgr.open(path, "rb") as fopen:
  128. all_images.append(PILImage.open(fopen).convert("RGB"))
  129. cache[path] = len(all_images) - 1
  130. else:
  131. # The frame rgb data has already been loaded
  132. # Convert it to a PILImage
  133. all_images.append(tensor_2_PIL(frame.data))
  134. return all_images
  135. def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
  136. data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
  137. data = data.astype(np.uint8)
  138. return PILImage.fromarray(data)