associate_det_trk.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from collections import defaultdict
  4. import torch
  5. import torch.nn.functional as F
  6. from sam3.perflib.masks_ops import mask_iou
  7. from scipy.optimize import linear_sum_assignment
  8. def associate_det_trk(
  9. det_masks,
  10. track_masks,
  11. iou_threshold=0.5,
  12. iou_threshold_trk=0.5,
  13. det_scores=None,
  14. new_det_thresh=0.0,
  15. ):
  16. """
  17. Optimized implementation of detection <-> track association that minimizes DtoH syncs.
  18. Args:
  19. det_masks: (N, H, W) tensor of predicted masks
  20. track_masks: (M, H, W) tensor of track masks
  21. Returns:
  22. new_det_indices: list of indices in det_masks considered 'new'
  23. unmatched_trk_indices: list of indices in track_masks considered 'unmatched'
  24. """
  25. with torch.autograd.profiler.record_function("perflib: associate_det_trk"):
  26. assert isinstance(det_masks, torch.Tensor), "det_masks should be a tensor"
  27. assert isinstance(track_masks, torch.Tensor), "track_masks should be a tensor"
  28. if det_masks.size(0) == 0 or track_masks.size(0) == 0:
  29. return list(range(det_masks.size(0))), [], {}, {} # all detections are new
  30. if list(det_masks.shape[-2:]) != list(track_masks.shape[-2:]):
  31. # resize to the smaller size to save GPU memory
  32. if torch.numel(det_masks[-2:]) < torch.numel(track_masks[-2:]):
  33. track_masks = (
  34. F.interpolate(
  35. track_masks.unsqueeze(1).float(),
  36. size=det_masks.shape[-2:],
  37. mode="bilinear",
  38. align_corners=False,
  39. ).squeeze(1)
  40. > 0
  41. )
  42. else:
  43. # resize detections to track size
  44. det_masks = (
  45. F.interpolate(
  46. det_masks.unsqueeze(1).float(),
  47. size=track_masks.shape[-2:],
  48. mode="bilinear",
  49. align_corners=False,
  50. ).squeeze(1)
  51. > 0
  52. )
  53. det_masks = det_masks > 0
  54. track_masks = track_masks > 0
  55. iou = mask_iou(det_masks, track_masks) # (N, M)
  56. igeit = iou >= iou_threshold
  57. igeit_any_dim_1 = igeit.any(dim=1)
  58. igeit_trk = iou >= iou_threshold_trk
  59. iou_list = iou.cpu().numpy().tolist()
  60. igeit_list = igeit.cpu().numpy().tolist()
  61. igeit_any_dim_1_list = igeit_any_dim_1.cpu().numpy().tolist()
  62. igeit_trk_list = igeit_trk.cpu().numpy().tolist()
  63. det_scores_list = (
  64. det_scores
  65. if det_scores is None
  66. else det_scores.cpu().float().numpy().tolist()
  67. )
  68. # Hungarian matching for tracks (one-to-one: each track matches at most one detection)
  69. # For detections: allow many tracks to match to the same detection (many-to-one)
  70. # If either is empty, return all detections as new
  71. if det_masks.size(0) == 0 or track_masks.size(0) == 0:
  72. return list(range(det_masks.size(0))), [], {}
  73. # Hungarian matching: maximize IoU for tracks
  74. cost_matrix = 1 - iou.cpu().numpy() # Hungarian solves for minimum cost
  75. row_ind, col_ind = linear_sum_assignment(cost_matrix)
  76. def branchy_hungarian_better_uses_the_cpu(
  77. cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
  78. ):
  79. matched_trk = set()
  80. matched_det = set()
  81. matched_det_scores = {} # track index -> [det_score, det_score * iou] det score of matched detection mask
  82. for d, t in zip(row_ind, col_ind):
  83. matched_det_scores[t] = [
  84. det_scores_list[d],
  85. det_scores_list[d] * iou_list[d][t],
  86. ]
  87. if igeit_trk_list[d][t]:
  88. matched_trk.add(t)
  89. matched_det.add(d)
  90. # Tracks not matched by Hungarian assignment above threshold are unmatched
  91. unmatched_trk_indices = [
  92. t for t in range(track_masks.size(0)) if t not in matched_trk
  93. ]
  94. # For detections: allow many tracks to match to the same detection (many-to-one)
  95. # So, a detection is 'new' if it does not match any track above threshold
  96. assert track_masks.size(0) == igeit.size(
  97. 1
  98. ) # Needed for loop optimizaiton below
  99. new_det_indices = []
  100. for d in range(det_masks.size(0)):
  101. if not igeit_any_dim_1_list[d]:
  102. if det_scores is not None and det_scores[d] >= new_det_thresh:
  103. new_det_indices.append(d)
  104. # for each detection, which tracks it matched to (above threshold)
  105. det_to_matched_trk = defaultdict(list)
  106. for d in range(det_masks.size(0)):
  107. for t in range(track_masks.size(0)):
  108. if igeit_list[d][t]:
  109. det_to_matched_trk[d].append(t)
  110. return (
  111. new_det_indices,
  112. unmatched_trk_indices,
  113. det_to_matched_trk,
  114. matched_det_scores,
  115. )
  116. return (branchy_hungarian_better_uses_the_cpu)(
  117. cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
  118. )