vos_segment_loader.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import glob
  6. import json
  7. import os
  8. import numpy as np
  9. import pandas as pd
  10. import torch
  11. from PIL import Image as PILImage
  12. try:
  13. from pycocotools import mask as mask_utils
  14. except:
  15. pass
  16. class JSONSegmentLoader:
  17. def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
  18. # Annotations in the json are provided every ann_every th frame
  19. self.ann_every = ann_every
  20. # Ids of the objects to consider when sampling this video
  21. self.valid_obj_ids = valid_obj_ids
  22. with open(video_json_path, "r") as f:
  23. data = json.load(f)
  24. if isinstance(data, list):
  25. self.frame_annots = data
  26. elif isinstance(data, dict):
  27. masklet_field_name = "masklet" if "masklet" in data else "masks"
  28. self.frame_annots = data[masklet_field_name]
  29. if "fps" in data:
  30. if isinstance(data["fps"], list):
  31. annotations_fps = int(data["fps"][0])
  32. else:
  33. annotations_fps = int(data["fps"])
  34. assert frames_fps % annotations_fps == 0
  35. self.ann_every = frames_fps // annotations_fps
  36. else:
  37. raise NotImplementedError
  38. def load(self, frame_id, obj_ids=None):
  39. assert frame_id % self.ann_every == 0
  40. rle_mask = self.frame_annots[frame_id // self.ann_every]
  41. valid_objs_ids = set(range(len(rle_mask)))
  42. if self.valid_obj_ids is not None:
  43. # Remove the masklets that have been filtered out for this video
  44. valid_objs_ids &= set(self.valid_obj_ids)
  45. if obj_ids is not None:
  46. # Only keep the objects that have been sampled
  47. valid_objs_ids &= set(obj_ids)
  48. valid_objs_ids = sorted(list(valid_objs_ids))
  49. # Construct rle_masks_filtered that only contains the rle masks we are interested in
  50. id_2_idx = {}
  51. rle_mask_filtered = []
  52. for obj_id in valid_objs_ids:
  53. if rle_mask[obj_id] is not None:
  54. id_2_idx[obj_id] = len(rle_mask_filtered)
  55. rle_mask_filtered.append(rle_mask[obj_id])
  56. else:
  57. id_2_idx[obj_id] = None
  58. # Decode the masks
  59. raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
  60. 2, 0, 1
  61. ) # (num_obj, h, w)
  62. segments = {}
  63. for obj_id in valid_objs_ids:
  64. if id_2_idx[obj_id] is None:
  65. segments[obj_id] = None
  66. else:
  67. idx = id_2_idx[obj_id]
  68. segments[obj_id] = raw_segments[idx]
  69. return segments
  70. def get_valid_obj_frames_ids(self, num_frames_min=None):
  71. # For each object, find all the frames with a valid (not None) mask
  72. num_objects = len(self.frame_annots[0])
  73. # The result dict associates each obj_id with the id of its valid frames
  74. res = {obj_id: [] for obj_id in range(num_objects)}
  75. for annot_idx, annot in enumerate(self.frame_annots):
  76. for obj_id in range(num_objects):
  77. if annot[obj_id] is not None:
  78. res[obj_id].append(int(annot_idx * self.ann_every))
  79. if num_frames_min is not None:
  80. # Remove masklets that have less than num_frames_min valid masks
  81. for obj_id, valid_frames in list(res.items()):
  82. if len(valid_frames) < num_frames_min:
  83. res.pop(obj_id)
  84. return res
  85. class PalettisedPNGSegmentLoader:
  86. def __init__(self, video_png_root):
  87. """
  88. SegmentLoader for datasets with masks stored as palettised PNGs.
  89. video_png_root: the folder contains all the masks stored in png
  90. """
  91. self.video_png_root = video_png_root
  92. # build a mapping from frame id to their PNG mask path
  93. # note that in some datasets, the PNG paths could have more
  94. # than 5 digits, e.g. "00000000.png" instead of "00000.png"
  95. png_filenames = os.listdir(self.video_png_root)
  96. self.frame_id_to_png_filename = {}
  97. for filename in png_filenames:
  98. frame_id, _ = os.path.splitext(filename)
  99. self.frame_id_to_png_filename[int(frame_id)] = filename
  100. def load(self, frame_id):
  101. """
  102. load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
  103. Args:
  104. frame_id: int, define the mask path
  105. Return:
  106. binary_segments: dict
  107. """
  108. # check the path
  109. mask_path = os.path.join(
  110. self.video_png_root, self.frame_id_to_png_filename[frame_id]
  111. )
  112. # load the mask
  113. masks = PILImage.open(mask_path).convert("P")
  114. masks = np.array(masks)
  115. object_id = pd.unique(masks.flatten())
  116. object_id = object_id[object_id != 0] # remove background (0)
  117. # convert into N binary segmentation masks
  118. binary_segments = {}
  119. for i in object_id:
  120. bs = masks == i
  121. binary_segments[i] = torch.from_numpy(bs)
  122. return binary_segments
  123. def __len__(self):
  124. return
  125. class MultiplePNGSegmentLoader:
  126. def __init__(self, video_png_root, single_object_mode=False):
  127. """
  128. video_png_root: the folder contains all the masks stored in png
  129. single_object_mode: whether to load only a single object at a time
  130. """
  131. self.video_png_root = video_png_root
  132. self.single_object_mode = single_object_mode
  133. # read a mask to know the resolution of the video
  134. if self.single_object_mode:
  135. tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
  136. else:
  137. tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
  138. tmp_mask = np.array(PILImage.open(tmp_mask_path))
  139. self.H = tmp_mask.shape[0]
  140. self.W = tmp_mask.shape[1]
  141. if self.single_object_mode:
  142. self.obj_id = (
  143. int(video_png_root.split("/")[-1]) + 1
  144. ) # offset by 1 as bg is 0
  145. else:
  146. self.obj_id = None
  147. def load(self, frame_id):
  148. if self.single_object_mode:
  149. return self._load_single_png(frame_id)
  150. else:
  151. return self._load_multiple_pngs(frame_id)
  152. def _load_single_png(self, frame_id):
  153. """
  154. load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
  155. Args:
  156. frame_id: int, define the mask path
  157. Return:
  158. binary_segments: dict
  159. """
  160. mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
  161. binary_segments = {}
  162. if os.path.exists(mask_path):
  163. mask = np.array(PILImage.open(mask_path))
  164. else:
  165. # if png doesn't exist, empty mask
  166. mask = np.zeros((self.H, self.W), dtype=bool)
  167. binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
  168. return binary_segments
  169. def _load_multiple_pngs(self, frame_id):
  170. """
  171. load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
  172. Args:
  173. frame_id: int, define the mask path
  174. Return:
  175. binary_segments: dict
  176. """
  177. # get the path
  178. all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
  179. num_objects = len(all_objects)
  180. assert num_objects > 0
  181. # load the masks
  182. binary_segments = {}
  183. for obj_folder in all_objects:
  184. # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
  185. obj_id = int(obj_folder.split("/")[-1])
  186. obj_id = obj_id + 1 # offset 1 as bg is 0
  187. mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
  188. if os.path.exists(mask_path):
  189. mask = np.array(PILImage.open(mask_path))
  190. else:
  191. mask = np.zeros((self.H, self.W), dtype=bool)
  192. binary_segments[obj_id] = torch.from_numpy(mask > 0)
  193. return binary_segments
  194. def __len__(self):
  195. return
  196. class LazySegments:
  197. """
  198. Only decodes segments that are actually used.
  199. """
  200. def __init__(self):
  201. self.segments = {}
  202. self.cache = {}
  203. def __setitem__(self, key, item):
  204. self.segments[key] = item
  205. def __getitem__(self, key):
  206. if key in self.cache:
  207. return self.cache[key]
  208. rle = self.segments[key]
  209. mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
  210. self.cache[key] = mask
  211. return mask
  212. def __contains__(self, key):
  213. return key in self.segments
  214. def __len__(self):
  215. return len(self.segments)
  216. def keys(self):
  217. return self.segments.keys()
  218. class SA1BSegmentLoader:
  219. def __init__(
  220. self,
  221. video_mask_path,
  222. mask_area_frac_thresh=1.1,
  223. video_frame_path=None,
  224. uncertain_iou=-1,
  225. ):
  226. with open(video_mask_path, "r") as f:
  227. self.frame_annots = json.load(f)
  228. if mask_area_frac_thresh <= 1.0:
  229. # Lazily read frame
  230. orig_w, orig_h = PILImage.open(video_frame_path).size
  231. area = orig_w * orig_h
  232. self.frame_annots = self.frame_annots["annotations"]
  233. rle_masks = []
  234. for frame_annot in self.frame_annots:
  235. if not frame_annot["area"] > 0:
  236. continue
  237. if ("uncertain_iou" in frame_annot) and (
  238. frame_annot["uncertain_iou"] < uncertain_iou
  239. ):
  240. # uncertain_iou is stability score
  241. continue
  242. if (
  243. mask_area_frac_thresh <= 1.0
  244. and (frame_annot["area"] / area) >= mask_area_frac_thresh
  245. ):
  246. continue
  247. rle_masks.append(frame_annot["segmentation"])
  248. self.segments = LazySegments()
  249. for i, rle in enumerate(rle_masks):
  250. self.segments[i] = rle
  251. def load(self, frame_idx):
  252. return self.segments