nms.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. # Adapted from https://github.com/stackav-oss/conch/blob/main/conch/kernels/vision/nms.py
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. @triton.autotune(
  8. configs=[
  9. triton.Config({"cxpr_block_size": 128}),
  10. triton.Config({"cxpr_block_size": 256}),
  11. triton.Config({"cxpr_block_size": 512}),
  12. triton.Config({"cxpr_block_size": 1024}),
  13. triton.Config({"cxpr_block_size": 2048}),
  14. triton.Config({"cxpr_block_size": 4096}),
  15. triton.Config({"cxpr_block_size": 8192}),
  16. ],
  17. key=["num_boxes"],
  18. )
  19. @triton.jit
  20. def _nms_suppression_kernel(
  21. # Tensors
  22. iou_mask_ptr: tl.tensor, # [N, N]
  23. keep_mask_ptr: tl.tensor, # [N]
  24. # Scalars
  25. num_boxes: tl.int32,
  26. # Strides
  27. iou_mask_stride: tl.int32,
  28. # Constexprs
  29. cxpr_block_size: tl.constexpr,
  30. ) -> None:
  31. """NMS suppression kernel.
  32. Args:
  33. iou_mask_ptr: Pointer to precomputed IoU mask, shape: (N, N).
  34. keep_mask_ptr: Pointer to keep mask tensor, shape: (N,).
  35. num_boxes: Number of boxes.
  36. iou_mask_stride: Stride for IoU mask tensor.
  37. cxpr_block_size: Block size for processing.
  38. """
  39. # Sequential NMS: for each box in sorted order, suppress later boxes
  40. for current_box_idx in range(num_boxes - 1):
  41. # Check if current box is still kept
  42. is_kept = tl.load(keep_mask_ptr + current_box_idx)
  43. if is_kept:
  44. # IoU mask row offset for the current box
  45. # Because the IoU mask is sorted by score, we will only consider boxes that come after the current box.
  46. # This means we only need to read the upper triangular part of the IoU mask.
  47. iou_row_offset = current_box_idx * iou_mask_stride
  48. # Only process boxes that come after the current box
  49. next_box_idx = current_box_idx + 1
  50. remaining_boxes = num_boxes - next_box_idx
  51. # Iterate blockwise through the columns
  52. for block_idx in range(tl.cdiv(remaining_boxes, cxpr_block_size)):
  53. # Masked load of indices for the target boxes in the current block
  54. block_start = next_box_idx + block_idx * cxpr_block_size
  55. target_box_offsets = block_start + tl.arange(0, cxpr_block_size)
  56. target_box_mask = target_box_offsets < num_boxes
  57. # Suppress boxes with lower scores that have high IoU
  58. suppression_mask = tl.load(
  59. iou_mask_ptr + iou_row_offset + target_box_offsets,
  60. mask=target_box_mask,
  61. other=False,
  62. )
  63. suppression_mask = tl.cast(suppression_mask, tl.int1)
  64. # Conditionally store suppression result for high-IoU boxes
  65. tl.store(
  66. keep_mask_ptr + target_box_offsets, False, mask=suppression_mask
  67. )
  68. # Potential race condition: we need to ensure all threads complete the store before the next
  69. # iteration otherwise we may load stale data for whether or not a box has been suppressed.
  70. tl.debug_barrier()
  71. def nms_triton(
  72. ious: torch.Tensor,
  73. scores: torch.Tensor,
  74. iou_threshold: float,
  75. ) -> torch.Tensor:
  76. """Perform NMS given the iou matrix, the scores and the iou threshold
  77. Args:
  78. ious: Pairwise IoU tensor of shape (N, N).
  79. scores: Scores tensor of shape (N,).
  80. iou_threshold: IoU threshold for suppression.
  81. Returns:
  82. Tensor: Indices of kept boxes, sorted by decreasing score.
  83. """
  84. assert scores.dim() == 1, "Scores must be 1D"
  85. iou_mask = ious > iou_threshold
  86. assert iou_mask.dim() == 2
  87. assert iou_mask.shape[0] == iou_mask.shape[1] == scores.shape[0]
  88. assert iou_mask.device == scores.device
  89. assert iou_mask.dtype == torch.bool
  90. num_boxes = scores.size(0)
  91. keep_mask = torch.ones(len(scores), device=scores.device, dtype=torch.bool)
  92. # Sort boxes by scores in descending order
  93. _, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True)
  94. iou_mask = iou_mask[sorted_indices][:, sorted_indices].contiguous()
  95. # For the suppression stage, we need to process sequentially, but we'll still take
  96. # advantage of parallelism by processing in blocks in one program.
  97. stage2_grid = (1,)
  98. _nms_suppression_kernel[stage2_grid](
  99. # Tensors
  100. iou_mask_ptr=iou_mask,
  101. keep_mask_ptr=keep_mask,
  102. # Scalars
  103. num_boxes=num_boxes,
  104. # Strides
  105. iou_mask_stride=iou_mask.stride(0),
  106. )
  107. # Extract indices of kept boxes
  108. return sorted_indices[keep_mask]