coco_writer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. COCO prediction dumper for distributed training.
  5. Handles collection and dumping of COCO-format predictions from models.
  6. Supports distributed processing with multiple GPUs/processes.
  7. """
  8. import copy
  9. import gc
  10. import heapq
  11. import json
  12. import logging
  13. import os
  14. from collections import defaultdict
  15. from pathlib import Path
  16. from typing import Any, Optional
  17. import pycocotools.mask as mask_utils
  18. import torch
  19. from iopath.common.file_io import g_pathmgr
  20. from sam3.eval.coco_eval_offline import convert_to_xywh
  21. from sam3.train.masks_ops import rle_encode
  22. from sam3.train.utils.distributed import (
  23. all_gather,
  24. gather_to_rank_0_via_filesys,
  25. get_rank,
  26. is_main_process,
  27. )
  28. ### Helper functions and classes
  29. class HeapElement:
  30. """Utility class to make a heap with a custom comparator based on score."""
  31. def __init__(self, val):
  32. self.val = val
  33. def __lt__(self, other):
  34. return self.val["score"] < other.val["score"]
  35. class PredictionDumper:
  36. """
  37. Handles collection and dumping of COCO-format predictions from a model.
  38. This class processes model outputs through a postprocessor, converts them to COCO format,
  39. and saves them to disk. It supports distributed processing with multiple GPUs/processes.
  40. """
  41. def __init__(
  42. self,
  43. dump_dir: str,
  44. postprocessor,
  45. maxdets: int,
  46. iou_type: str,
  47. gather_pred_via_filesys: bool = False,
  48. merge_predictions: bool = False,
  49. pred_file_evaluators: Optional[Any] = None,
  50. ):
  51. """
  52. Initialize the PredictionDumper.
  53. Args:
  54. dump_dir: Directory to dump predictions.
  55. postprocessor: Module to convert the model's output into COCO format.
  56. maxdets: Maximum number of detections per image.
  57. iou_type: IoU type to evaluate. Can include "bbox", "segm"
  58. gather_pred_via_filesys: If True, use the filesystem for collective gathers across
  59. processes (requires a shared filesystem). Otherwise, use torch collective ops.
  60. merge_predictions: If True, merge predictions from all processes and dump to a single file.
  61. """
  62. self.iou_type = iou_type
  63. self.maxdets = maxdets
  64. self.dump_dir = dump_dir
  65. self.postprocessor = postprocessor
  66. self.gather_pred_via_filesys = gather_pred_via_filesys
  67. self.merge_predictions = merge_predictions
  68. self.pred_file_evaluators = pred_file_evaluators
  69. if self.pred_file_evaluators is not None:
  70. assert merge_predictions, (
  71. "merge_predictions must be True if pred_file_evaluators are provided"
  72. )
  73. assert self.dump_dir is not None, "dump_dir must be provided"
  74. if is_main_process():
  75. os.makedirs(self.dump_dir, exist_ok=True)
  76. logging.info(f"Created prediction dump directory: {self.dump_dir}")
  77. # Initialize state
  78. self.reset()
  79. def update(self, *args, **kwargs):
  80. """
  81. Process and accumulate predictions from model outputs.
  82. Args:
  83. *args, **kwargs: Arguments passed to postprocessor.process_results()
  84. """
  85. predictions = self.postprocessor.process_results(*args, **kwargs)
  86. results = self.prepare(predictions, self.iou_type)
  87. self._dump(results)
  88. def _dump(self, results):
  89. """
  90. Add results to the dump list with precision rounding.
  91. Args:
  92. results: List of prediction dictionaries in COCO format.
  93. """
  94. dumped_results = copy.deepcopy(results)
  95. for r in dumped_results:
  96. if "bbox" in r:
  97. r["bbox"] = [round(coord, 5) for coord in r["bbox"]]
  98. r["score"] = round(r["score"], 5)
  99. self.dump.extend(dumped_results)
  100. def synchronize_between_processes(self):
  101. """
  102. Synchronize predictions across all processes and save to disk.
  103. If gather_pred_via_filesys is True, uses filesystem for gathering.
  104. Otherwise, uses torch distributed collective operations.
  105. Saves per-rank predictions to separate JSON files.
  106. """
  107. logging.info("Prediction Dumper: Synchronizing between processes")
  108. if not self.merge_predictions:
  109. dumped_file = (
  110. Path(self.dump_dir)
  111. / f"coco_predictions_{self.iou_type}_{get_rank()}.json"
  112. )
  113. logging.info(
  114. f"Prediction Dumper: Dumping local predictions to {dumped_file}"
  115. )
  116. with g_pathmgr.open(str(dumped_file), "w") as f:
  117. json.dump(self.dump, f)
  118. else:
  119. self.dump = self.gather_and_merge_predictions()
  120. dumped_file = Path(self.dump_dir) / f"coco_predictions_{self.iou_type}.json"
  121. if is_main_process():
  122. logging.info(
  123. f"Prediction Dumper: Dumping merged predictions to {dumped_file}"
  124. )
  125. with g_pathmgr.open(str(dumped_file), "w") as f:
  126. json.dump(self.dump, f)
  127. self.reset()
  128. return dumped_file
  129. def gather_and_merge_predictions(self):
  130. """
  131. Gather predictions from all processes and merge them, keeping top predictions per image.
  132. This method collects predictions from all processes, then keeps only the top maxdets
  133. predictions per image based on score. It also deduplicates predictions by (image_id, category_id).
  134. Returns:
  135. List of merged prediction dictionaries.
  136. """
  137. logging.info("Prediction Dumper: Gathering predictions from all processes")
  138. gc.collect()
  139. if self.gather_pred_via_filesys:
  140. dump = gather_to_rank_0_via_filesys(self.dump)
  141. else:
  142. dump = all_gather(self.dump, force_cpu=True)
  143. # Combine predictions, keeping only top maxdets per image
  144. preds_by_image = defaultdict(list)
  145. seen_img_cat = set()
  146. for cur_dump in dump:
  147. cur_seen_img_cat = set()
  148. for p in cur_dump:
  149. image_id = p["image_id"]
  150. cat_id = p["category_id"]
  151. # Skip if we've already seen this image/category pair in a previous dump
  152. if (image_id, cat_id) in seen_img_cat:
  153. continue
  154. cur_seen_img_cat.add((image_id, cat_id))
  155. # Use a min-heap to keep top predictions
  156. if len(preds_by_image[image_id]) < self.maxdets:
  157. heapq.heappush(preds_by_image[image_id], HeapElement(p))
  158. else:
  159. heapq.heappushpop(preds_by_image[image_id], HeapElement(p))
  160. seen_img_cat.update(cur_seen_img_cat)
  161. # Flatten the heap elements back to a list
  162. merged_dump = sum(
  163. [[h.val for h in cur_preds] for cur_preds in preds_by_image.values()], []
  164. )
  165. return merged_dump
  166. def compute_synced(self):
  167. """
  168. Synchronize predictions across processes and compute summary.
  169. Returns:
  170. Summary dictionary from summarize().
  171. """
  172. dumped_file = self.synchronize_between_processes()
  173. if not is_main_process():
  174. return {"": 0.0}
  175. meters = {}
  176. if self.pred_file_evaluators is not None:
  177. for evaluator in self.pred_file_evaluators:
  178. results = evaluator.evaluate(dumped_file)
  179. meters.update(results)
  180. if len(meters) == 0:
  181. meters = {"": 0.0}
  182. return meters
  183. def compute(self):
  184. """
  185. Compute without synchronization.
  186. Returns:
  187. Empty metric dictionary.
  188. """
  189. return {"": 0.0}
  190. def reset(self):
  191. """Reset internal state for a new evaluation round."""
  192. self.dump = []
  193. def prepare(self, predictions, iou_type):
  194. """
  195. Route predictions to the appropriate preparation method based on iou_type.
  196. Args:
  197. predictions: Dictionary mapping image IDs to prediction dictionaries.
  198. iou_type: Type of evaluation ("bbox", "segm").
  199. Returns:
  200. List of COCO-format prediction dictionaries.
  201. """
  202. if iou_type == "bbox":
  203. return self.prepare_for_coco_detection(predictions)
  204. elif iou_type == "segm":
  205. return self.prepare_for_coco_segmentation(predictions)
  206. else:
  207. raise ValueError(f"Unknown iou type: {iou_type}")
  208. def prepare_for_coco_detection(self, predictions):
  209. """
  210. Convert predictions to COCO detection format.
  211. Args:
  212. predictions: Dictionary mapping image IDs to prediction dictionaries
  213. containing "boxes", "scores", and "labels".
  214. Returns:
  215. List of COCO-format detection dictionaries.
  216. """
  217. coco_results = []
  218. for original_id, prediction in predictions.items():
  219. if len(prediction) == 0:
  220. continue
  221. boxes = prediction["boxes"]
  222. boxes = convert_to_xywh(boxes).tolist()
  223. scores = prediction["scores"].tolist()
  224. labels = prediction["labels"].tolist()
  225. coco_results.extend(
  226. [
  227. {
  228. "image_id": original_id,
  229. "category_id": labels[k],
  230. "bbox": box,
  231. "score": scores[k],
  232. }
  233. for k, box in enumerate(boxes)
  234. ]
  235. )
  236. return coco_results
  237. @torch.no_grad()
  238. def prepare_for_coco_segmentation(self, predictions):
  239. """
  240. Convert predictions to COCO segmentation format.
  241. Args:
  242. predictions: Dictionary mapping image IDs to prediction dictionaries
  243. containing "masks" or "masks_rle", "scores", and "labels".
  244. Optionally includes "boundaries" and "dilated_boundaries".
  245. Returns:
  246. List of COCO-format segmentation dictionaries with RLE-encoded masks.
  247. """
  248. coco_results = []
  249. for original_id, prediction in predictions.items():
  250. if len(prediction) == 0:
  251. continue
  252. scores = prediction["scores"].tolist()
  253. labels = prediction["labels"].tolist()
  254. boxes = None
  255. if "boxes" in prediction:
  256. boxes = prediction["boxes"]
  257. boxes = convert_to_xywh(boxes).tolist()
  258. assert len(boxes) == len(scores)
  259. if "masks_rle" in prediction:
  260. rles = prediction["masks_rle"]
  261. areas = []
  262. for rle in rles:
  263. cur_area = mask_utils.area(rle)
  264. h, w = rle["size"]
  265. areas.append(cur_area / (h * w))
  266. else:
  267. masks = prediction["masks"]
  268. masks = masks > 0.5
  269. h, w = masks.shape[-2:]
  270. areas = masks.flatten(1).sum(1) / (h * w)
  271. areas = areas.tolist()
  272. rles = rle_encode(masks.squeeze(1))
  273. # Memory cleanup
  274. del masks
  275. del prediction["masks"]
  276. assert len(areas) == len(rles) == len(scores)
  277. for k, rle in enumerate(rles):
  278. payload = {
  279. "image_id": original_id,
  280. "category_id": labels[k],
  281. "segmentation": rle,
  282. "score": scores[k],
  283. "area": areas[k],
  284. }
  285. if boxes is not None:
  286. payload["bbox"] = boxes[k]
  287. coco_results.append(payload)
  288. return coco_results