vos_sampler.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import random
  6. from dataclasses import dataclass
  7. from typing import List
  8. from training.dataset.vos_segment_loader import LazySegments
  9. MAX_RETRIES = 1000
  10. @dataclass
  11. class SampledFramesAndObjects:
  12. frames: List[int]
  13. object_ids: List[int]
  14. class VOSSampler:
  15. def __init__(self, sort_frames=True):
  16. # frames are ordered by frame id when sort_frames is True
  17. self.sort_frames = sort_frames
  18. def sample(self, video):
  19. raise NotImplementedError()
  20. class RandomUniformSampler(VOSSampler):
  21. def __init__(
  22. self,
  23. num_frames,
  24. max_num_objects,
  25. reverse_time_prob=0.0,
  26. ):
  27. self.num_frames = num_frames
  28. self.max_num_objects = max_num_objects
  29. self.reverse_time_prob = reverse_time_prob
  30. def sample(self, video, segment_loader, epoch=None):
  31. for retry in range(MAX_RETRIES):
  32. if len(video.frames) < self.num_frames:
  33. raise Exception(
  34. f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
  35. )
  36. start = random.randrange(0, len(video.frames) - self.num_frames + 1)
  37. frames = [video.frames[start + step] for step in range(self.num_frames)]
  38. if random.uniform(0, 1) < self.reverse_time_prob:
  39. # Reverse time
  40. frames = frames[::-1]
  41. # Get first frame object ids
  42. visible_object_ids = []
  43. loaded_segms = segment_loader.load(frames[0].frame_idx)
  44. if isinstance(loaded_segms, LazySegments):
  45. # LazySegments for SA1BRawDataset
  46. visible_object_ids = list(loaded_segms.keys())
  47. else:
  48. for object_id, segment in segment_loader.load(
  49. frames[0].frame_idx
  50. ).items():
  51. if segment.sum():
  52. visible_object_ids.append(object_id)
  53. # First frame needs to have at least a target to track
  54. if len(visible_object_ids) > 0:
  55. break
  56. if retry >= MAX_RETRIES - 1:
  57. raise Exception("No visible objects")
  58. object_ids = random.sample(
  59. visible_object_ids,
  60. min(len(visible_object_ids), self.max_num_objects),
  61. )
  62. return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
  63. class EvalSampler(VOSSampler):
  64. """
  65. VOS Sampler for evaluation: sampling all the frames and all the objects in a video
  66. """
  67. def __init__(
  68. self,
  69. ):
  70. super().__init__()
  71. def sample(self, video, segment_loader, epoch=None):
  72. """
  73. Sampling all the frames and all the objects
  74. """
  75. if self.sort_frames:
  76. # ordered by frame id
  77. frames = sorted(video.frames, key=lambda x: x.frame_idx)
  78. else:
  79. # use the original order
  80. frames = video.frames
  81. object_ids = segment_loader.load(frames[0].frame_idx).keys()
  82. if len(object_ids) == 0:
  83. raise Exception("First frame of the video has no objects")
  84. return SampledFramesAndObjects(frames=frames, object_ids=object_ids)