misc.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import os
  6. import warnings
  7. from threading import Thread
  8. import numpy as np
  9. import torch
  10. from PIL import Image
  11. from tqdm import tqdm
  12. def get_sdpa_settings():
  13. if torch.cuda.is_available():
  14. old_gpu = torch.cuda.get_device_properties(0).major < 7
  15. # only use Flash Attention on Ampere (8.0) or newer GPUs
  16. use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
  17. if not use_flash_attn:
  18. warnings.warn(
  19. "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
  20. category=UserWarning,
  21. stacklevel=2,
  22. )
  23. # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
  24. # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
  25. pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
  26. if pytorch_version < (2, 2):
  27. warnings.warn(
  28. f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
  29. "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
  30. category=UserWarning,
  31. stacklevel=2,
  32. )
  33. math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
  34. else:
  35. old_gpu = True
  36. use_flash_attn = False
  37. math_kernel_on = True
  38. return old_gpu, use_flash_attn, math_kernel_on
  39. def get_connected_components(mask):
  40. """
  41. Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
  42. Inputs:
  43. - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
  44. background.
  45. Outputs:
  46. - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
  47. for foreground pixels and 0 for background pixels.
  48. - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
  49. components for foreground pixels and 0 for background pixels.
  50. """
  51. from sam2 import _C
  52. return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
  53. def mask_to_box(masks: torch.Tensor):
  54. """
  55. compute bounding box given an input mask
  56. Inputs:
  57. - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
  58. Returns:
  59. - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
  60. """
  61. B, _, h, w = masks.shape
  62. device = masks.device
  63. xs = torch.arange(w, device=device, dtype=torch.int32)
  64. ys = torch.arange(h, device=device, dtype=torch.int32)
  65. grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
  66. grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
  67. grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
  68. min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
  69. max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
  70. min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
  71. max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
  72. bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
  73. return bbox_coords
  74. def _load_img_as_tensor(img_path, image_size):
  75. img_pil = Image.open(img_path)
  76. img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
  77. if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
  78. img_np = img_np / 255.0
  79. else:
  80. raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
  81. img = torch.from_numpy(img_np).permute(2, 0, 1)
  82. video_width, video_height = img_pil.size # the original video size
  83. return img, video_height, video_width
  84. class AsyncVideoFrameLoader:
  85. """
  86. A list of video frames to be load asynchronously without blocking session start.
  87. """
  88. def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
  89. self.img_paths = img_paths
  90. self.image_size = image_size
  91. self.offload_video_to_cpu = offload_video_to_cpu
  92. self.img_mean = img_mean
  93. self.img_std = img_std
  94. # items in `self._images` will be loaded asynchronously
  95. self.images = [None] * len(img_paths)
  96. # catch and raise any exceptions in the async loading thread
  97. self.exception = None
  98. # video_height and video_width be filled when loading the first image
  99. self.video_height = None
  100. self.video_width = None
  101. # load the first frame to fill video_height and video_width and also
  102. # to cache it (since it's most likely where the user will click)
  103. self.__getitem__(0)
  104. # load the rest of frames asynchronously without blocking the session start
  105. def _load_frames():
  106. try:
  107. for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
  108. self.__getitem__(n)
  109. except Exception as e:
  110. self.exception = e
  111. self.thread = Thread(target=_load_frames, daemon=True)
  112. self.thread.start()
  113. def __getitem__(self, index):
  114. if self.exception is not None:
  115. raise RuntimeError("Failure in frame loading thread") from self.exception
  116. img = self.images[index]
  117. if img is not None:
  118. return img
  119. img, video_height, video_width = _load_img_as_tensor(
  120. self.img_paths[index], self.image_size
  121. )
  122. self.video_height = video_height
  123. self.video_width = video_width
  124. # normalize by mean and std
  125. img -= self.img_mean
  126. img /= self.img_std
  127. if not self.offload_video_to_cpu:
  128. img = img.cuda(non_blocking=True)
  129. self.images[index] = img
  130. return img
  131. def __len__(self):
  132. return len(self.images)
  133. def load_video_frames(
  134. video_path,
  135. image_size,
  136. offload_video_to_cpu,
  137. img_mean=(0.485, 0.456, 0.406),
  138. img_std=(0.229, 0.224, 0.225),
  139. async_loading_frames=False,
  140. ):
  141. """
  142. Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
  143. The frames are resized to image_size x image_size and are loaded to GPU if
  144. `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
  145. You can load a frame asynchronously by setting `async_loading_frames` to `True`.
  146. """
  147. if isinstance(video_path, str) and os.path.isdir(video_path):
  148. jpg_folder = video_path
  149. else:
  150. raise NotImplementedError("Only JPEG frames are supported at this moment")
  151. frame_names = [
  152. p
  153. for p in os.listdir(jpg_folder)
  154. if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
  155. ]
  156. frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
  157. num_frames = len(frame_names)
  158. if num_frames == 0:
  159. raise RuntimeError(f"no images found in {jpg_folder}")
  160. img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
  161. img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
  162. img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
  163. if async_loading_frames:
  164. lazy_images = AsyncVideoFrameLoader(
  165. img_paths, image_size, offload_video_to_cpu, img_mean, img_std
  166. )
  167. return lazy_images, lazy_images.video_height, lazy_images.video_width
  168. images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
  169. for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
  170. images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
  171. if not offload_video_to_cpu:
  172. images = images.cuda()
  173. img_mean = img_mean.cuda()
  174. img_std = img_std.cuda()
  175. # normalize by mean and std
  176. images -= img_mean
  177. images /= img_std
  178. return images, video_height, video_width
  179. def fill_holes_in_mask_scores(mask, max_area):
  180. """
  181. A post processor to fill small holes in mask scores with area under `max_area`.
  182. """
  183. # Holes are those connected components in background with area <= self.max_area
  184. # (background regions are those with mask scores <= 0)
  185. assert max_area > 0, "max_area must be positive"
  186. labels, areas = get_connected_components(mask <= 0)
  187. is_hole = (labels > 0) & (areas <= max_area)
  188. # We fill holes with a small positive mask score (0.1) to change them to foreground.
  189. mask = torch.where(is_hole, 0.1, mask)
  190. return mask
  191. def concat_points(old_point_inputs, new_points, new_labels):
  192. """Add new points and labels to previous point inputs (add at the end)."""
  193. if old_point_inputs is None:
  194. points, labels = new_points, new_labels
  195. else:
  196. points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
  197. labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
  198. return {"point_coords": points, "point_labels": labels}