| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import random
- from dataclasses import dataclass
- from typing import List
- from training.dataset.vos_segment_loader import LazySegments
- MAX_RETRIES = 1000
- @dataclass
- class SampledFramesAndObjects:
- frames: List[int]
- object_ids: List[int]
- class VOSSampler:
- def __init__(self, sort_frames=True):
- # frames are ordered by frame id when sort_frames is True
- self.sort_frames = sort_frames
- def sample(self, video):
- raise NotImplementedError()
- class RandomUniformSampler(VOSSampler):
- def __init__(
- self,
- num_frames,
- max_num_objects,
- reverse_time_prob=0.0,
- ):
- self.num_frames = num_frames
- self.max_num_objects = max_num_objects
- self.reverse_time_prob = reverse_time_prob
- def sample(self, video, segment_loader, epoch=None):
- for retry in range(MAX_RETRIES):
- if len(video.frames) < self.num_frames:
- raise Exception(
- f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
- )
- start = random.randrange(0, len(video.frames) - self.num_frames + 1)
- frames = [video.frames[start + step] for step in range(self.num_frames)]
- if random.uniform(0, 1) < self.reverse_time_prob:
- # Reverse time
- frames = frames[::-1]
- # Get first frame object ids
- visible_object_ids = []
- loaded_segms = segment_loader.load(frames[0].frame_idx)
- if isinstance(loaded_segms, LazySegments):
- # LazySegments for SA1BRawDataset
- visible_object_ids = list(loaded_segms.keys())
- else:
- for object_id, segment in segment_loader.load(
- frames[0].frame_idx
- ).items():
- if segment.sum():
- visible_object_ids.append(object_id)
- # First frame needs to have at least a target to track
- if len(visible_object_ids) > 0:
- break
- if retry >= MAX_RETRIES - 1:
- raise Exception("No visible objects")
- object_ids = random.sample(
- visible_object_ids,
- min(len(visible_object_ids), self.max_num_objects),
- )
- return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
- class EvalSampler(VOSSampler):
- """
- VOS Sampler for evaluation: sampling all the frames and all the objects in a video
- """
- def __init__(
- self,
- ):
- super().__init__()
- def sample(self, video, segment_loader, epoch=None):
- """
- Sampling all the frames and all the objects
- """
- if self.sort_frames:
- # ordered by frame id
- frames = sorted(video.frames, key=lambda x: x.frame_idx)
- else:
- # use the original order
- frames = video.frames
- object_ids = segment_loader.load(frames[0].frame_idx).keys()
- if len(object_ids) == 0:
- raise Exception("First frame of the video has no objects")
- return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|