| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import random
- from copy import deepcopy
- import numpy as np
- import torch
- from iopath.common.file_io import g_pathmgr
- from PIL import Image as PILImage
- from torchvision.datasets.vision import VisionDataset
- from training.dataset.vos_raw_dataset import VOSRawDataset
- from training.dataset.vos_sampler import VOSSampler
- from training.dataset.vos_segment_loader import JSONSegmentLoader
- from training.utils.data_utils import Frame, Object, VideoDatapoint
- MAX_RETRIES = 100
- class VOSDataset(VisionDataset):
- def __init__(
- self,
- transforms,
- training: bool,
- video_dataset: VOSRawDataset,
- sampler: VOSSampler,
- multiplier: int,
- always_target=True,
- target_segments_available=True,
- ):
- self._transforms = transforms
- self.training = training
- self.video_dataset = video_dataset
- self.sampler = sampler
- self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
- self.repeat_factors *= multiplier
- print(f"Raw dataset length = {len(self.video_dataset)}")
- self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
- self.always_target = always_target
- self.target_segments_available = target_segments_available
- def _get_datapoint(self, idx):
- for retry in range(MAX_RETRIES):
- try:
- if isinstance(idx, torch.Tensor):
- idx = idx.item()
- # sample a video
- video, segment_loader = self.video_dataset.get_video(idx)
- # sample frames and object indices to be used in a datapoint
- sampled_frms_and_objs = self.sampler.sample(
- video, segment_loader, epoch=self.curr_epoch
- )
- break # Succesfully loaded video
- except Exception as e:
- if self.training:
- logging.warning(
- f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
- )
- idx = random.randrange(0, len(self.video_dataset))
- else:
- # Shouldn't fail to load a val video
- raise e
- datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
- for transform in self._transforms:
- datapoint = transform(datapoint, epoch=self.curr_epoch)
- return datapoint
- def construct(self, video, sampled_frms_and_objs, segment_loader):
- """
- Constructs a VideoDatapoint sample to pass to transforms
- """
- sampled_frames = sampled_frms_and_objs.frames
- sampled_object_ids = sampled_frms_and_objs.object_ids
- images = []
- rgb_images = load_images(sampled_frames)
- # Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
- for frame_idx, frame in enumerate(sampled_frames):
- w, h = rgb_images[frame_idx].size
- images.append(
- Frame(
- data=rgb_images[frame_idx],
- objects=[],
- )
- )
- # We load the gt segments associated with the current frame
- if isinstance(segment_loader, JSONSegmentLoader):
- segments = segment_loader.load(
- frame.frame_idx, obj_ids=sampled_object_ids
- )
- else:
- segments = segment_loader.load(frame.frame_idx)
- for obj_id in sampled_object_ids:
- # Extract the segment
- if obj_id in segments:
- assert (
- segments[obj_id] is not None
- ), "None targets are not supported"
- # segment is uint8 and remains uint8 throughout the transforms
- segment = segments[obj_id].to(torch.uint8)
- else:
- # There is no target, we either use a zero mask target or drop this object
- if not self.always_target:
- continue
- segment = torch.zeros(h, w, dtype=torch.uint8)
- images[frame_idx].objects.append(
- Object(
- object_id=obj_id,
- frame_index=frame.frame_idx,
- segment=segment,
- )
- )
- return VideoDatapoint(
- frames=images,
- video_id=video.video_id,
- size=(h, w),
- )
- def __getitem__(self, idx):
- return self._get_datapoint(idx)
- def __len__(self):
- return len(self.video_dataset)
- def load_images(frames):
- all_images = []
- cache = {}
- for frame in frames:
- if frame.data is None:
- # Load the frame rgb data from file
- path = frame.image_path
- if path in cache:
- all_images.append(deepcopy(all_images[cache[path]]))
- continue
- with g_pathmgr.open(path, "rb") as fopen:
- all_images.append(PILImage.open(fopen).convert("RGB"))
- cache[path] = len(all_images) - 1
- else:
- # The frame rgb data has already been loaded
- # Convert it to a PILImage
- all_images.append(tensor_2_PIL(frame.data))
- return all_images
- def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
- data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
- data = data.astype(np.uint8)
- return PILImage.fromarray(data)
|