| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import glob
- import logging
- import os
- from dataclasses import dataclass
- from typing import List, Optional
- import pandas as pd
- import torch
- from iopath.common.file_io import g_pathmgr
- from omegaconf.listconfig import ListConfig
- from training.dataset.vos_segment_loader import (
- JSONSegmentLoader,
- MultiplePNGSegmentLoader,
- PalettisedPNGSegmentLoader,
- SA1BSegmentLoader,
- )
- @dataclass
- class VOSFrame:
- frame_idx: int
- image_path: str
- data: Optional[torch.Tensor] = None
- is_conditioning_only: Optional[bool] = False
- @dataclass
- class VOSVideo:
- video_name: str
- video_id: int
- frames: List[VOSFrame]
- def __len__(self):
- return len(self.frames)
- class VOSRawDataset:
- def __init__(self):
- pass
- def get_video(self, idx):
- raise NotImplementedError()
- class PNGRawDataset(VOSRawDataset):
- def __init__(
- self,
- img_folder,
- gt_folder,
- file_list_txt=None,
- excluded_videos_list_txt=None,
- sample_rate=1,
- is_palette=True,
- single_object_mode=False,
- truncate_video=-1,
- frames_sampling_mult=False,
- ):
- self.img_folder = img_folder
- self.gt_folder = gt_folder
- self.sample_rate = sample_rate
- self.is_palette = is_palette
- self.single_object_mode = single_object_mode
- self.truncate_video = truncate_video
- # Read the subset defined in file_list_txt
- if file_list_txt is not None:
- with g_pathmgr.open(file_list_txt, "r") as f:
- subset = [os.path.splitext(line.strip())[0] for line in f]
- else:
- subset = os.listdir(self.img_folder)
- # Read and process excluded files if provided
- if excluded_videos_list_txt is not None:
- with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
- excluded_files = [os.path.splitext(line.strip())[0] for line in f]
- else:
- excluded_files = []
- # Check if it's not in excluded_files
- self.video_names = sorted(
- [video_name for video_name in subset if video_name not in excluded_files]
- )
- if self.single_object_mode:
- # single object mode
- self.video_names = sorted(
- [
- os.path.join(video_name, obj)
- for video_name in self.video_names
- for obj in os.listdir(os.path.join(self.gt_folder, video_name))
- ]
- )
- if frames_sampling_mult:
- video_names_mult = []
- for video_name in self.video_names:
- num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
- video_names_mult.extend([video_name] * num_frames)
- self.video_names = video_names_mult
- def get_video(self, idx):
- """
- Given a VOSVideo object, return the mask tensors.
- """
- video_name = self.video_names[idx]
- if self.single_object_mode:
- video_frame_root = os.path.join(
- self.img_folder, os.path.dirname(video_name)
- )
- else:
- video_frame_root = os.path.join(self.img_folder, video_name)
- video_mask_root = os.path.join(self.gt_folder, video_name)
- if self.is_palette:
- segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
- else:
- segment_loader = MultiplePNGSegmentLoader(
- video_mask_root, self.single_object_mode
- )
- all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
- if self.truncate_video > 0:
- all_frames = all_frames[: self.truncate_video]
- frames = []
- for _, fpath in enumerate(all_frames[:: self.sample_rate]):
- fid = int(os.path.basename(fpath).split(".")[0])
- frames.append(VOSFrame(fid, image_path=fpath))
- video = VOSVideo(video_name, idx, frames)
- return video, segment_loader
- def __len__(self):
- return len(self.video_names)
- class SA1BRawDataset(VOSRawDataset):
- def __init__(
- self,
- img_folder,
- gt_folder,
- file_list_txt=None,
- excluded_videos_list_txt=None,
- num_frames=1,
- mask_area_frac_thresh=1.1, # no filtering by default
- uncertain_iou=-1, # no filtering by default
- ):
- self.img_folder = img_folder
- self.gt_folder = gt_folder
- self.num_frames = num_frames
- self.mask_area_frac_thresh = mask_area_frac_thresh
- self.uncertain_iou = uncertain_iou # stability score
- # Read the subset defined in file_list_txt
- if file_list_txt is not None:
- with g_pathmgr.open(file_list_txt, "r") as f:
- subset = [os.path.splitext(line.strip())[0] for line in f]
- else:
- subset = os.listdir(self.img_folder)
- subset = [
- path.split(".")[0] for path in subset if path.endswith(".jpg")
- ] # remove extension
- # Read and process excluded files if provided
- if excluded_videos_list_txt is not None:
- with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
- excluded_files = [os.path.splitext(line.strip())[0] for line in f]
- else:
- excluded_files = []
- # Check if it's not in excluded_files and it exists
- self.video_names = [
- video_name for video_name in subset if video_name not in excluded_files
- ]
- def get_video(self, idx):
- """
- Given a VOSVideo object, return the mask tensors.
- """
- video_name = self.video_names[idx]
- video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
- video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
- segment_loader = SA1BSegmentLoader(
- video_mask_path,
- mask_area_frac_thresh=self.mask_area_frac_thresh,
- video_frame_path=video_frame_path,
- uncertain_iou=self.uncertain_iou,
- )
- frames = []
- for frame_idx in range(self.num_frames):
- frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
- video_name = video_name.split("_")[-1] # filename is sa_{int}
- # video id needs to be image_id to be able to load correct annotation file during eval
- video = VOSVideo(video_name, int(video_name), frames)
- return video, segment_loader
- def __len__(self):
- return len(self.video_names)
- class JSONRawDataset(VOSRawDataset):
- """
- Dataset where the annotation in the format of SA-V json files
- """
- def __init__(
- self,
- img_folder,
- gt_folder,
- file_list_txt=None,
- excluded_videos_list_txt=None,
- sample_rate=1,
- rm_unannotated=True,
- ann_every=1,
- frames_fps=24,
- ):
- self.gt_folder = gt_folder
- self.img_folder = img_folder
- self.sample_rate = sample_rate
- self.rm_unannotated = rm_unannotated
- self.ann_every = ann_every
- self.frames_fps = frames_fps
- # Read and process excluded files if provided
- excluded_files = []
- if excluded_videos_list_txt is not None:
- if isinstance(excluded_videos_list_txt, str):
- excluded_videos_lists = [excluded_videos_list_txt]
- elif isinstance(excluded_videos_list_txt, ListConfig):
- excluded_videos_lists = list(excluded_videos_list_txt)
- else:
- raise NotImplementedError
- for excluded_videos_list_txt in excluded_videos_lists:
- with open(excluded_videos_list_txt, "r") as f:
- excluded_files.extend(
- [os.path.splitext(line.strip())[0] for line in f]
- )
- excluded_files = set(excluded_files)
- # Read the subset defined in file_list_txt
- if file_list_txt is not None:
- with g_pathmgr.open(file_list_txt, "r") as f:
- subset = [os.path.splitext(line.strip())[0] for line in f]
- else:
- subset = os.listdir(self.img_folder)
- self.video_names = sorted(
- [video_name for video_name in subset if video_name not in excluded_files]
- )
- def get_video(self, video_idx):
- """
- Given a VOSVideo object, return the mask tensors.
- """
- video_name = self.video_names[video_idx]
- video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
- segment_loader = JSONSegmentLoader(
- video_json_path=video_json_path,
- ann_every=self.ann_every,
- frames_fps=self.frames_fps,
- )
- frame_ids = [
- int(os.path.splitext(frame_name)[0])
- for frame_name in sorted(
- os.listdir(os.path.join(self.img_folder, video_name))
- )
- ]
- frames = [
- VOSFrame(
- frame_id,
- image_path=os.path.join(
- self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
- ),
- )
- for frame_id in frame_ids[:: self.sample_rate]
- ]
- if self.rm_unannotated:
- # Eliminate the frames that have not been annotated
- valid_frame_ids = [
- i * segment_loader.ann_every
- for i, annot in enumerate(segment_loader.frame_annots)
- if annot is not None and None not in annot
- ]
- frames = [f for f in frames if f.frame_idx in valid_frame_ids]
- video = VOSVideo(video_name, video_idx, frames)
- return video, segment_loader
- def __len__(self):
- return len(self.video_names)
|