| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the sav_dataset directory of this source tree.
- import json
- import os
- from typing import Dict, List, Optional, Tuple
- import cv2
- import matplotlib.pyplot as plt
- import numpy as np
- import pycocotools.mask as mask_util
- def decode_video(video_path: str) -> List[np.ndarray]:
- """
- Decode the video and return the RGB frames
- """
- video = cv2.VideoCapture(video_path)
- video_frames = []
- while video.isOpened():
- ret, frame = video.read()
- if ret:
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- video_frames.append(frame)
- else:
- break
- return video_frames
- def show_anns(masks, colors: List, borders=True) -> None:
- """
- show the annotations
- """
- # return if no masks
- if len(masks) == 0:
- return
- # sort masks by size
- sorted_annot_and_color = sorted(
- zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
- )
- H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
- canvas = np.ones((H, W, 4))
- canvas[:, :, 3] = 0 # set the alpha channel
- contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
- for mask, color in sorted_annot_and_color:
- canvas[mask] = np.concatenate([color, [0.55]])
- if borders:
- contours, _ = cv2.findContours(
- np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
- )
- cv2.drawContours(
- canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
- )
- ax = plt.gca()
- ax.imshow(canvas)
- class SAVDataset:
- """
- SAVDataset is a class to load the SAV dataset and visualize the annotations.
- """
- def __init__(self, sav_dir, annot_sample_rate=4):
- """
- Args:
- sav_dir: the directory of the SAV dataset
- annot_sample_rate: the sampling rate of the annotations.
- The annotations are aligned with the videos at 6 fps.
- """
- self.sav_dir = sav_dir
- self.annot_sample_rate = annot_sample_rate
- self.manual_mask_colors = np.random.random((256, 3))
- self.auto_mask_colors = np.random.random((256, 3))
- def read_frames(self, mp4_path: str) -> None:
- """
- Read the frames and downsample them to align with the annotations.
- """
- if not os.path.exists(mp4_path):
- print(f"{mp4_path} doesn't exist.")
- return None
- else:
- # decode the video
- frames = decode_video(mp4_path)
- print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
- # downsample the frames to align with the annotations
- frames = frames[:: self.annot_sample_rate]
- print(
- f"Videos are annotated every {self.annot_sample_rate} frames. "
- "To align with the annotations, "
- f"downsample the video to {len(frames)} frames."
- )
- return frames
- def get_frames_and_annotations(
- self, video_id: str
- ) -> Tuple[List | None, Dict | None, Dict | None]:
- """
- Get the frames and annotations for video.
- """
- # load the video
- mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
- frames = self.read_frames(mp4_path)
- if frames is None:
- return None, None, None
- # load the manual annotations
- manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
- if not os.path.exists(manual_annot_path):
- print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
- manual_annot = None
- else:
- manual_annot = json.load(open(manual_annot_path))
- # load the manual annotations
- auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
- if not os.path.exists(auto_annot_path):
- print(f"{auto_annot_path} doesn't exist.")
- auto_annot = None
- else:
- auto_annot = json.load(open(auto_annot_path))
- return frames, manual_annot, auto_annot
- def visualize_annotation(
- self,
- frames: List[np.ndarray],
- auto_annot: Optional[Dict],
- manual_annot: Optional[Dict],
- annotated_frame_id: int,
- show_auto=True,
- show_manual=True,
- ) -> None:
- """
- Visualize the annotations on the annotated_frame_id.
- If show_manual is True, show the manual annotations.
- If show_auto is True, show the auto annotations.
- By default, show both auto and manual annotations.
- """
- if annotated_frame_id >= len(frames):
- print("invalid annotated_frame_id")
- return
- rles = []
- colors = []
- if show_manual and manual_annot is not None:
- rles.extend(manual_annot["masklet"][annotated_frame_id])
- colors.extend(
- self.manual_mask_colors[
- : len(manual_annot["masklet"][annotated_frame_id])
- ]
- )
- if show_auto and auto_annot is not None:
- rles.extend(auto_annot["masklet"][annotated_frame_id])
- colors.extend(
- self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
- )
- plt.imshow(frames[annotated_frame_id])
- if len(rles) > 0:
- masks = [mask_util.decode(rle) > 0 for rle in rles]
- show_anns(masks, colors)
- else:
- print("No annotation will be shown")
- plt.axis("off")
- plt.show()
|