nms.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. import numpy as np
  5. import torch
  6. from sam3.perflib.masks_ops import mask_iou
  7. try:
  8. from torch_generic_nms import generic_nms as generic_nms_cuda
  9. GENERIC_NMS_AVAILABLE = True
  10. except ImportError:
  11. logging.debug(
  12. "Falling back to triton or CPU mask NMS implementation -- please install `torch_generic_nms` via\n\t"
  13. 'pip uninstall -y torch_generic_nms; TORCH_CUDA_ARCH_LIST="8.0 9.0" pip install git+https://github.com/ronghanghu/torch_generic_nms'
  14. )
  15. GENERIC_NMS_AVAILABLE = False
  16. def nms_masks(
  17. pred_probs: torch.Tensor,
  18. pred_masks: torch.Tensor,
  19. prob_threshold: float,
  20. iou_threshold: float,
  21. ) -> torch.Tensor:
  22. """
  23. Args:
  24. - pred_probs: (num_det,) float Tensor, containing the score (probability) of each detection
  25. - pred_masks: (num_det, H_mask, W_mask) float Tensor, containing the binary segmentation mask of each detection
  26. - prob_threshold: float, score threshold to prefilter detections (NMS is performed on detections above threshold)
  27. - iou_threshold: float, mask IoU threshold for NMS
  28. Returns:
  29. - keep: (num_det,) bool Tensor, indicating whether each detection is kept after score thresholding + NMS
  30. """
  31. # prefilter the detections with prob_threshold ("valid" are those above prob_threshold)
  32. is_valid = pred_probs > prob_threshold # (num_det,)
  33. probs = pred_probs[is_valid] # (num_valid,)
  34. masks_binary = pred_masks[is_valid] > 0 # (num_valid, H_mask, W_mask)
  35. if probs.numel() == 0:
  36. return is_valid # no valid detection, return empty keep mask
  37. ious = mask_iou(masks_binary, masks_binary) # (num_valid, num_valid)
  38. kept_inds = generic_nms(ious, probs, iou_threshold)
  39. # valid_inds are the indices among `probs` of valid detections before NMS (or -1 for invalid)
  40. valid_inds = torch.where(is_valid, is_valid.cumsum(dim=0) - 1, -1) # (num_det,)
  41. keep = torch.isin(valid_inds, kept_inds) # (num_det,)
  42. return keep
  43. def generic_nms(
  44. ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5
  45. ) -> torch.Tensor:
  46. """A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix."""
  47. assert ious.dim() == 2 and ious.size(0) == ious.size(1)
  48. assert scores.dim() == 1 and scores.size(0) == ious.size(0)
  49. if ious.is_cuda:
  50. if GENERIC_NMS_AVAILABLE:
  51. return generic_nms_cuda(ious, scores, iou_threshold, use_iou_matrix=True)
  52. else:
  53. from sam3.perflib.triton.nms import nms_triton
  54. return nms_triton(ious, scores, iou_threshold)
  55. return generic_nms_cpu(ious, scores, iou_threshold)
  56. def generic_nms_cpu(
  57. ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5
  58. ) -> torch.Tensor:
  59. """
  60. A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix. (CPU implementation
  61. based on https://github.com/jwyang/faster-rcnn.pytorch/blob/master/lib/model/nms/nms_cpu.py)
  62. """
  63. ious_np = ious.float().detach().cpu().numpy()
  64. scores_np = scores.float().detach().cpu().numpy()
  65. order = scores_np.argsort()[::-1]
  66. kept_inds = []
  67. while order.size > 0:
  68. i = order.item(0)
  69. kept_inds.append(i)
  70. inds = np.where(ious_np[i, order[1:]] <= iou_threshold)[0]
  71. order = order[inds + 1]
  72. return torch.tensor(kept_inds, dtype=torch.int64, device=scores.device)