sam3_image_dataset.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Dataset class for modulated detection"""
  4. import json
  5. import os
  6. import random
  7. import sys
  8. import traceback
  9. from collections import Counter
  10. from dataclasses import dataclass
  11. from enum import Enum
  12. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  13. import torch
  14. import torch.utils.data
  15. import torchvision
  16. from decord import cpu, VideoReader
  17. from iopath.common.file_io import g_pathmgr
  18. from PIL import Image as PILImage
  19. from PIL.Image import DecompressionBombError
  20. from sam3.model.box_ops import box_xywh_to_xyxy
  21. from torchvision.datasets.vision import VisionDataset
  22. from .coco_json_loaders import COCO_FROM_JSON
  23. @dataclass
  24. class InferenceMetadata:
  25. """Metadata required for postprocessing"""
  26. # Coco id that corresponds to the "image" for evaluation by the coco evaluator
  27. # This is used for our own "class agnostic" evaluation
  28. coco_image_id: int
  29. # id in the original dataset, such that we can use the original evaluator
  30. original_image_id: int
  31. # Original category id (if we want to use the original evaluator)
  32. original_category_id: int
  33. # Size of the raw image (height, width)
  34. original_size: Tuple[int, int]
  35. # Id of the object in the media
  36. object_id: int
  37. # Index of the frame in the media (0 if single image)
  38. frame_index: int
  39. # Whether it is for conditioning only, e.g., 0-th frame in TA is for conditioning
  40. # as we assume GT available in frame 0.
  41. is_conditioning_only: Optional[bool] = False
  42. @dataclass
  43. class FindQuery:
  44. query_text: str
  45. image_id: int
  46. # In case of a find query, the list of object ids that have to be predicted
  47. object_ids_output: List[int]
  48. # This is "instance exhaustivity".
  49. # true iff all instances are separable and annotated
  50. # See below the slightly different "pixel exhaustivity"
  51. is_exhaustive: bool
  52. # The order in which the queries are processed (only meaningful for video)
  53. query_processing_order: int = 0
  54. # Input geometry, initially in denormalized XYXY format. Then
  55. # 1. converted to normalized CxCyWH by the Normalize transform
  56. input_bbox: Optional[torch.Tensor] = None
  57. input_bbox_label: Optional[torch.Tensor] = None
  58. # Only for the PVS task
  59. input_points: Optional[torch.Tensor] = None
  60. semantic_target: Optional[torch.Tensor] = None
  61. # pixel exhaustivity: true iff the union of all segments (including crowds)
  62. # covers every pixel belonging to the target class
  63. # Note that instance_exhaustive implies pixel_exhaustive
  64. is_pixel_exhaustive: Optional[bool] = None
  65. @dataclass
  66. class FindQueryLoaded(FindQuery):
  67. # Must have default value since FindQuery has entries with default values
  68. inference_metadata: Optional[InferenceMetadata] = None
  69. @dataclass
  70. class Object:
  71. # Initially in denormalized XYXY format, gets converted to normalized CxCyWH by the Normalize transform
  72. bbox: torch.Tensor
  73. area: float
  74. # Id of the object in the media
  75. object_id: Optional[int] = -1
  76. # Index of the frame in the media (0 if single image)
  77. frame_index: Optional[int] = -1
  78. segment: Optional[Union[torch.Tensor, dict]] = None # RLE dict or binary mask
  79. is_crowd: bool = False
  80. source: Optional[str] = None
  81. @dataclass
  82. class Image:
  83. data: Union[torch.Tensor, PILImage.Image]
  84. objects: List[Object]
  85. size: Tuple[int, int] # (height, width)
  86. # For blurring augmentation
  87. blurring_mask: Optional[Dict[str, Any]] = None
  88. @dataclass
  89. class Datapoint:
  90. """Refers to an image/video and all its annotations"""
  91. find_queries: List[FindQueryLoaded]
  92. images: List[Image]
  93. raw_images: Optional[List[PILImage.Image]] = None
  94. class CustomCocoDetectionAPI(VisionDataset):
  95. """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
  96. Args:
  97. root (string): Root directory where images are downloaded to.
  98. annFile (string): Path to json annotation file.
  99. transform (callable, optional): A function/transform that takes in an PIL image
  100. and returns a transformed version. E.g, ``transforms.ToTensor``
  101. target_transform (callable, optional): A function/transform that takes in the
  102. target and transforms it.
  103. transforms (callable, optional): A function/transform that takes input sample and its target as entry
  104. and returns a transformed version.
  105. """
  106. def __init__(
  107. self,
  108. root: str,
  109. annFile: str,
  110. load_segmentation: bool,
  111. fix_fname: bool = False,
  112. training: bool = True,
  113. blurring_masks_path: Optional[str] = None,
  114. use_caching: bool = True,
  115. zstd_dict_path=None,
  116. filter_query=None,
  117. coco_json_loader: Callable = COCO_FROM_JSON,
  118. limit_ids: int = None,
  119. ) -> None:
  120. super().__init__(root)
  121. self.annFile = annFile
  122. self.use_caching = use_caching
  123. self.zstd_dict_path = zstd_dict_path
  124. self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
  125. self.load_segmentation = load_segmentation
  126. self.fix_fname = fix_fname
  127. self.filter_query = filter_query
  128. self.coco = None
  129. self.coco_json_loader = coco_json_loader
  130. self.limit_ids = limit_ids
  131. self.set_sharded_annotation_file(0)
  132. self.training = training
  133. self.blurring_masks_path = blurring_masks_path
  134. def _load_images(
  135. self, datapoint_id: int, img_ids_to_load: Optional[Set[int]] = None
  136. ) -> Tuple[List[Tuple[int, PILImage.Image]], List[Dict[str, Any]]]:
  137. all_images = []
  138. all_img_metadata = []
  139. for current_meta in self.coco.loadImagesFromDatapoint(datapoint_id):
  140. img_id = current_meta["id"]
  141. if img_ids_to_load is not None and img_id not in img_ids_to_load:
  142. continue
  143. if self.fix_fname:
  144. current_meta["file_name"] = current_meta["file_name"].split("/")[-1]
  145. path = current_meta["file_name"]
  146. if self.blurring_masks_path is not None:
  147. mask_fname = os.path.basename(path).replace(".jpg", "-mask.json")
  148. mask_path = os.path.join(self.blurring_masks_path, mask_fname)
  149. if os.path.exists(mask_path):
  150. with open(mask_path, "r") as fopen:
  151. current_meta["blurring_mask"] = json.load(fopen)
  152. all_img_metadata.append(current_meta)
  153. path = os.path.join(self.root, path)
  154. try:
  155. if ".mp4" in path and path[-4:] == ".mp4":
  156. # Going to load a video frame
  157. video_path, frame = path.split("@")
  158. video = VideoReader(video_path, ctx=cpu(0))
  159. # Convert to PIL image
  160. all_images.append(
  161. (
  162. img_id,
  163. torchvision.transforms.ToPILImage()(
  164. video[int(frame)].asnumpy()
  165. ),
  166. )
  167. )
  168. else:
  169. with g_pathmgr.open(path, "rb") as fopen:
  170. all_images.append((img_id, PILImage.open(fopen).convert("RGB")))
  171. except FileNotFoundError as e:
  172. print(f"File not found: {path} from dataset: {self.annFile}")
  173. raise e
  174. return all_images, all_img_metadata
  175. def set_curr_epoch(self, epoch: int):
  176. self.curr_epoch = epoch
  177. def set_epoch(self, epoch: int):
  178. pass
  179. def set_sharded_annotation_file(self, data_epoch: int):
  180. if self.coco is not None:
  181. return
  182. assert g_pathmgr.isfile(self.annFile), (
  183. f"please provide valid annotation file. Missing: {self.annFile}"
  184. )
  185. annFile = g_pathmgr.get_local_path(self.annFile)
  186. if self.coco is not None:
  187. del self.coco
  188. self.coco = self.coco_json_loader(annFile)
  189. # Use a torch tensor here to optimize memory usage when using several dataloaders
  190. ids_list = list(sorted(self.coco.getDatapointIds()))
  191. if self.limit_ids is not None:
  192. local_random = random.Random(len(ids_list))
  193. local_random.shuffle(ids_list)
  194. ids_list = ids_list[: self.limit_ids]
  195. self.ids = torch.as_tensor(ids_list, dtype=torch.long)
  196. def __getitem__(self, index: int) -> Datapoint:
  197. return self._load_datapoint(index)
  198. def _load_datapoint(self, index: int) -> Datapoint:
  199. """A separate method for easy overriding in subclasses."""
  200. id = self.ids[index].item()
  201. pil_images, img_metadata = self._load_images(id)
  202. queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id)
  203. return self.load_queries(pil_images, annotations, queries, img_metadata)
  204. def load_queries(self, pil_images, annotations, queries, img_metadata):
  205. """Transform the raw image and queries into a Datapoint sample."""
  206. images: List[Image] = []
  207. id2index_img = {}
  208. id2index_obj = {}
  209. id2index_find_query = {}
  210. id2imsize = {}
  211. assert len(pil_images) == len(img_metadata)
  212. for i in range(len(pil_images)):
  213. w, h = pil_images[i][1].size
  214. blurring_mask = None
  215. if "blurring_mask" in img_metadata[i]:
  216. blurring_mask = img_metadata[i]["blurring_mask"]
  217. images.append(
  218. Image(
  219. data=pil_images[i][1],
  220. objects=[],
  221. size=(h, w),
  222. blurring_mask=blurring_mask,
  223. )
  224. )
  225. id2index_img[pil_images[i][0]] = i
  226. id2imsize[pil_images[i][0]] = (h, w)
  227. for annotation in annotations:
  228. image_id = id2index_img[annotation["image_id"]]
  229. bbox = box_xywh_to_xyxy(torch.as_tensor(annotation["bbox"])).view(1, 4)
  230. h, w = id2imsize[annotation["image_id"]]
  231. bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
  232. bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
  233. segment = None
  234. if self.load_segmentation and "segmentation" in annotation:
  235. # We're not decoding the RLE here, a transform will do it lazily later
  236. segment = annotation["segmentation"]
  237. images[image_id].objects.append(
  238. Object(
  239. bbox=bbox[0],
  240. area=annotation["area"],
  241. object_id=(
  242. annotation["object_id"] if "object_id" in annotation else -1
  243. ),
  244. frame_index=(
  245. annotation["frame_index"] if "frame_index" in annotation else -1
  246. ),
  247. segment=segment,
  248. is_crowd=(
  249. annotation["is_crowd"] if "is_crowd" in annotation else None
  250. ),
  251. source=annotation["source"] if "source" in annotation else "",
  252. )
  253. )
  254. id2index_obj[annotation["id"]] = len(images[image_id].objects) - 1
  255. find_queries = []
  256. stage2num_queries = Counter()
  257. for i, query in enumerate(queries):
  258. stage2num_queries[query["query_processing_order"]] += 1
  259. id2index_find_query[query["id"]] = i
  260. # Sanity check: all the stages should have the same number of queries
  261. if len(stage2num_queries) == 0:
  262. num_queries_per_stage = 0
  263. else:
  264. num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
  265. for stage, num_queries in stage2num_queries.items():
  266. assert num_queries == num_queries_per_stage, (
  267. f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
  268. )
  269. for query in queries:
  270. h, w = id2imsize[query["image_id"]]
  271. if (
  272. "input_box" in query
  273. and query["input_box"] is not None
  274. and len(query["input_box"]) > 0
  275. ):
  276. bbox = box_xywh_to_xyxy(torch.as_tensor(query["input_box"])).view(-1, 4)
  277. bbox[:, 0::2].mul_(w).clamp_(min=0, max=w)
  278. bbox[:, 1::2].mul_(h).clamp_(min=0, max=h)
  279. if "input_box_label" in query and query["input_box_label"] is not None:
  280. bbox_label = torch.as_tensor(
  281. query["input_box_label"], dtype=torch.long
  282. ).view(-1)
  283. assert len(bbox_label) == len(bbox)
  284. else:
  285. # assume the boxes are positives
  286. bbox_label = torch.ones(len(bbox), dtype=torch.long)
  287. else:
  288. bbox = None
  289. bbox_label = None
  290. if "input_points" in query and query["input_points"] is not None:
  291. points = torch.as_tensor(query["input_points"]).view(1, -1, 3)
  292. points[:, :, 0:1].mul_(w).clamp_(min=0, max=w)
  293. points[:, :, 1:2].mul_(h).clamp_(min=0, max=h)
  294. else:
  295. points = None
  296. try:
  297. original_image_id = int(
  298. img_metadata[id2index_img[query["image_id"]]]["original_img_id"]
  299. )
  300. except ValueError:
  301. original_image_id = -1
  302. try:
  303. img_metadata_query = img_metadata[id2index_img[query["image_id"]]]
  304. coco_image_id = (
  305. int(img_metadata_query["coco_img_id"])
  306. if "coco_img_id" in img_metadata_query
  307. else query["id"]
  308. )
  309. except KeyError:
  310. coco_image_id = -1
  311. try:
  312. original_category_id = int(query["original_cat_id"])
  313. except (ValueError, KeyError):
  314. original_category_id = -1
  315. # For evaluation, we associate the ids of the object to be tracked to the query
  316. if query["object_ids_output"]:
  317. obj_id = query["object_ids_output"][0]
  318. obj_idx = id2index_obj[obj_id]
  319. image_idx = id2index_img[query["image_id"]]
  320. object_id = images[image_idx].objects[obj_idx].object_id
  321. frame_index = images[image_idx].objects[obj_idx].frame_index
  322. else:
  323. object_id = -1
  324. frame_index = -1
  325. find_queries.append(
  326. FindQueryLoaded(
  327. # id=query["id"],
  328. # query_type=qtype,
  329. query_text=(
  330. query["query_text"] if query["query_text"] is not None else ""
  331. ),
  332. image_id=id2index_img[query["image_id"]],
  333. input_bbox=bbox,
  334. input_bbox_label=bbox_label,
  335. input_points=points,
  336. object_ids_output=[
  337. id2index_obj[obj_id] for obj_id in query["object_ids_output"]
  338. ],
  339. is_exhaustive=query["is_exhaustive"],
  340. is_pixel_exhaustive=(
  341. query["is_pixel_exhaustive"]
  342. if "is_pixel_exhaustive" in query
  343. else (
  344. query["is_exhaustive"] if query["is_exhaustive"] else None
  345. )
  346. ),
  347. query_processing_order=query["query_processing_order"],
  348. inference_metadata=InferenceMetadata(
  349. coco_image_id=-1 if self.training else coco_image_id,
  350. original_image_id=(-1 if self.training else original_image_id),
  351. frame_index=frame_index,
  352. original_category_id=original_category_id,
  353. original_size=(h, w),
  354. object_id=object_id,
  355. ),
  356. )
  357. )
  358. return Datapoint(
  359. find_queries=find_queries,
  360. images=images,
  361. raw_images=[p[1] for p in pil_images],
  362. )
  363. def __len__(self) -> int:
  364. return len(self.ids)
  365. class Sam3ImageDataset(CustomCocoDetectionAPI):
  366. def __init__(
  367. self,
  368. img_folder,
  369. ann_file,
  370. transforms,
  371. max_ann_per_img: int,
  372. multiplier: int,
  373. training: bool,
  374. load_segmentation: bool = False,
  375. max_train_queries: int = 81,
  376. max_val_queries: int = 300,
  377. fix_fname: bool = False,
  378. is_sharded_annotation_dir: bool = False,
  379. blurring_masks_path: Optional[str] = None,
  380. use_caching: bool = True,
  381. zstd_dict_path=None,
  382. filter_query=None,
  383. coco_json_loader: Callable = COCO_FROM_JSON,
  384. limit_ids: int = None,
  385. ):
  386. super(Sam3ImageDataset, self).__init__(
  387. img_folder,
  388. ann_file,
  389. fix_fname=fix_fname,
  390. load_segmentation=load_segmentation,
  391. training=training,
  392. blurring_masks_path=blurring_masks_path,
  393. use_caching=use_caching,
  394. zstd_dict_path=zstd_dict_path,
  395. filter_query=filter_query,
  396. coco_json_loader=coco_json_loader,
  397. limit_ids=limit_ids,
  398. )
  399. self._transforms = transforms
  400. self.training = training
  401. self.max_ann_per_img = max_ann_per_img
  402. self.max_train_queries = max_train_queries
  403. self.max_val_queries = max_val_queries
  404. self.repeat_factors = torch.ones(len(self.ids), dtype=torch.float32)
  405. self.repeat_factors *= multiplier
  406. print(f"Raw dataset length = {len(self.ids)}")
  407. self._MAX_RETRIES = 100
  408. def __getitem__(self, idx):
  409. return self.__orig_getitem__(idx)
  410. def __orig_getitem__(self, idx):
  411. for _ in range(self._MAX_RETRIES):
  412. try:
  413. datapoint = super(Sam3ImageDataset, self).__getitem__(idx)
  414. # This can be done better by filtering the offending find queries
  415. # However, this requires care:
  416. # - Delete any find/get query that may depend on the deleted one
  417. # - Re-compute the indexes in the pointers to account for the deleted finds
  418. for q in datapoint.find_queries:
  419. if len(q.object_ids_output) > self.max_ann_per_img:
  420. raise DecompressionBombError(
  421. f"Too many outputs ({len(q.object_ids_output)})"
  422. )
  423. max_queries = (
  424. self.max_train_queries if self.training else self.max_val_queries
  425. )
  426. if len(datapoint.find_queries) > max_queries:
  427. raise DecompressionBombError(
  428. f"Too many find queries ({len(datapoint.find_queries)})"
  429. )
  430. if len(datapoint.find_queries) == 0:
  431. raise DecompressionBombError("No find queries")
  432. for transform in self._transforms:
  433. datapoint = transform(datapoint, epoch=self.curr_epoch)
  434. break
  435. except (DecompressionBombError, OSError, ValueError) as error:
  436. sys.stderr.write(f"ERROR: got loading error on datapoint {idx}\n")
  437. sys.stderr.write(f"Exception: {error}\n")
  438. sys.stderr.write(traceback.format_exc())
  439. idx = (idx + 1) % len(self)
  440. else:
  441. raise RuntimeError(
  442. f"Failed {self._MAX_RETRIES} times trying to load an image."
  443. )
  444. return datapoint