| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- # Adapted from https://github.com/stackav-oss/conch/blob/main/conch/kernels/vision/nms.py
- import torch
- import triton
- import triton.language as tl
- @triton.autotune(
- configs=[
- triton.Config({"cxpr_block_size": 128}),
- triton.Config({"cxpr_block_size": 256}),
- triton.Config({"cxpr_block_size": 512}),
- triton.Config({"cxpr_block_size": 1024}),
- triton.Config({"cxpr_block_size": 2048}),
- triton.Config({"cxpr_block_size": 4096}),
- triton.Config({"cxpr_block_size": 8192}),
- ],
- key=["num_boxes"],
- )
- @triton.jit
- def _nms_suppression_kernel(
- # Tensors
- iou_mask_ptr: tl.tensor, # [N, N]
- keep_mask_ptr: tl.tensor, # [N]
- # Scalars
- num_boxes: tl.int32,
- # Strides
- iou_mask_stride: tl.int32,
- # Constexprs
- cxpr_block_size: tl.constexpr,
- ) -> None:
- """NMS suppression kernel.
- Args:
- iou_mask_ptr: Pointer to precomputed IoU mask, shape: (N, N).
- keep_mask_ptr: Pointer to keep mask tensor, shape: (N,).
- num_boxes: Number of boxes.
- iou_mask_stride: Stride for IoU mask tensor.
- cxpr_block_size: Block size for processing.
- """
- # Sequential NMS: for each box in sorted order, suppress later boxes
- for current_box_idx in range(num_boxes - 1):
- # Check if current box is still kept
- is_kept = tl.load(keep_mask_ptr + current_box_idx)
- if is_kept:
- # IoU mask row offset for the current box
- # Because the IoU mask is sorted by score, we will only consider boxes that come after the current box.
- # This means we only need to read the upper triangular part of the IoU mask.
- iou_row_offset = current_box_idx * iou_mask_stride
- # Only process boxes that come after the current box
- next_box_idx = current_box_idx + 1
- remaining_boxes = num_boxes - next_box_idx
- # Iterate blockwise through the columns
- for block_idx in range(tl.cdiv(remaining_boxes, cxpr_block_size)):
- # Masked load of indices for the target boxes in the current block
- block_start = next_box_idx + block_idx * cxpr_block_size
- target_box_offsets = block_start + tl.arange(0, cxpr_block_size)
- target_box_mask = target_box_offsets < num_boxes
- # Suppress boxes with lower scores that have high IoU
- suppression_mask = tl.load(
- iou_mask_ptr + iou_row_offset + target_box_offsets,
- mask=target_box_mask,
- other=False,
- )
- suppression_mask = tl.cast(suppression_mask, tl.int1)
- # Conditionally store suppression result for high-IoU boxes
- tl.store(
- keep_mask_ptr + target_box_offsets, False, mask=suppression_mask
- )
- # Potential race condition: we need to ensure all threads complete the store before the next
- # iteration otherwise we may load stale data for whether or not a box has been suppressed.
- tl.debug_barrier()
- def nms_triton(
- ious: torch.Tensor,
- scores: torch.Tensor,
- iou_threshold: float,
- ) -> torch.Tensor:
- """Perform NMS given the iou matrix, the scores and the iou threshold
- Args:
- ious: Pairwise IoU tensor of shape (N, N).
- scores: Scores tensor of shape (N,).
- iou_threshold: IoU threshold for suppression.
- Returns:
- Tensor: Indices of kept boxes, sorted by decreasing score.
- """
- assert scores.dim() == 1, "Scores must be 1D"
- iou_mask = ious > iou_threshold
- assert iou_mask.dim() == 2
- assert iou_mask.shape[0] == iou_mask.shape[1] == scores.shape[0]
- assert iou_mask.device == scores.device
- assert iou_mask.dtype == torch.bool
- num_boxes = scores.size(0)
- keep_mask = torch.ones(len(scores), device=scores.device, dtype=torch.bool)
- # Sort boxes by scores in descending order
- _, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True)
- iou_mask = iou_mask[sorted_indices][:, sorted_indices].contiguous()
- # For the suppression stage, we need to process sequentially, but we'll still take
- # advantage of parallelism by processing in blocks in one program.
- stage2_grid = (1,)
- _nms_suppression_kernel[stage2_grid](
- # Tensors
- iou_mask_ptr=iou_mask,
- keep_mask_ptr=keep_mask,
- # Scalars
- num_boxes=num_boxes,
- # Strides
- iou_mask_stride=iou_mask.stride(0),
- )
- # Extract indices of kept boxes
- return sorted_indices[keep_mask]
|