sam3_video_dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import copy
  4. import io
  5. import json
  6. import logging
  7. import math
  8. import os
  9. import pickle
  10. import random
  11. import sys
  12. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  13. import torch
  14. import torchvision
  15. # from decord import cpu, VideoReader
  16. from iopath.common.file_io import PathManager
  17. from PIL import Image as PILImage
  18. from .sam3_image_dataset import Datapoint, Sam3ImageDataset
  19. SEED = 42
  20. class VideoGroundingDataset(Sam3ImageDataset):
  21. def __init__(
  22. self,
  23. num_stages_sample: int = 4,
  24. stage_stride_min: int = 1,
  25. stage_stride_max: int = 5,
  26. random_reverse_time_axis: bool = True,
  27. is_tiling_single_image: bool = False,
  28. # By default, we remove find those queries with geometric inputs (input_box or input_points)
  29. # when creating synthetic videos from frames (since they are not *video-level* text prompts).
  30. # If we need them later, we can sample them on-the-fly via transforms or inside the model.
  31. tile_img_keep_find_queries_with_geo_inputs: bool = False,
  32. tile_img_keep_get_queries: bool = False,
  33. # the maximum number of find queries (for each frame) to keep in a video; if the datapoint
  34. # contains more queries per frame than this limit, we subsample them to avoid OOM errors
  35. max_query_num: int = -1, # the default -1 means no limit
  36. # whether to override the "is_exhaustive" flag of the loaded find queries to True
  37. # (by default, our video datasets are ingested with is_exhaustive=False, since the YTVIS format
  38. # annotations doesn't involve an "is_exhaustive" flag; this means that those unmatched (negative)
  39. # detection queries or tracking queries do not receive a classification loss given that we have
  40. # weak_loss=True in IABCEMdetr -- this could lead to false positives for both image detection
  41. # and video association.)
  42. override_query_is_exhaustive_to_true: bool = False,
  43. # the maximum number of masklets in a video; if the datapoint contains more masklets
  44. # than this limit, we skip the datapoint to avoid OOM errors (this is useful for
  45. # training with large videos that contain many objects)
  46. max_masklet_num_in_video: int = 300, # 300 masklets is usually OK to avoid OOM
  47. **kwargs,
  48. ):
  49. """
  50. Loading video grounding data
  51. Video frame sampling parameters (for training only):
  52. - num_stages_sample: number of frames to sample from the video during training
  53. - stage_stride_min: minimum stride between sampled frames during training
  54. - stage_stride_max: maximum stride between sampled frames during training (if it's
  55. greater than stage_stride_min, the actual stride is sampled uniformly between min
  56. and max; during inference, we always use all frames in the video with stride=1)
  57. - random_reverse_time_axis: whether to randomly invert the video's temporal axis
  58. (i.e. playing it backwards) during training
  59. """
  60. super().__init__(**kwargs)
  61. assert num_stages_sample >= 1
  62. assert stage_stride_min >= 1
  63. assert stage_stride_max >= stage_stride_min
  64. self.num_stages_sample = num_stages_sample
  65. self.stage_stride_min = stage_stride_min
  66. self.stage_stride_max = stage_stride_max
  67. self.random_reverse_time_axis = random_reverse_time_axis
  68. self.is_tiling_single_image = is_tiling_single_image
  69. self.tile_img_keep_find_queries_with_geo_inputs = (
  70. tile_img_keep_find_queries_with_geo_inputs
  71. )
  72. self.tile_img_keep_get_queries = tile_img_keep_get_queries
  73. self.max_query_num = max_query_num
  74. self.override_query_is_exhaustive_to_true = override_query_is_exhaustive_to_true
  75. self.max_masklet_num_in_video = max_masklet_num_in_video
  76. self.rng = random.Random()
  77. self.set_curr_epoch(0)
  78. def set_curr_epoch(self, epoch: int):
  79. super().set_curr_epoch(epoch)
  80. self.rng.seed(SEED + epoch)
  81. def _load_datapoint(self, index: int) -> Datapoint:
  82. id = self.ids[index].item()
  83. queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
  84. # we subsample the video frames during training
  85. if self.training and not self.is_tiling_single_image:
  86. # pick a random stride for sampling query stages (`randint` includes both ends)
  87. stage_stride = self.rng.randint(
  88. self.stage_stride_min, self.stage_stride_max
  89. )
  90. stage_ids_to_keep = self._sample_stage_ids(
  91. queries, self.num_stages_sample, stage_stride
  92. )
  93. # filter the queries and annotations to keep only the selected stages
  94. # (also remap the stage ids so that they are contiguous and start from 0)
  95. reverse_time_axis = (
  96. self.rng.random() < 0.5 if self.random_reverse_time_axis else False
  97. )
  98. queries, annotations, kept_img_ids = self._filter_query_and_anns(
  99. queries,
  100. annotations,
  101. stage_ids_to_keep,
  102. remap_stage_id=True,
  103. reverse_time_axis=reverse_time_axis,
  104. )
  105. pil_images, img_metadata = self._load_images(id, kept_img_ids)
  106. if reverse_time_axis:
  107. # reverse the temporal ordering of the images and their metadata
  108. # so that the image order matches the query order
  109. pil_images = pil_images[::-1]
  110. img_metadata = img_metadata[::-1]
  111. else:
  112. pil_images, img_metadata = self._load_images(id)
  113. # check that all the images have the same image size (they are expected
  114. # to have the same image size since they are frames from the same video)
  115. assert all(p.size == pil_images[0][1].size for _, p in pil_images)
  116. queries.sort(key=lambda q: q["query_processing_order"])
  117. if self.override_query_is_exhaustive_to_true:
  118. for query in queries:
  119. query["is_exhaustive"] = True
  120. datapoint = self.load_queries(pil_images, annotations, queries, img_metadata)
  121. # skip datapoints with too many masklets to avoid OOM errors
  122. num_masklets_in_video = len(datapoint.images[0].objects)
  123. if num_masklets_in_video > self.max_masklet_num_in_video > 0:
  124. logging.warning(
  125. f"Datapoint {id} has ({num_masklets_in_video=}), exceeding "
  126. f"the maximum allowed ({self.max_masklet_num_in_video}). "
  127. "Skipping this datapoint."
  128. )
  129. next_index = (index + 1) % len(self)
  130. return self._load_datapoint(next_index) # move to the next datapoint
  131. if self.is_tiling_single_image:
  132. datapoint = self._tile_single_image_data(datapoint, self.num_stages_sample)
  133. if self.max_query_num > 0:
  134. datapoint = self._subsample_queries(datapoint, self.max_query_num)
  135. # ensure that all find queries have the same processing order as their image id
  136. for query in datapoint.find_queries:
  137. assert query.image_id == query.query_processing_order, (
  138. f"find query has inconsistent image_id and "
  139. f"query_processing_order: {query.image_id=} vs "
  140. f"{query.query_processing_order=}"
  141. )
  142. return datapoint
  143. def _sample_stage_ids(self, queries, num_stages_sample, stage_stride):
  144. """Sample a subset of stage ids from all queries."""
  145. # Later we can perhaps turn it into a Sampler class to be more flexible.
  146. all_stage_ids = sorted(set(q["query_processing_order"] for q in queries))
  147. num_stages_total = len(all_stage_ids)
  148. if num_stages_total < num_stages_sample:
  149. raise ValueError("Not enough stages to sample")
  150. # the difference in index between the first and the last sampled stage ids
  151. b_e_gap = (num_stages_sample - 1) * stage_stride
  152. if b_e_gap > num_stages_total - 1:
  153. # In this case, it's not possible to sample with the provide stride,
  154. # so we use the maximum possible stride.
  155. prev_stage_stride = stage_stride
  156. stage_stride = math.floor((num_stages_total - 1) / (num_stages_sample - 1))
  157. logging.info(
  158. f"lowering stride from {prev_stage_stride} to {stage_stride} to "
  159. f"sample {num_stages_sample} stages (from {num_stages_total} total)"
  160. )
  161. b_e_gap = (num_stages_sample - 1) * stage_stride
  162. # randomly select a starting stage id (`randint` includes both ends)
  163. b_max = len(all_stage_ids) - 1 - b_e_gap
  164. b = self.rng.randint(0, b_max)
  165. e = b + b_e_gap
  166. stage_ids_to_keep = all_stage_ids[b : e + 1 : stage_stride]
  167. return stage_ids_to_keep
  168. def _filter_query_and_anns(
  169. self, queries, annotations, stage_ids_to_keep, remap_stage_id, reverse_time_axis
  170. ):
  171. """Filter queries and annotations to only keep those in `stage_ids_to_keep`."""
  172. stage_ids_to_keep = set(stage_ids_to_keep)
  173. kept_img_ids = set()
  174. kept_stage_ids = set()
  175. # Filter queries -- keep those queries with stage_id in `stage_ids_to_keep`
  176. filtered_queries = []
  177. for query in queries:
  178. input_box = query.get("input_box", None)
  179. input_points = query.get("input_points", None)
  180. has_geo_input = input_box is not None or input_points is not None
  181. if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
  182. continue
  183. stage_id = query["query_processing_order"]
  184. if stage_id in stage_ids_to_keep:
  185. kept_img_ids.add(query["image_id"])
  186. kept_stage_ids.add(stage_id)
  187. filtered_queries.append(query)
  188. # Check that all frames in `stage_ids_to_keep` are present after filtering
  189. all_frame_present = kept_stage_ids == stage_ids_to_keep
  190. assert all_frame_present, f"{kept_stage_ids=} vs {stage_ids_to_keep=}"
  191. if remap_stage_id:
  192. # Remap those kept stage ids to be contiguous and starting from 0
  193. old_stage_ids = sorted(kept_stage_ids, reverse=reverse_time_axis)
  194. stage_id_old2new = {old: new for new, old in enumerate(old_stage_ids)}
  195. for query in filtered_queries:
  196. ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
  197. ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
  198. assert ptr_x_is_empty and ptr_y_is_empty, (
  199. "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
  200. )
  201. query["query_processing_order"] = stage_id_old2new[
  202. query["query_processing_order"]
  203. ]
  204. # Filter annotations -- keep those annotations with image_id in `kept_img_ids`
  205. filtered_annotations = [
  206. ann for ann in annotations if ann["image_id"] in kept_img_ids
  207. ]
  208. return filtered_queries, filtered_annotations, kept_img_ids
  209. def _tile_single_image_data(self, datapoint: Datapoint, num_stages_sample: int):
  210. """
  211. Tile a single image and its queries to simulate video frames. The output is a
  212. datapoint with *identical video frames* (i.e. the same static image) and needs
  213. further transforms (e.g. affine) to get video frames with different content.
  214. """
  215. # tile `images: List[Image]`
  216. assert len(datapoint.images) == 1, "Expected only one single image"
  217. tiled_images = [
  218. copy.deepcopy(datapoint.images[0]) for _ in range(num_stages_sample)
  219. ]
  220. for stage_id, img in enumerate(tiled_images):
  221. for obj in img.objects:
  222. obj.frame_index = stage_id
  223. # tile `raw_images: Optional[List[PILImage.Image]] = None`
  224. tiled_raw_images = None
  225. if datapoint.raw_images is not None:
  226. assert len(datapoint.raw_images) == 1, "Expected only one single image"
  227. tiled_raw_images = [
  228. datapoint.raw_images[0].copy() for _ in range(num_stages_sample)
  229. ]
  230. # tile `find_queries: List[FindQueryLoaded]`
  231. tiled_find_queries_per_stage = [[] for _ in range(num_stages_sample)]
  232. for query in datapoint.find_queries:
  233. assert query.image_id == 0
  234. assert query.query_processing_order == 0
  235. # check and make sure that a query doesn't contain pointers or references
  236. # to other queries (that cannot be tiled)
  237. assert query.ptr_x is None and query.ptr_y is None
  238. assert query.ptr_mem is None
  239. # assert query.wkdata_qid is None
  240. # assert query.other_positive_qids is None
  241. # assert query.negative_qids is None
  242. has_geo_input = (
  243. query.input_bbox is not None or query.input_points is not None
  244. )
  245. if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs:
  246. continue
  247. for stage_id in range(num_stages_sample):
  248. # copy the query and update the image_id
  249. new_query = copy.deepcopy(query)
  250. new_query.image_id = stage_id
  251. new_query.query_processing_order = stage_id
  252. if new_query.inference_metadata is not None:
  253. new_query.inference_metadata.frame_index = stage_id
  254. tiled_find_queries_per_stage[stage_id].append(new_query)
  255. tiled_find_queries = sum(tiled_find_queries_per_stage, [])
  256. # tile `get_queries: List[GetQuery]` -- we skip them for now (since they involve
  257. # a pointer to a find query that is complicated to tile, and there is not an
  258. # imminent use case for them in the video grounding task in the near future)
  259. if self.tile_img_keep_get_queries:
  260. raise NotImplementedError("Tiling get queries is not implemented yet")
  261. else:
  262. tiled_get_queries = []
  263. return Datapoint(
  264. images=tiled_images,
  265. raw_images=tiled_raw_images,
  266. find_queries=tiled_find_queries,
  267. get_queries=tiled_get_queries,
  268. )
  269. def _subsample_queries(self, datapoint: Datapoint, max_query_num: int):
  270. """Subsample to keep at most `max_query_num` queries per frame in a datapoint."""
  271. # aggregate the find queries per stage
  272. num_frames = max(q.query_processing_order for q in datapoint.find_queries) + 1
  273. find_queries_per_stage = [[] for _ in range(num_frames)]
  274. for query in datapoint.find_queries:
  275. find_queries_per_stage[query.query_processing_order].append(query)
  276. # verify that all the stages have the same number of queries
  277. num_queries_per_stage = len(find_queries_per_stage[0])
  278. for queries in find_queries_per_stage:
  279. assert len(queries) == num_queries_per_stage
  280. if max_query_num <= 0 or num_queries_per_stage <= max_query_num:
  281. return datapoint
  282. # subsample the queries to keep only `max_query_num` queries
  283. sampled_inds = self.rng.sample(range(num_queries_per_stage), max_query_num)
  284. sampled_find_queries_per_stage = [
  285. [queries[idx] for idx in sampled_inds] for queries in find_queries_per_stage
  286. ]
  287. sampled_find_queries = sum(sampled_find_queries_per_stage, [])
  288. return Datapoint(
  289. images=datapoint.images,
  290. raw_images=datapoint.raw_images,
  291. find_queries=sampled_find_queries,
  292. get_queries=datapoint.get_queries,
  293. )