coco_json_loaders.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import json
  4. from collections import defaultdict
  5. from typing import Dict, List, Tuple
  6. import torch
  7. from pycocotools import mask as mask_util
  8. # ============================================================================
  9. # Utility Functions
  10. # ============================================================================
  11. def convert_boxlist_to_normalized_tensor(box_list, image_width, image_height):
  12. """
  13. Converts a list of bounding boxes to a normalized PyTorch tensor.
  14. Args:
  15. box_list (list of list or tuples): Each box is [x_min, y_min, x_max, y_max].
  16. image_width (int or float): Width of the image.
  17. image_height (int or float): Height of the image.
  18. Returns:
  19. torch.Tensor: Normalized tensor of shape (N, 4), values in [0, 1].
  20. """
  21. boxes = torch.tensor(box_list, dtype=torch.float32)
  22. boxes[:, [0, 2]] /= image_width # x_min, x_max
  23. boxes[:, [1, 3]] /= image_height # y_min, y_max
  24. boxes = boxes.clamp(0, 1)
  25. return boxes
  26. def load_coco_and_group_by_image(json_path: str) -> Tuple[List[Dict], Dict[int, str]]:
  27. """
  28. Load COCO JSON file and group annotations by image.
  29. Args:
  30. json_path (str): Path to COCO JSON file.
  31. Returns:
  32. Tuple containing:
  33. - List of dicts with 'image' and 'annotations' keys
  34. - Dict mapping category IDs to category names
  35. """
  36. with open(json_path, "r") as f:
  37. coco = json.load(f)
  38. images = {img["id"]: img for img in coco["images"]}
  39. anns_by_image = defaultdict(list)
  40. for ann in coco["annotations"]:
  41. anns_by_image[ann["image_id"]].append(ann)
  42. sorted_image_ids = sorted(images.keys())
  43. grouped = []
  44. for image_id in sorted_image_ids:
  45. image_info = images[image_id]
  46. grouped.append(
  47. {"image": image_info, "annotations": anns_by_image.get(image_id, [])}
  48. )
  49. cat_id_to_name = {cat["id"]: cat["name"] for cat in coco["categories"]}
  50. return grouped, cat_id_to_name
  51. def ann_to_rle(segm, im_info: Dict) -> Dict:
  52. """
  53. Convert annotation which can be polygons or uncompressed RLE to RLE.
  54. Args:
  55. segm: Segmentation data (polygon list or RLE dict)
  56. im_info (dict): Image info containing 'height' and 'width'
  57. Returns:
  58. RLE encoded segmentation
  59. """
  60. h, w = im_info["height"], im_info["width"]
  61. if isinstance(segm, list):
  62. # Polygon - merge all parts into one mask RLE code
  63. rles = mask_util.frPyObjects(segm, h, w)
  64. rle = mask_util.merge(rles)
  65. elif isinstance(segm["counts"], list):
  66. # Uncompressed RLE
  67. rle = mask_util.frPyObjects(segm, h, w)
  68. else:
  69. # Already RLE
  70. rle = segm
  71. return rle
  72. # ============================================================================
  73. # COCO Training API
  74. # ============================================================================
  75. class COCO_FROM_JSON:
  76. """
  77. COCO training API for loading box-only annotations from JSON.
  78. Groups all annotations per image and creates queries per category.
  79. """
  80. def __init__(
  81. self,
  82. annotation_file,
  83. prompts=None,
  84. include_negatives=True,
  85. category_chunk_size=None,
  86. ):
  87. """
  88. Initialize the COCO training API.
  89. Args:
  90. annotation_file (str): Path to COCO JSON annotation file
  91. prompts: Optional custom prompts for categories
  92. include_negatives (bool): Whether to include negative examples (categories with no instances)
  93. """
  94. self._raw_data, self._cat_idx_to_text = load_coco_and_group_by_image(
  95. annotation_file
  96. )
  97. self._sorted_cat_ids = sorted(list(self._cat_idx_to_text.keys()))
  98. self.prompts = None
  99. self.include_negatives = include_negatives
  100. self.category_chunk_size = (
  101. category_chunk_size
  102. if category_chunk_size is not None
  103. else len(self._sorted_cat_ids)
  104. )
  105. self.category_chunks = [
  106. self._sorted_cat_ids[i : i + self.category_chunk_size]
  107. for i in range(0, len(self._sorted_cat_ids), self.category_chunk_size)
  108. ]
  109. if prompts is not None:
  110. prompts = eval(prompts)
  111. self.prompts = {}
  112. for loc_dict in prompts:
  113. self.prompts[int(loc_dict["id"])] = loc_dict["name"]
  114. assert len(self.prompts) == len(self._sorted_cat_ids), (
  115. "Number of prompts must match number of categories"
  116. )
  117. def getDatapointIds(self):
  118. """Return all datapoint indices for training."""
  119. return list(range(len(self._raw_data) * len(self.category_chunks)))
  120. def loadQueriesAndAnnotationsFromDatapoint(self, idx):
  121. """
  122. Load queries and annotations for a specific datapoint.
  123. Args:
  124. idx (int): Datapoint index
  125. Returns:
  126. Tuple of (queries, annotations) lists
  127. """
  128. img_idx = idx // len(self.category_chunks)
  129. chunk_idx = idx % len(self.category_chunks)
  130. cat_chunk = self.category_chunks[chunk_idx]
  131. queries = []
  132. annotations = []
  133. query_template = {
  134. "id": None,
  135. "original_cat_id": None,
  136. "object_ids_output": None,
  137. "query_text": None,
  138. "query_processing_order": 0,
  139. "ptr_x_query_id": None,
  140. "ptr_y_query_id": None,
  141. "image_id": 0, # Single image per datapoint
  142. "input_box": None,
  143. "input_box_label": None,
  144. "input_points": None,
  145. "is_exhaustive": True,
  146. }
  147. annot_template = {
  148. "image_id": 0,
  149. "bbox": None, # Normalized bbox in xywh
  150. "area": None, # Unnormalized area
  151. "segmentation": None, # RLE encoded
  152. "object_id": None,
  153. "is_crowd": None,
  154. "id": None,
  155. }
  156. raw_annotations = self._raw_data[img_idx]["annotations"]
  157. image_info = self._raw_data[img_idx]["image"]
  158. width, height = image_info["width"], image_info["height"]
  159. # Group annotations by category
  160. cat_id_to_anns = defaultdict(list)
  161. for ann in raw_annotations:
  162. cat_id_to_anns[ann["category_id"]].append(ann)
  163. annotations_by_cat_sorted = [
  164. (cat_id, cat_id_to_anns[cat_id]) for cat_id in cat_chunk
  165. ]
  166. for cat_id, anns in annotations_by_cat_sorted:
  167. if len(anns) == 0 and not self.include_negatives:
  168. continue
  169. cur_ann_ids = []
  170. # Create annotations for this category
  171. for ann in anns:
  172. annotation = annot_template.copy()
  173. annotation["id"] = len(annotations)
  174. annotation["object_id"] = annotation["id"]
  175. annotation["is_crowd"] = ann["iscrowd"]
  176. normalized_boxes = convert_boxlist_to_normalized_tensor(
  177. [ann["bbox"]], width, height
  178. )
  179. bbox = normalized_boxes[0]
  180. annotation["area"] = (bbox[2] * bbox[3]).item()
  181. annotation["bbox"] = bbox
  182. if (
  183. "segmentation" in ann
  184. and ann["segmentation"] is not None
  185. and ann["segmentation"] != []
  186. ):
  187. annotation["segmentation"] = ann_to_rle(
  188. ann["segmentation"], im_info=image_info
  189. )
  190. annotations.append(annotation)
  191. cur_ann_ids.append(annotation["id"])
  192. # Create query for this category
  193. query = query_template.copy()
  194. query["id"] = len(queries)
  195. query["original_cat_id"] = cat_id
  196. query["query_text"] = (
  197. self._cat_idx_to_text[cat_id]
  198. if self.prompts is None
  199. else self.prompts[cat_id]
  200. )
  201. query["object_ids_output"] = cur_ann_ids
  202. queries.append(query)
  203. return queries, annotations
  204. def loadImagesFromDatapoint(self, idx):
  205. """
  206. Load image information for a specific datapoint.
  207. Args:
  208. idx (int): Datapoint index
  209. Returns:
  210. List containing image info dict
  211. """
  212. img_idx = idx // len(self.category_chunks)
  213. img_data = self._raw_data[img_idx]["image"]
  214. images = [
  215. {
  216. "id": 0,
  217. "file_name": img_data["file_name"],
  218. "original_img_id": img_data["id"],
  219. "coco_img_id": img_data["id"],
  220. }
  221. ]
  222. return images
  223. # ============================================================================
  224. # SAM3 Evaluation APIs
  225. # ============================================================================
  226. class SAM3_EVAL_API_FROM_JSON_NP:
  227. """
  228. SAM3 evaluation API for loading noun phrase queries from JSON.
  229. """
  230. def __init__(self, annotation_file):
  231. """
  232. Initialize the SAM3 evaluation API.
  233. Args:
  234. annotation_file (str): Path to SAM3 JSON annotation file
  235. """
  236. with open(annotation_file, "r") as f:
  237. data = json.load(f)
  238. self._image_data = data["images"]
  239. def getDatapointIds(self):
  240. """Return all datapoint indices."""
  241. return list(range(len(self._image_data)))
  242. def loadQueriesAndAnnotationsFromDatapoint(self, idx):
  243. """
  244. Load queries and annotations for a specific datapoint.
  245. Args:
  246. idx (int): Datapoint index
  247. Returns:
  248. Tuple of (queries, annotations) lists
  249. """
  250. cur_img_data = self._image_data[idx]
  251. queries = []
  252. annotations = []
  253. query_template = {
  254. "id": None,
  255. "original_cat_id": None,
  256. "object_ids_output": None,
  257. "query_text": None,
  258. "query_processing_order": 0,
  259. "ptr_x_query_id": None,
  260. "ptr_y_query_id": None,
  261. "image_id": 0,
  262. "input_box": None,
  263. "input_box_label": None,
  264. "input_points": None,
  265. "is_exhaustive": True,
  266. }
  267. # Create query
  268. query = query_template.copy()
  269. query["id"] = len(queries)
  270. query["original_cat_id"] = int(cur_img_data["queried_category"])
  271. query["query_text"] = cur_img_data["text_input"]
  272. query["object_ids_output"] = []
  273. queries.append(query)
  274. return queries, annotations
  275. def loadImagesFromDatapoint(self, idx):
  276. """
  277. Load image information for a specific datapoint.
  278. Args:
  279. idx (int): Datapoint index
  280. Returns:
  281. List containing image info dict
  282. """
  283. img_data = self._image_data[idx]
  284. images = [
  285. {
  286. "id": 0,
  287. "file_name": img_data["file_name"],
  288. "original_img_id": img_data["id"],
  289. "coco_img_id": img_data["id"],
  290. }
  291. ]
  292. return images
  293. class SAM3_VEVAL_API_FROM_JSON_NP:
  294. """
  295. SAM3 video evaluation API for loading noun phrase queries from JSON.
  296. """
  297. def __init__(self, annotation_file):
  298. """
  299. Initialize the SAM3 video evaluation API.
  300. Args:
  301. annotation_file (str): Path to SAM3 video JSON annotation file
  302. """
  303. with open(annotation_file, "r") as f:
  304. data = json.load(f)
  305. assert "video_np_pairs" in data, "Incorrect data format"
  306. self._video_data = data["videos"]
  307. self._video_id_to_np_ids = defaultdict(list)
  308. self._cat_id_to_np = {}
  309. for cat_dict in data["categories"]:
  310. self._cat_id_to_np[cat_dict["id"]] = cat_dict["name"]
  311. for video_np_dict in data["video_np_pairs"]:
  312. self._video_id_to_np_ids[video_np_dict["video_id"]].append(
  313. video_np_dict["category_id"]
  314. )
  315. assert (
  316. self._cat_id_to_np[video_np_dict["category_id"]]
  317. == video_np_dict["noun_phrase"]
  318. ), "Category name does not match text input"
  319. def getDatapointIds(self):
  320. """Return all datapoint indices."""
  321. return list(range(len(self._video_data)))
  322. def loadQueriesAndAnnotationsFromDatapoint(self, idx):
  323. """
  324. Load queries and annotations for a specific video datapoint.
  325. Args:
  326. idx (int): Datapoint index
  327. Returns:
  328. Tuple of (queries, annotations) lists
  329. """
  330. cur_vid_data = self._video_data[idx]
  331. queries = []
  332. annotations = []
  333. query_template = {
  334. "id": None,
  335. "original_cat_id": None,
  336. "object_ids_output": None,
  337. "query_text": None,
  338. "query_processing_order": 0,
  339. "ptr_x_query_id": None,
  340. "ptr_y_query_id": None,
  341. "image_id": 0,
  342. "input_box": None,
  343. "input_box_label": None,
  344. "input_points": None,
  345. "is_exhaustive": True,
  346. }
  347. all_np_ids = self._video_id_to_np_ids[cur_vid_data["id"]]
  348. for np_id in all_np_ids:
  349. text_input = self._cat_id_to_np[np_id]
  350. for i, image_path in enumerate(cur_vid_data["file_names"]):
  351. query = query_template.copy()
  352. query["id"] = len(queries)
  353. query["original_cat_id"] = np_id
  354. query["query_text"] = text_input
  355. query["image_id"] = i
  356. query["query_processing_order"] = i
  357. query["object_ids_output"] = []
  358. queries.append(query)
  359. return queries, annotations
  360. def loadImagesFromDatapoint(self, idx):
  361. """
  362. Load image information for a specific video datapoint.
  363. Args:
  364. idx (int): Datapoint index
  365. Returns:
  366. List containing image info dicts for all frames
  367. """
  368. video_data = self._video_data[idx]
  369. images = [
  370. {
  371. "id": i,
  372. "file_name": file_name,
  373. "original_img_id": video_data["id"],
  374. "coco_img_id": video_data["id"],
  375. }
  376. for i, file_name in enumerate(video_data["file_names"])
  377. ]
  378. return images