vos_raw_dataset.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import glob
  6. import logging
  7. import os
  8. from dataclasses import dataclass
  9. from typing import List, Optional
  10. import pandas as pd
  11. import torch
  12. from iopath.common.file_io import g_pathmgr
  13. from omegaconf.listconfig import ListConfig
  14. from training.dataset.vos_segment_loader import (
  15. JSONSegmentLoader,
  16. MultiplePNGSegmentLoader,
  17. PalettisedPNGSegmentLoader,
  18. SA1BSegmentLoader,
  19. )
  20. @dataclass
  21. class VOSFrame:
  22. frame_idx: int
  23. image_path: str
  24. data: Optional[torch.Tensor] = None
  25. is_conditioning_only: Optional[bool] = False
  26. @dataclass
  27. class VOSVideo:
  28. video_name: str
  29. video_id: int
  30. frames: List[VOSFrame]
  31. def __len__(self):
  32. return len(self.frames)
  33. class VOSRawDataset:
  34. def __init__(self):
  35. pass
  36. def get_video(self, idx):
  37. raise NotImplementedError()
  38. class PNGRawDataset(VOSRawDataset):
  39. def __init__(
  40. self,
  41. img_folder,
  42. gt_folder,
  43. file_list_txt=None,
  44. excluded_videos_list_txt=None,
  45. sample_rate=1,
  46. is_palette=True,
  47. single_object_mode=False,
  48. truncate_video=-1,
  49. frames_sampling_mult=False,
  50. ):
  51. self.img_folder = img_folder
  52. self.gt_folder = gt_folder
  53. self.sample_rate = sample_rate
  54. self.is_palette = is_palette
  55. self.single_object_mode = single_object_mode
  56. self.truncate_video = truncate_video
  57. # Read the subset defined in file_list_txt
  58. if file_list_txt is not None:
  59. with g_pathmgr.open(file_list_txt, "r") as f:
  60. subset = [os.path.splitext(line.strip())[0] for line in f]
  61. else:
  62. subset = os.listdir(self.img_folder)
  63. # Read and process excluded files if provided
  64. if excluded_videos_list_txt is not None:
  65. with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
  66. excluded_files = [os.path.splitext(line.strip())[0] for line in f]
  67. else:
  68. excluded_files = []
  69. # Check if it's not in excluded_files
  70. self.video_names = sorted(
  71. [video_name for video_name in subset if video_name not in excluded_files]
  72. )
  73. if self.single_object_mode:
  74. # single object mode
  75. self.video_names = sorted(
  76. [
  77. os.path.join(video_name, obj)
  78. for video_name in self.video_names
  79. for obj in os.listdir(os.path.join(self.gt_folder, video_name))
  80. ]
  81. )
  82. if frames_sampling_mult:
  83. video_names_mult = []
  84. for video_name in self.video_names:
  85. num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
  86. video_names_mult.extend([video_name] * num_frames)
  87. self.video_names = video_names_mult
  88. def get_video(self, idx):
  89. """
  90. Given a VOSVideo object, return the mask tensors.
  91. """
  92. video_name = self.video_names[idx]
  93. if self.single_object_mode:
  94. video_frame_root = os.path.join(
  95. self.img_folder, os.path.dirname(video_name)
  96. )
  97. else:
  98. video_frame_root = os.path.join(self.img_folder, video_name)
  99. video_mask_root = os.path.join(self.gt_folder, video_name)
  100. if self.is_palette:
  101. segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
  102. else:
  103. segment_loader = MultiplePNGSegmentLoader(
  104. video_mask_root, self.single_object_mode
  105. )
  106. all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
  107. if self.truncate_video > 0:
  108. all_frames = all_frames[: self.truncate_video]
  109. frames = []
  110. for _, fpath in enumerate(all_frames[:: self.sample_rate]):
  111. fid = int(os.path.basename(fpath).split(".")[0])
  112. frames.append(VOSFrame(fid, image_path=fpath))
  113. video = VOSVideo(video_name, idx, frames)
  114. return video, segment_loader
  115. def __len__(self):
  116. return len(self.video_names)
  117. class SA1BRawDataset(VOSRawDataset):
  118. def __init__(
  119. self,
  120. img_folder,
  121. gt_folder,
  122. file_list_txt=None,
  123. excluded_videos_list_txt=None,
  124. num_frames=1,
  125. mask_area_frac_thresh=1.1, # no filtering by default
  126. uncertain_iou=-1, # no filtering by default
  127. ):
  128. self.img_folder = img_folder
  129. self.gt_folder = gt_folder
  130. self.num_frames = num_frames
  131. self.mask_area_frac_thresh = mask_area_frac_thresh
  132. self.uncertain_iou = uncertain_iou # stability score
  133. # Read the subset defined in file_list_txt
  134. if file_list_txt is not None:
  135. with g_pathmgr.open(file_list_txt, "r") as f:
  136. subset = [os.path.splitext(line.strip())[0] for line in f]
  137. else:
  138. subset = os.listdir(self.img_folder)
  139. subset = [
  140. path.split(".")[0] for path in subset if path.endswith(".jpg")
  141. ] # remove extension
  142. # Read and process excluded files if provided
  143. if excluded_videos_list_txt is not None:
  144. with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
  145. excluded_files = [os.path.splitext(line.strip())[0] for line in f]
  146. else:
  147. excluded_files = []
  148. # Check if it's not in excluded_files and it exists
  149. self.video_names = [
  150. video_name for video_name in subset if video_name not in excluded_files
  151. ]
  152. def get_video(self, idx):
  153. """
  154. Given a VOSVideo object, return the mask tensors.
  155. """
  156. video_name = self.video_names[idx]
  157. video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
  158. video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
  159. segment_loader = SA1BSegmentLoader(
  160. video_mask_path,
  161. mask_area_frac_thresh=self.mask_area_frac_thresh,
  162. video_frame_path=video_frame_path,
  163. uncertain_iou=self.uncertain_iou,
  164. )
  165. frames = []
  166. for frame_idx in range(self.num_frames):
  167. frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
  168. video_name = video_name.split("_")[-1] # filename is sa_{int}
  169. # video id needs to be image_id to be able to load correct annotation file during eval
  170. video = VOSVideo(video_name, int(video_name), frames)
  171. return video, segment_loader
  172. def __len__(self):
  173. return len(self.video_names)
  174. class JSONRawDataset(VOSRawDataset):
  175. """
  176. Dataset where the annotation in the format of SA-V json files
  177. """
  178. def __init__(
  179. self,
  180. img_folder,
  181. gt_folder,
  182. file_list_txt=None,
  183. excluded_videos_list_txt=None,
  184. sample_rate=1,
  185. rm_unannotated=True,
  186. ann_every=1,
  187. frames_fps=24,
  188. ):
  189. self.gt_folder = gt_folder
  190. self.img_folder = img_folder
  191. self.sample_rate = sample_rate
  192. self.rm_unannotated = rm_unannotated
  193. self.ann_every = ann_every
  194. self.frames_fps = frames_fps
  195. # Read and process excluded files if provided
  196. excluded_files = []
  197. if excluded_videos_list_txt is not None:
  198. if isinstance(excluded_videos_list_txt, str):
  199. excluded_videos_lists = [excluded_videos_list_txt]
  200. elif isinstance(excluded_videos_list_txt, ListConfig):
  201. excluded_videos_lists = list(excluded_videos_list_txt)
  202. else:
  203. raise NotImplementedError
  204. for excluded_videos_list_txt in excluded_videos_lists:
  205. with open(excluded_videos_list_txt, "r") as f:
  206. excluded_files.extend(
  207. [os.path.splitext(line.strip())[0] for line in f]
  208. )
  209. excluded_files = set(excluded_files)
  210. # Read the subset defined in file_list_txt
  211. if file_list_txt is not None:
  212. with g_pathmgr.open(file_list_txt, "r") as f:
  213. subset = [os.path.splitext(line.strip())[0] for line in f]
  214. else:
  215. subset = os.listdir(self.img_folder)
  216. self.video_names = sorted(
  217. [video_name for video_name in subset if video_name not in excluded_files]
  218. )
  219. def get_video(self, video_idx):
  220. """
  221. Given a VOSVideo object, return the mask tensors.
  222. """
  223. video_name = self.video_names[video_idx]
  224. video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
  225. segment_loader = JSONSegmentLoader(
  226. video_json_path=video_json_path,
  227. ann_every=self.ann_every,
  228. frames_fps=self.frames_fps,
  229. )
  230. frame_ids = [
  231. int(os.path.splitext(frame_name)[0])
  232. for frame_name in sorted(
  233. os.listdir(os.path.join(self.img_folder, video_name))
  234. )
  235. ]
  236. frames = [
  237. VOSFrame(
  238. frame_id,
  239. image_path=os.path.join(
  240. self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
  241. ),
  242. )
  243. for frame_id in frame_ids[:: self.sample_rate]
  244. ]
  245. if self.rm_unannotated:
  246. # Eliminate the frames that have not been annotated
  247. valid_frame_ids = [
  248. i * segment_loader.ann_every
  249. for i, annot in enumerate(segment_loader.frame_annots)
  250. if annot is not None and None not in annot
  251. ]
  252. frames = [f for f in frames if f.frame_idx in valid_frame_ids]
  253. video = VOSVideo(video_name, video_idx, frames)
  254. return video, segment_loader
  255. def __len__(self):
  256. return len(self.video_names)