data_utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. """
  6. Misc functions, including distributed helpers.
  7. Mostly copy-paste from torchvision references.
  8. """
  9. from dataclasses import dataclass
  10. from typing import List, Optional, Tuple, Union
  11. import torch
  12. from PIL import Image as PILImage
  13. from tensordict import tensorclass
  14. @tensorclass
  15. class BatchedVideoMetaData:
  16. """
  17. This class represents metadata about a batch of videos.
  18. Attributes:
  19. unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
  20. frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
  21. """
  22. unique_objects_identifier: torch.LongTensor
  23. frame_orig_size: torch.LongTensor
  24. @tensorclass
  25. class BatchedVideoDatapoint:
  26. """
  27. This class represents a batch of videos with associated annotations and metadata.
  28. Attributes:
  29. img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
  30. obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
  31. masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
  32. metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
  33. dict_key: A string key used to identify the batch.
  34. """
  35. img_batch: torch.FloatTensor
  36. obj_to_frame_idx: torch.IntTensor
  37. masks: torch.BoolTensor
  38. metadata: BatchedVideoMetaData
  39. dict_key: str
  40. def pin_memory(self, device=None):
  41. return self.apply(torch.Tensor.pin_memory, device=device)
  42. @property
  43. def num_frames(self) -> int:
  44. """
  45. Returns the number of frames per video.
  46. """
  47. return self.batch_size[0]
  48. @property
  49. def num_videos(self) -> int:
  50. """
  51. Returns the number of videos in the batch.
  52. """
  53. return self.img_batch.shape[1]
  54. @property
  55. def flat_obj_to_img_idx(self) -> torch.IntTensor:
  56. """
  57. Returns a flattened tensor containing the object to img index.
  58. The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
  59. """
  60. frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
  61. flat_idx = video_idx * self.num_frames + frame_idx
  62. return flat_idx
  63. @property
  64. def flat_img_batch(self) -> torch.FloatTensor:
  65. """
  66. Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
  67. """
  68. return self.img_batch.transpose(0, 1).flatten(0, 1)
  69. @dataclass
  70. class Object:
  71. # Id of the object in the media
  72. object_id: int
  73. # Index of the frame in the media (0 if single image)
  74. frame_index: int
  75. segment: Union[torch.Tensor, dict] # RLE dict or binary mask
  76. @dataclass
  77. class Frame:
  78. data: Union[torch.Tensor, PILImage.Image]
  79. objects: List[Object]
  80. @dataclass
  81. class VideoDatapoint:
  82. """Refers to an image/video and all its annotations"""
  83. frames: List[Frame]
  84. video_id: int
  85. size: Tuple[int, int]
  86. def collate_fn(
  87. batch: List[VideoDatapoint],
  88. dict_key,
  89. ) -> BatchedVideoDatapoint:
  90. """
  91. Args:
  92. batch: A list of VideoDatapoint instances.
  93. dict_key (str): A string key used to identify the batch.
  94. """
  95. img_batch = []
  96. for video in batch:
  97. img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
  98. img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
  99. T = img_batch.shape[0]
  100. # Prepare data structures for sequential processing. Per-frame processing but batched across videos.
  101. step_t_objects_identifier = [[] for _ in range(T)]
  102. step_t_frame_orig_size = [[] for _ in range(T)]
  103. step_t_masks = [[] for _ in range(T)]
  104. step_t_obj_to_frame_idx = [
  105. [] for _ in range(T)
  106. ] # List to store frame indices for each time step
  107. for video_idx, video in enumerate(batch):
  108. orig_video_id = video.video_id
  109. orig_frame_size = video.size
  110. for t, frame in enumerate(video.frames):
  111. objects = frame.objects
  112. for obj in objects:
  113. orig_obj_id = obj.object_id
  114. orig_frame_idx = obj.frame_index
  115. step_t_obj_to_frame_idx[t].append(
  116. torch.tensor([t, video_idx], dtype=torch.int)
  117. )
  118. step_t_masks[t].append(obj.segment.to(torch.bool))
  119. step_t_objects_identifier[t].append(
  120. torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
  121. )
  122. step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
  123. obj_to_frame_idx = torch.stack(
  124. [
  125. torch.stack(obj_to_frame_idx, dim=0)
  126. for obj_to_frame_idx in step_t_obj_to_frame_idx
  127. ],
  128. dim=0,
  129. )
  130. masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
  131. objects_identifier = torch.stack(
  132. [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
  133. )
  134. frame_orig_size = torch.stack(
  135. [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
  136. )
  137. return BatchedVideoDatapoint(
  138. img_batch=img_batch,
  139. obj_to_frame_idx=obj_to_frame_idx,
  140. masks=masks,
  141. metadata=BatchedVideoMetaData(
  142. unique_objects_identifier=objects_identifier,
  143. frame_orig_size=frame_orig_size,
  144. ),
  145. dict_key=dict_key,
  146. batch_size=[T],
  147. )