sav_utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the sav_dataset directory of this source tree.
  5. import json
  6. import os
  7. from typing import Dict, List, Optional, Tuple
  8. import cv2
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import pycocotools.mask as mask_util
  12. def decode_video(video_path: str) -> List[np.ndarray]:
  13. """
  14. Decode the video and return the RGB frames
  15. """
  16. video = cv2.VideoCapture(video_path)
  17. video_frames = []
  18. while video.isOpened():
  19. ret, frame = video.read()
  20. if ret:
  21. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  22. video_frames.append(frame)
  23. else:
  24. break
  25. return video_frames
  26. def show_anns(masks, colors: List, borders=True) -> None:
  27. """
  28. show the annotations
  29. """
  30. # return if no masks
  31. if len(masks) == 0:
  32. return
  33. # sort masks by size
  34. sorted_annot_and_color = sorted(
  35. zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
  36. )
  37. H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
  38. canvas = np.ones((H, W, 4))
  39. canvas[:, :, 3] = 0 # set the alpha channel
  40. contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
  41. for mask, color in sorted_annot_and_color:
  42. canvas[mask] = np.concatenate([color, [0.55]])
  43. if borders:
  44. contours, _ = cv2.findContours(
  45. np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
  46. )
  47. cv2.drawContours(
  48. canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
  49. )
  50. ax = plt.gca()
  51. ax.imshow(canvas)
  52. class SAVDataset:
  53. """
  54. SAVDataset is a class to load the SAV dataset and visualize the annotations.
  55. """
  56. def __init__(self, sav_dir, annot_sample_rate=4):
  57. """
  58. Args:
  59. sav_dir: the directory of the SAV dataset
  60. annot_sample_rate: the sampling rate of the annotations.
  61. The annotations are aligned with the videos at 6 fps.
  62. """
  63. self.sav_dir = sav_dir
  64. self.annot_sample_rate = annot_sample_rate
  65. self.manual_mask_colors = np.random.random((256, 3))
  66. self.auto_mask_colors = np.random.random((256, 3))
  67. def read_frames(self, mp4_path: str) -> None:
  68. """
  69. Read the frames and downsample them to align with the annotations.
  70. """
  71. if not os.path.exists(mp4_path):
  72. print(f"{mp4_path} doesn't exist.")
  73. return None
  74. else:
  75. # decode the video
  76. frames = decode_video(mp4_path)
  77. print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
  78. # downsample the frames to align with the annotations
  79. frames = frames[:: self.annot_sample_rate]
  80. print(
  81. f"Videos are annotated every {self.annot_sample_rate} frames. "
  82. "To align with the annotations, "
  83. f"downsample the video to {len(frames)} frames."
  84. )
  85. return frames
  86. def get_frames_and_annotations(
  87. self, video_id: str
  88. ) -> Tuple[List | None, Dict | None, Dict | None]:
  89. """
  90. Get the frames and annotations for video.
  91. """
  92. # load the video
  93. mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
  94. frames = self.read_frames(mp4_path)
  95. if frames is None:
  96. return None, None, None
  97. # load the manual annotations
  98. manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
  99. if not os.path.exists(manual_annot_path):
  100. print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
  101. manual_annot = None
  102. else:
  103. manual_annot = json.load(open(manual_annot_path))
  104. # load the manual annotations
  105. auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
  106. if not os.path.exists(auto_annot_path):
  107. print(f"{auto_annot_path} doesn't exist.")
  108. auto_annot = None
  109. else:
  110. auto_annot = json.load(open(auto_annot_path))
  111. return frames, manual_annot, auto_annot
  112. def visualize_annotation(
  113. self,
  114. frames: List[np.ndarray],
  115. auto_annot: Optional[Dict],
  116. manual_annot: Optional[Dict],
  117. annotated_frame_id: int,
  118. show_auto=True,
  119. show_manual=True,
  120. ) -> None:
  121. """
  122. Visualize the annotations on the annotated_frame_id.
  123. If show_manual is True, show the manual annotations.
  124. If show_auto is True, show the auto annotations.
  125. By default, show both auto and manual annotations.
  126. """
  127. if annotated_frame_id >= len(frames):
  128. print("invalid annotated_frame_id")
  129. return
  130. rles = []
  131. colors = []
  132. if show_manual and manual_annot is not None:
  133. rles.extend(manual_annot["masklet"][annotated_frame_id])
  134. colors.extend(
  135. self.manual_mask_colors[
  136. : len(manual_annot["masklet"][annotated_frame_id])
  137. ]
  138. )
  139. if show_auto and auto_annot is not None:
  140. rles.extend(auto_annot["masklet"][annotated_frame_id])
  141. colors.extend(
  142. self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
  143. )
  144. plt.imshow(frames[annotated_frame_id])
  145. if len(rles) > 0:
  146. masks = [mask_util.decode(rle) > 0 for rle in rles]
  147. show_anns(masks, colors)
  148. else:
  149. print("No annotation will be shown")
  150. plt.axis("off")
  151. plt.show()