| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import glob
- import json
- import os
- import numpy as np
- import pandas as pd
- import torch
- from PIL import Image as PILImage
- try:
- from pycocotools import mask as mask_utils
- except:
- pass
- class JSONSegmentLoader:
- def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
- # Annotations in the json are provided every ann_every th frame
- self.ann_every = ann_every
- # Ids of the objects to consider when sampling this video
- self.valid_obj_ids = valid_obj_ids
- with open(video_json_path, "r") as f:
- data = json.load(f)
- if isinstance(data, list):
- self.frame_annots = data
- elif isinstance(data, dict):
- masklet_field_name = "masklet" if "masklet" in data else "masks"
- self.frame_annots = data[masklet_field_name]
- if "fps" in data:
- if isinstance(data["fps"], list):
- annotations_fps = int(data["fps"][0])
- else:
- annotations_fps = int(data["fps"])
- assert frames_fps % annotations_fps == 0
- self.ann_every = frames_fps // annotations_fps
- else:
- raise NotImplementedError
- def load(self, frame_id, obj_ids=None):
- assert frame_id % self.ann_every == 0
- rle_mask = self.frame_annots[frame_id // self.ann_every]
- valid_objs_ids = set(range(len(rle_mask)))
- if self.valid_obj_ids is not None:
- # Remove the masklets that have been filtered out for this video
- valid_objs_ids &= set(self.valid_obj_ids)
- if obj_ids is not None:
- # Only keep the objects that have been sampled
- valid_objs_ids &= set(obj_ids)
- valid_objs_ids = sorted(list(valid_objs_ids))
- # Construct rle_masks_filtered that only contains the rle masks we are interested in
- id_2_idx = {}
- rle_mask_filtered = []
- for obj_id in valid_objs_ids:
- if rle_mask[obj_id] is not None:
- id_2_idx[obj_id] = len(rle_mask_filtered)
- rle_mask_filtered.append(rle_mask[obj_id])
- else:
- id_2_idx[obj_id] = None
- # Decode the masks
- raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
- 2, 0, 1
- ) # (num_obj, h, w)
- segments = {}
- for obj_id in valid_objs_ids:
- if id_2_idx[obj_id] is None:
- segments[obj_id] = None
- else:
- idx = id_2_idx[obj_id]
- segments[obj_id] = raw_segments[idx]
- return segments
- def get_valid_obj_frames_ids(self, num_frames_min=None):
- # For each object, find all the frames with a valid (not None) mask
- num_objects = len(self.frame_annots[0])
- # The result dict associates each obj_id with the id of its valid frames
- res = {obj_id: [] for obj_id in range(num_objects)}
- for annot_idx, annot in enumerate(self.frame_annots):
- for obj_id in range(num_objects):
- if annot[obj_id] is not None:
- res[obj_id].append(int(annot_idx * self.ann_every))
- if num_frames_min is not None:
- # Remove masklets that have less than num_frames_min valid masks
- for obj_id, valid_frames in list(res.items()):
- if len(valid_frames) < num_frames_min:
- res.pop(obj_id)
- return res
- class PalettisedPNGSegmentLoader:
- def __init__(self, video_png_root):
- """
- SegmentLoader for datasets with masks stored as palettised PNGs.
- video_png_root: the folder contains all the masks stored in png
- """
- self.video_png_root = video_png_root
- # build a mapping from frame id to their PNG mask path
- # note that in some datasets, the PNG paths could have more
- # than 5 digits, e.g. "00000000.png" instead of "00000.png"
- png_filenames = os.listdir(self.video_png_root)
- self.frame_id_to_png_filename = {}
- for filename in png_filenames:
- frame_id, _ = os.path.splitext(filename)
- self.frame_id_to_png_filename[int(frame_id)] = filename
- def load(self, frame_id):
- """
- load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
- Args:
- frame_id: int, define the mask path
- Return:
- binary_segments: dict
- """
- # check the path
- mask_path = os.path.join(
- self.video_png_root, self.frame_id_to_png_filename[frame_id]
- )
- # load the mask
- masks = PILImage.open(mask_path).convert("P")
- masks = np.array(masks)
- object_id = pd.unique(masks.flatten())
- object_id = object_id[object_id != 0] # remove background (0)
- # convert into N binary segmentation masks
- binary_segments = {}
- for i in object_id:
- bs = masks == i
- binary_segments[i] = torch.from_numpy(bs)
- return binary_segments
- def __len__(self):
- return
- class MultiplePNGSegmentLoader:
- def __init__(self, video_png_root, single_object_mode=False):
- """
- video_png_root: the folder contains all the masks stored in png
- single_object_mode: whether to load only a single object at a time
- """
- self.video_png_root = video_png_root
- self.single_object_mode = single_object_mode
- # read a mask to know the resolution of the video
- if self.single_object_mode:
- tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
- else:
- tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
- tmp_mask = np.array(PILImage.open(tmp_mask_path))
- self.H = tmp_mask.shape[0]
- self.W = tmp_mask.shape[1]
- if self.single_object_mode:
- self.obj_id = (
- int(video_png_root.split("/")[-1]) + 1
- ) # offset by 1 as bg is 0
- else:
- self.obj_id = None
- def load(self, frame_id):
- if self.single_object_mode:
- return self._load_single_png(frame_id)
- else:
- return self._load_multiple_pngs(frame_id)
- def _load_single_png(self, frame_id):
- """
- load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
- Args:
- frame_id: int, define the mask path
- Return:
- binary_segments: dict
- """
- mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
- binary_segments = {}
- if os.path.exists(mask_path):
- mask = np.array(PILImage.open(mask_path))
- else:
- # if png doesn't exist, empty mask
- mask = np.zeros((self.H, self.W), dtype=bool)
- binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
- return binary_segments
- def _load_multiple_pngs(self, frame_id):
- """
- load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
- Args:
- frame_id: int, define the mask path
- Return:
- binary_segments: dict
- """
- # get the path
- all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
- num_objects = len(all_objects)
- assert num_objects > 0
- # load the masks
- binary_segments = {}
- for obj_folder in all_objects:
- # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
- obj_id = int(obj_folder.split("/")[-1])
- obj_id = obj_id + 1 # offset 1 as bg is 0
- mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
- if os.path.exists(mask_path):
- mask = np.array(PILImage.open(mask_path))
- else:
- mask = np.zeros((self.H, self.W), dtype=bool)
- binary_segments[obj_id] = torch.from_numpy(mask > 0)
- return binary_segments
- def __len__(self):
- return
- class LazySegments:
- """
- Only decodes segments that are actually used.
- """
- def __init__(self):
- self.segments = {}
- self.cache = {}
- def __setitem__(self, key, item):
- self.segments[key] = item
- def __getitem__(self, key):
- if key in self.cache:
- return self.cache[key]
- rle = self.segments[key]
- mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
- self.cache[key] = mask
- return mask
- def __contains__(self, key):
- return key in self.segments
- def __len__(self):
- return len(self.segments)
- def keys(self):
- return self.segments.keys()
- class SA1BSegmentLoader:
- def __init__(
- self,
- video_mask_path,
- mask_area_frac_thresh=1.1,
- video_frame_path=None,
- uncertain_iou=-1,
- ):
- with open(video_mask_path, "r") as f:
- self.frame_annots = json.load(f)
- if mask_area_frac_thresh <= 1.0:
- # Lazily read frame
- orig_w, orig_h = PILImage.open(video_frame_path).size
- area = orig_w * orig_h
- self.frame_annots = self.frame_annots["annotations"]
- rle_masks = []
- for frame_annot in self.frame_annots:
- if not frame_annot["area"] > 0:
- continue
- if ("uncertain_iou" in frame_annot) and (
- frame_annot["uncertain_iou"] < uncertain_iou
- ):
- # uncertain_iou is stability score
- continue
- if (
- mask_area_frac_thresh <= 1.0
- and (frame_annot["area"] / area) >= mask_area_frac_thresh
- ):
- continue
- rle_masks.append(frame_annot["segmentation"])
- self.segments = LazySegments()
- for i, rle in enumerate(rle_masks):
- self.segments[i] = rle
- def load(self, frame_idx):
- return self.segments
|