| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- """
- Misc functions, including distributed helpers.
- Mostly copy-paste from torchvision references.
- """
- from dataclasses import dataclass
- from typing import List, Optional, Tuple, Union
- import torch
- from PIL import Image as PILImage
- from tensordict import tensorclass
- @tensorclass
- class BatchedVideoMetaData:
- """
- This class represents metadata about a batch of videos.
- Attributes:
- unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
- frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
- """
- unique_objects_identifier: torch.LongTensor
- frame_orig_size: torch.LongTensor
- @tensorclass
- class BatchedVideoDatapoint:
- """
- This class represents a batch of videos with associated annotations and metadata.
- Attributes:
- img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
- obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
- masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
- metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
- dict_key: A string key used to identify the batch.
- """
- img_batch: torch.FloatTensor
- obj_to_frame_idx: torch.IntTensor
- masks: torch.BoolTensor
- metadata: BatchedVideoMetaData
- dict_key: str
- def pin_memory(self, device=None):
- return self.apply(torch.Tensor.pin_memory, device=device)
- @property
- def num_frames(self) -> int:
- """
- Returns the number of frames per video.
- """
- return self.batch_size[0]
- @property
- def num_videos(self) -> int:
- """
- Returns the number of videos in the batch.
- """
- return self.img_batch.shape[1]
- @property
- def flat_obj_to_img_idx(self) -> torch.IntTensor:
- """
- Returns a flattened tensor containing the object to img index.
- The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
- """
- frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
- flat_idx = video_idx * self.num_frames + frame_idx
- return flat_idx
- @property
- def flat_img_batch(self) -> torch.FloatTensor:
- """
- Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
- """
- return self.img_batch.transpose(0, 1).flatten(0, 1)
- @dataclass
- class Object:
- # Id of the object in the media
- object_id: int
- # Index of the frame in the media (0 if single image)
- frame_index: int
- segment: Union[torch.Tensor, dict] # RLE dict or binary mask
- @dataclass
- class Frame:
- data: Union[torch.Tensor, PILImage.Image]
- objects: List[Object]
- @dataclass
- class VideoDatapoint:
- """Refers to an image/video and all its annotations"""
- frames: List[Frame]
- video_id: int
- size: Tuple[int, int]
- def collate_fn(
- batch: List[VideoDatapoint],
- dict_key,
- ) -> BatchedVideoDatapoint:
- """
- Args:
- batch: A list of VideoDatapoint instances.
- dict_key (str): A string key used to identify the batch.
- """
- img_batch = []
- for video in batch:
- img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
- img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
- T = img_batch.shape[0]
- # Prepare data structures for sequential processing. Per-frame processing but batched across videos.
- step_t_objects_identifier = [[] for _ in range(T)]
- step_t_frame_orig_size = [[] for _ in range(T)]
- step_t_masks = [[] for _ in range(T)]
- step_t_obj_to_frame_idx = [
- [] for _ in range(T)
- ] # List to store frame indices for each time step
- for video_idx, video in enumerate(batch):
- orig_video_id = video.video_id
- orig_frame_size = video.size
- for t, frame in enumerate(video.frames):
- objects = frame.objects
- for obj in objects:
- orig_obj_id = obj.object_id
- orig_frame_idx = obj.frame_index
- step_t_obj_to_frame_idx[t].append(
- torch.tensor([t, video_idx], dtype=torch.int)
- )
- step_t_masks[t].append(obj.segment.to(torch.bool))
- step_t_objects_identifier[t].append(
- torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
- )
- step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
- obj_to_frame_idx = torch.stack(
- [
- torch.stack(obj_to_frame_idx, dim=0)
- for obj_to_frame_idx in step_t_obj_to_frame_idx
- ],
- dim=0,
- )
- masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
- objects_identifier = torch.stack(
- [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
- )
- frame_orig_size = torch.stack(
- [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
- )
- return BatchedVideoDatapoint(
- img_batch=img_batch,
- obj_to_frame_idx=obj_to_frame_idx,
- masks=masks,
- metadata=BatchedVideoMetaData(
- unique_objects_identifier=objects_identifier,
- frame_orig_size=frame_orig_size,
- ),
- dict_key=dict_key,
- batch_size=[T],
- )
|