ytvis_coco_wrapper.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
  3. # pyre-unsafe
  4. import copy
  5. import json
  6. import logging
  7. import numpy as np
  8. import pycocotools.mask as mask_util
  9. from pycocotools.coco import COCO
  10. from typing_extensions import override
  11. class YTVIS(COCO):
  12. """
  13. Helper class for reading YT-VIS annotations
  14. """
  15. @override
  16. def __init__(self, annotation_file: str = None, ignore_gt_cats: bool = True):
  17. """
  18. Args:
  19. annotation_file: Path to the annotation file
  20. ignore_gt_cats: If True, we ignore the ground truth categories and replace them with a dummy "object" category. This is useful for Phrase AP evaluation.
  21. """
  22. self.ignore_gt_cats = ignore_gt_cats
  23. super().__init__(annotation_file=annotation_file)
  24. @override
  25. def createIndex(self):
  26. # We rename some keys to match the COCO format before creating the index.
  27. if "annotations" in self.dataset:
  28. for ann in self.dataset["annotations"]:
  29. if "video_id" in ann:
  30. ann["image_id"] = int(ann.pop("video_id"))
  31. if self.ignore_gt_cats:
  32. ann["category_id"] = -1
  33. else:
  34. ann["category_id"] = int(ann["category_id"])
  35. if "bboxes" in ann:
  36. # note that in some datasets we load under this YTVIS class,
  37. # some "bboxes" could be None for when the GT object is invisible,
  38. # so we replace them with [0, 0, 0, 0]
  39. ann["bboxes"] = [
  40. bbox if bbox is not None else [0, 0, 0, 0]
  41. for bbox in ann["bboxes"]
  42. ]
  43. if "areas" in ann:
  44. # similar to "bboxes", some areas could be None for when the GT
  45. # object is invisible, so we replace them with 0
  46. areas = [a if a is not None else 0 for a in ann["areas"]]
  47. # Compute average area of tracklet
  48. ann["area"] = np.mean(areas)
  49. if "videos" in self.dataset:
  50. for vid in self.dataset["videos"]:
  51. vid["id"] = int(vid["id"])
  52. self.dataset["images"] = self.dataset.pop("videos")
  53. if self.ignore_gt_cats:
  54. self.dataset["categories"] = [
  55. {"supercategory": "object", "id": -1, "name": "object"}
  56. ]
  57. else:
  58. for cat in self.dataset["categories"]:
  59. cat["id"] = int(cat["id"])
  60. super().createIndex()
  61. @override
  62. def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
  63. if len(areaRng) > 0:
  64. logging.warning(
  65. "Note that we filter out objects based on their *average* area across the video, not per frame area"
  66. )
  67. return super().getAnnIds(imgIds=imgIds, catIds=catIds, iscrowd=iscrowd)
  68. @override
  69. def showAnns(self, anns, draw_bbox=False):
  70. raise NotImplementedError("Showing annotations is not supported")
  71. @override
  72. def loadRes(self, resFile):
  73. # Adapted from COCO.loadRes to support tracklets/masklets
  74. res = YTVIS(ignore_gt_cats=self.ignore_gt_cats)
  75. res.dataset["images"] = [img for img in self.dataset["images"]]
  76. if type(resFile) == str:
  77. with open(resFile) as f:
  78. anns = json.load(f)
  79. elif type(resFile) == np.ndarray:
  80. anns = self.loadNumpyAnnotations(resFile)
  81. else:
  82. anns = resFile
  83. assert type(anns) == list, "results is not an array of objects"
  84. annsImgIds = [ann["image_id"] for ann in anns]
  85. assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
  86. "Results do not correspond to current coco set"
  87. )
  88. if "bboxes" in anns[0] and not anns[0]["bboxes"] == []:
  89. res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
  90. for id, ann in enumerate(anns):
  91. bbs = [(bb if bb is not None else [0, 0, 0, 0]) for bb in ann["bboxes"]]
  92. xxyy = [[bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] for bb in bbs]
  93. if not "segmentations" in ann:
  94. ann["segmentations"] = [
  95. [[x1, y1, x1, y2, x2, y2, x2, y1]] for (x1, x2, y1, y2) in xxyy
  96. ]
  97. ann["areas"] = [bb[2] * bb[3] for bb in bbs]
  98. # NOTE: We also compute average area of a tracklet across video, allowing us to compute area based mAP.
  99. ann["area"] = np.mean(ann["areas"])
  100. ann["id"] = id + 1
  101. ann["iscrowd"] = 0
  102. elif "segmentations" in anns[0]:
  103. res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
  104. for id, ann in enumerate(anns):
  105. ann["bboxes"] = [
  106. mask_util.toBbox(segm) for segm in ann["segmentations"]
  107. ]
  108. if "areas" not in ann:
  109. ann["areas"] = [
  110. mask_util.area(segm) for segm in ann["segmentations"]
  111. ]
  112. # NOTE: We also compute average area of a tracklet across video, allowing us to compute area based mAP.
  113. ann["area"] = np.mean(ann["areas"])
  114. ann["id"] = id + 1
  115. ann["iscrowd"] = 0
  116. res.dataset["annotations"] = anns
  117. res.createIndex()
  118. return res
  119. @override
  120. def download(self, tarDir=None, imgIds=[]):
  121. raise NotImplementedError
  122. @override
  123. def loadNumpyAnnotations(self, data):
  124. raise NotImplementedError("We don't support numpy annotations for now")
  125. @override
  126. def annToRLE(self, ann):
  127. raise NotImplementedError("We expect masks to be already in RLE format")
  128. @override
  129. def annToMask(self, ann):
  130. raise NotImplementedError("We expect masks to be already in RLE format")