coco_reindex.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Self-contained COCO JSON re-indexing function that creates temporary files.
  5. """
  6. import json
  7. import os
  8. import tempfile
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Optional, Tuple
  11. def reindex_coco_to_temp(input_json_path: str) -> Optional[str]:
  12. """
  13. Convert 0-indexed COCO JSON file to 1-indexed and save to temporary location.
  14. Args:
  15. input_json_path: Path to the input COCO JSON file
  16. Returns:
  17. Path to the new 1-indexed JSON file in temporary directory, or None if no conversion needed
  18. Raises:
  19. FileNotFoundError: If input file doesn't exist
  20. json.JSONDecodeError: If input file is not valid JSON
  21. ValueError: If input file is not a valid COCO format
  22. """
  23. def is_coco_json(data: Dict[str, Any]) -> bool:
  24. """Check if data appears to be a COCO format file."""
  25. if not isinstance(data, dict):
  26. return False
  27. # A COCO file should have at least one of these keys
  28. coco_keys = {"images", "annotations", "categories"}
  29. return any(key in data for key in coco_keys)
  30. def check_zero_indexed(data: Dict[str, Any]) -> Tuple[bool, bool, bool]:
  31. """
  32. Check if annotations, images, or categories start from index 0.
  33. Returns:
  34. Tuple of (annotations_zero_indexed, images_zero_indexed, categories_zero_indexed)
  35. """
  36. annotations_zero = False
  37. images_zero = False
  38. categories_zero = False
  39. # Check annotations
  40. annotations = data.get("annotations", [])
  41. if annotations and any(ann.get("id", -1) == 0 for ann in annotations):
  42. annotations_zero = True
  43. # Check images
  44. images = data.get("images", [])
  45. if images and any(img.get("id", -1) == 0 for img in images):
  46. images_zero = True
  47. # Check categories
  48. categories = data.get("categories", [])
  49. if categories and any(cat.get("id", -1) == 0 for cat in categories):
  50. categories_zero = True
  51. return annotations_zero, images_zero, categories_zero
  52. def reindex_coco_data(data: Dict[str, Any]) -> Dict[str, Any]:
  53. """Convert 0-indexed COCO data to 1-indexed."""
  54. modified_data = data.copy()
  55. annotations_zero, images_zero, categories_zero = check_zero_indexed(data)
  56. # Create ID mapping for consistency
  57. image_id_mapping = {}
  58. category_id_mapping = {}
  59. # Process images first (since annotations reference image IDs)
  60. if images_zero and "images" in modified_data:
  61. for img in modified_data["images"]:
  62. old_id = img["id"]
  63. new_id = old_id + 1
  64. image_id_mapping[old_id] = new_id
  65. img["id"] = new_id
  66. # Process categories (since annotations reference category IDs)
  67. if categories_zero and "categories" in modified_data:
  68. for cat in modified_data["categories"]:
  69. old_id = cat["id"]
  70. new_id = old_id + 1
  71. category_id_mapping[old_id] = new_id
  72. cat["id"] = new_id
  73. # Process annotations
  74. if "annotations" in modified_data:
  75. for ann in modified_data["annotations"]:
  76. # Update annotation ID if needed
  77. if annotations_zero:
  78. ann["id"] = ann["id"] + 1
  79. # Update image_id reference if images were reindexed
  80. if images_zero and ann.get("image_id") is not None:
  81. old_image_id = ann["image_id"]
  82. if old_image_id in image_id_mapping:
  83. ann["image_id"] = image_id_mapping[old_image_id]
  84. # Update category_id reference if categories were reindexed
  85. if categories_zero and ann.get("category_id") is not None:
  86. old_category_id = ann["category_id"]
  87. if old_category_id in category_id_mapping:
  88. ann["category_id"] = category_id_mapping[old_category_id]
  89. return modified_data
  90. # Validate input path
  91. if not os.path.exists(input_json_path):
  92. raise FileNotFoundError(f"Input file not found: {input_json_path}")
  93. # Load and validate JSON data
  94. try:
  95. with open(input_json_path, "r", encoding="utf-8") as f:
  96. data = json.load(f)
  97. except json.JSONDecodeError as e:
  98. raise json.JSONDecodeError(f"Invalid JSON in {input_json_path}: {e}")
  99. # Validate COCO format
  100. if not is_coco_json(data):
  101. raise ValueError(
  102. f"File does not appear to be in COCO format: {input_json_path}"
  103. )
  104. # Check if reindexing is needed
  105. annotations_zero, images_zero, categories_zero = check_zero_indexed(data)
  106. if not (annotations_zero or images_zero or categories_zero):
  107. # No conversion needed - just copy to temp location
  108. input_path = Path(input_json_path)
  109. temp_dir = tempfile.mkdtemp()
  110. temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}"
  111. temp_path = os.path.join(temp_dir, temp_filename)
  112. with open(temp_path, "w", encoding="utf-8") as f:
  113. json.dump(data, f, indent=2, ensure_ascii=False)
  114. return temp_path
  115. # Perform reindexing
  116. modified_data = reindex_coco_data(data)
  117. # Create temporary file
  118. input_path = Path(input_json_path)
  119. temp_dir = tempfile.mkdtemp()
  120. temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}"
  121. temp_path = os.path.join(temp_dir, temp_filename)
  122. # Write modified data to temporary file
  123. with open(temp_path, "w", encoding="utf-8") as f:
  124. json.dump(modified_data, f, indent=2, ensure_ascii=False)
  125. return temp_path
  126. # Example usage and test function
  127. def test_reindex_function():
  128. """Test the reindex function with a sample COCO file."""
  129. # Create a test COCO file
  130. test_data = {
  131. "info": {"description": "Test COCO dataset", "version": "1.0", "year": 2023},
  132. "images": [
  133. {"id": 0, "width": 640, "height": 480, "file_name": "test1.jpg"},
  134. {"id": 1, "width": 640, "height": 480, "file_name": "test2.jpg"},
  135. ],
  136. "categories": [
  137. {"id": 0, "name": "person", "supercategory": "person"},
  138. {"id": 1, "name": "car", "supercategory": "vehicle"},
  139. ],
  140. "annotations": [
  141. {
  142. "id": 0,
  143. "image_id": 0,
  144. "category_id": 0,
  145. "bbox": [100, 100, 50, 75],
  146. "area": 3750,
  147. "iscrowd": 0,
  148. },
  149. {
  150. "id": 1,
  151. "image_id": 1,
  152. "category_id": 1,
  153. "bbox": [200, 150, 120, 80],
  154. "area": 9600,
  155. "iscrowd": 0,
  156. },
  157. ],
  158. }
  159. # Create temporary test file
  160. with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
  161. json.dump(test_data, f, indent=2)
  162. test_file_path = f.name
  163. try:
  164. # Test the function
  165. result_path = reindex_coco_to_temp(test_file_path)
  166. print(f"Original file: {test_file_path}")
  167. print(f"Converted file: {result_path}")
  168. # Load and display the result
  169. with open(result_path, "r") as f:
  170. result_data = json.load(f)
  171. print("\nConverted data sample:")
  172. print(f"First image ID: {result_data['images'][0]['id']}")
  173. print(f"First category ID: {result_data['categories'][0]['id']}")
  174. print(f"First annotation ID: {result_data['annotations'][0]['id']}")
  175. print(f"First annotation image_id: {result_data['annotations'][0]['image_id']}")
  176. print(
  177. f"First annotation category_id: {result_data['annotations'][0]['category_id']}"
  178. )
  179. # Clean up
  180. os.unlink(result_path)
  181. os.rmdir(os.path.dirname(result_path))
  182. finally:
  183. # Clean up test file
  184. os.unlink(test_file_path)
  185. if __name__ == "__main__":
  186. test_reindex_function()