nms_helper.py 11 KB


  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import warnings
  4. from typing import Dict, List
  5. import numpy as np
  6. # Check if Numba is available
  7. HAS_NUMBA = False
  8. try:
  9. import numba as nb
  10. HAS_NUMBA = True
  11. except ImportError:
  12. warnings.warn(
  13. "Numba not found. Using slower pure Python implementations.", UserWarning
  14. )
  15. # -------------------- Helper Functions --------------------
  16. def is_zero_box(bbox: list) -> bool:
  17. """Check if bounding box is invalid"""
  18. if bbox is None:
  19. return True
  20. return all(x <= 0 for x in bbox[:4]) or len(bbox) < 4
  21. def convert_bbox_format(bbox: list) -> List[float]:
  22. """Convert bbox from (x,y,w,h) to (x1,y1,x2,y2)"""
  23. x, y, w, h = bbox
  24. return [x, y, x + w, y + h]
  25. # -------------------- Track-level NMS --------------------
  26. def process_track_level_nms(video_groups: Dict, nms_threshold: float) -> Dict:
  27. """Apply track-level NMS to all videos"""
  28. for tracks in video_groups.values():
  29. track_detections = []
  30. # Process tracks
  31. for track_idx, track in enumerate(tracks):
  32. if not track["bboxes"]:
  33. continue
  34. converted_bboxes = []
  35. valid_frames = []
  36. for bbox in track["bboxes"]:
  37. if bbox and not is_zero_box(bbox):
  38. converted_bboxes.append(convert_bbox_format(bbox))
  39. valid_frames.append(True)
  40. else:
  41. converted_bboxes.append([np.nan] * 4)
  42. valid_frames.append(False)
  43. if any(valid_frames):
  44. track_detections.append(
  45. {
  46. "track_idx": track_idx,
  47. "bboxes": np.array(converted_bboxes, dtype=np.float32),
  48. "score": track["score"],
  49. }
  50. )
  51. # Apply NMS
  52. if track_detections:
  53. scores = np.array([d["score"] for d in track_detections], dtype=np.float32)
  54. keep = apply_track_nms(track_detections, scores, nms_threshold)
  55. # Suppress non-kept tracks
  56. for idx, track in enumerate(track_detections):
  57. if idx not in keep:
  58. tracks[track["track_idx"]]["bboxes"] = [None] * len(track["bboxes"])
  59. return video_groups
  60. # -------------------- Frame-level NMS --------------------
  61. def process_frame_level_nms(video_groups: Dict, nms_threshold: float) -> Dict:
  62. """Apply frame-level NMS to all videos"""
  63. for tracks in video_groups.values():
  64. if not tracks:
  65. continue
  66. num_frames = len(tracks[0]["bboxes"])
  67. for frame_idx in range(num_frames):
  68. frame_detections = []
  69. # Collect valid detections
  70. for track_idx, track in enumerate(tracks):
  71. bbox = track["bboxes"][frame_idx]
  72. if bbox and not is_zero_box(bbox):
  73. frame_detections.append(
  74. {
  75. "track_idx": track_idx,
  76. "bbox": np.array(
  77. convert_bbox_format(bbox), dtype=np.float32
  78. ),
  79. "score": track["score"],
  80. }
  81. )
  82. # Apply NMS
  83. if frame_detections:
  84. bboxes = np.stack([d["bbox"] for d in frame_detections])
  85. scores = np.array(
  86. [d["score"] for d in frame_detections], dtype=np.float32
  87. )
  88. keep = apply_frame_nms(bboxes, scores, nms_threshold)
  89. # Suppress non-kept detections
  90. for i, d in enumerate(frame_detections):
  91. if i not in keep:
  92. tracks[d["track_idx"]]["bboxes"][frame_idx] = None
  93. return video_groups
  94. # Track-level NMS helpers ------------------------------------------------------
  95. def compute_track_iou_matrix(
  96. bboxes_stacked: np.ndarray, valid_masks: np.ndarray, areas: np.ndarray
  97. ) -> np.ndarray:
  98. """IoU matrix computation for track-level NMS with fallback to pure Python"""
  99. num_tracks = bboxes_stacked.shape[0]
  100. iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32)
  101. if HAS_NUMBA:
  102. iou_matrix = _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas)
  103. else:
  104. # Pure Python implementation
  105. for i in range(num_tracks):
  106. for j in range(i + 1, num_tracks):
  107. valid_ij = valid_masks[i] & valid_masks[j]
  108. if not valid_ij.any():
  109. continue
  110. bboxes_i = bboxes_stacked[i, valid_ij]
  111. bboxes_j = bboxes_stacked[j, valid_ij]
  112. area_i = areas[i, valid_ij]
  113. area_j = areas[j, valid_ij]
  114. inter_total = 0.0
  115. union_total = 0.0
  116. for k in range(bboxes_i.shape[0]):
  117. x1 = max(bboxes_i[k, 0], bboxes_j[k, 0])
  118. y1 = max(bboxes_i[k, 1], bboxes_j[k, 1])
  119. x2 = min(bboxes_i[k, 2], bboxes_j[k, 2])
  120. y2 = min(bboxes_i[k, 3], bboxes_j[k, 3])
  121. inter = max(0, x2 - x1) * max(0, y2 - y1)
  122. union = area_i[k] + area_j[k] - inter
  123. inter_total += inter
  124. union_total += union
  125. if union_total > 0:
  126. iou_matrix[i, j] = inter_total / union_total
  127. iou_matrix[j, i] = iou_matrix[i, j]
  128. return iou_matrix
  129. if HAS_NUMBA:
  130. @nb.jit(nopython=True, parallel=True)
  131. def _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas):
  132. """Numba-optimized IoU matrix computation for track-level NMS"""
  133. num_tracks = bboxes_stacked.shape[0]
  134. iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32)
  135. for i in nb.prange(num_tracks):
  136. for j in range(i + 1, num_tracks):
  137. valid_ij = valid_masks[i] & valid_masks[j]
  138. if not valid_ij.any():
  139. continue
  140. bboxes_i = bboxes_stacked[i, valid_ij]
  141. bboxes_j = bboxes_stacked[j, valid_ij]
  142. area_i = areas[i, valid_ij]
  143. area_j = areas[j, valid_ij]
  144. inter_total = 0.0
  145. union_total = 0.0
  146. for k in range(bboxes_i.shape[0]):
  147. x1 = max(bboxes_i[k, 0], bboxes_j[k, 0])
  148. y1 = max(bboxes_i[k, 1], bboxes_j[k, 1])
  149. x2 = min(bboxes_i[k, 2], bboxes_j[k, 2])
  150. y2 = min(bboxes_i[k, 3], bboxes_j[k, 3])
  151. inter = max(0, x2 - x1) * max(0, y2 - y1)
  152. union = area_i[k] + area_j[k] - inter
  153. inter_total += inter
  154. union_total += union
  155. if union_total > 0:
  156. iou_matrix[i, j] = inter_total / union_total
  157. iou_matrix[j, i] = iou_matrix[i, j]
  158. return iou_matrix
  159. def apply_track_nms(
  160. track_detections: List[dict], scores: np.ndarray, nms_threshold: float
  161. ) -> List[int]:
  162. """Vectorized track-level NMS implementation"""
  163. if not track_detections:
  164. return []
  165. bboxes_stacked = np.stack([d["bboxes"] for d in track_detections], axis=0)
  166. valid_masks = ~np.isnan(bboxes_stacked).any(axis=2)
  167. areas = (bboxes_stacked[:, :, 2] - bboxes_stacked[:, :, 0]) * (
  168. bboxes_stacked[:, :, 3] - bboxes_stacked[:, :, 1]
  169. )
  170. areas[~valid_masks] = 0
  171. iou_matrix = compute_track_iou_matrix(bboxes_stacked, valid_masks, areas)
  172. keep = []
  173. order = np.argsort(-scores)
  174. suppress = np.zeros(len(track_detections), dtype=bool)
  175. for i in range(len(order)):
  176. if not suppress[order[i]]:
  177. keep.append(order[i])
  178. suppress[order[i:]] = suppress[order[i:]] | (
  179. iou_matrix[order[i], order[i:]] >= nms_threshold
  180. )
  181. return keep
  182. # Frame-level NMS helpers ------------------------------------------------------
  183. def compute_frame_ious(bbox: np.ndarray, bboxes: np.ndarray) -> np.ndarray:
  184. """IoU computation for frame-level NMS with fallback to pure Python"""
  185. if HAS_NUMBA:
  186. return _compute_frame_ious_numba(bbox, bboxes)
  187. else:
  188. # Pure Python implementation
  189. ious = np.zeros(len(bboxes), dtype=np.float32)
  190. for i in range(len(bboxes)):
  191. x1 = max(bbox[0], bboxes[i, 0])
  192. y1 = max(bbox[1], bboxes[i, 1])
  193. x2 = min(bbox[2], bboxes[i, 2])
  194. y2 = min(bbox[3], bboxes[i, 3])
  195. inter = max(0, x2 - x1) * max(0, y2 - y1)
  196. area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
  197. area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1])
  198. union = area1 + area2 - inter
  199. ious[i] = inter / union if union > 0 else 0.0
  200. return ious
  201. if HAS_NUMBA:
  202. @nb.jit(nopython=True, parallel=True)
  203. def _compute_frame_ious_numba(bbox, bboxes):
  204. """Numba-optimized IoU computation"""
  205. ious = np.zeros(len(bboxes), dtype=np.float32)
  206. for i in nb.prange(len(bboxes)):
  207. x1 = max(bbox[0], bboxes[i, 0])
  208. y1 = max(bbox[1], bboxes[i, 1])
  209. x2 = min(bbox[2], bboxes[i, 2])
  210. y2 = min(bbox[3], bboxes[i, 3])
  211. inter = max(0, x2 - x1) * max(0, y2 - y1)
  212. area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
  213. area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1])
  214. union = area1 + area2 - inter
  215. ious[i] = inter / union if union > 0 else 0.0
  216. return ious
  217. def apply_frame_nms(
  218. bboxes: np.ndarray, scores: np.ndarray, nms_threshold: float
  219. ) -> List[int]:
  220. """Frame-level NMS implementation with fallback to pure Python"""
  221. if HAS_NUMBA:
  222. return _apply_frame_nms_numba(bboxes, scores, nms_threshold)
  223. else:
  224. # Pure Python implementation
  225. order = np.argsort(-scores)
  226. keep = []
  227. suppress = np.zeros(len(bboxes), dtype=bool)
  228. for i in range(len(order)):
  229. if not suppress[order[i]]:
  230. keep.append(order[i])
  231. current_bbox = bboxes[order[i]]
  232. remaining_bboxes = bboxes[order[i + 1 :]]
  233. if len(remaining_bboxes) > 0: # Check if there are any remaining boxes
  234. ious = compute_frame_ious(current_bbox, remaining_bboxes)
  235. suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | (
  236. ious >= nms_threshold
  237. )
  238. return keep
  239. if HAS_NUMBA:
  240. @nb.jit(nopython=True)
  241. def _apply_frame_nms_numba(bboxes, scores, nms_threshold):
  242. """Numba-optimized NMS implementation"""
  243. order = np.argsort(-scores)
  244. keep = []
  245. suppress = np.zeros(len(bboxes), dtype=nb.boolean)
  246. for i in range(len(order)):
  247. if not suppress[order[i]]:
  248. keep.append(order[i])
  249. current_bbox = bboxes[order[i]]
  250. if i + 1 < len(order): # Check bounds
  251. ious = _compute_frame_ious_numba(
  252. current_bbox, bboxes[order[i + 1 :]]
  253. )
  254. suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | (
  255. ious >= nms_threshold
  256. )
  257. return keep