matcher.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Modules to compute the matching cost and solve the corresponding LSAP.
  5. """
  6. import numpy as np
  7. import torch
  8. from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
  9. from scipy.optimize import linear_sum_assignment
  10. from torch import nn
  11. def _do_matching(cost, repeats=1, return_tgt_indices=False, do_filtering=False):
  12. if repeats > 1:
  13. cost = np.tile(cost, (1, repeats))
  14. i, j = linear_sum_assignment(cost)
  15. if do_filtering:
  16. # filter out invalid entries (i.e. those with cost > 1e8)
  17. valid_thresh = 1e8
  18. valid_ijs = [(ii, jj) for ii, jj in zip(i, j) if cost[ii, jj] < valid_thresh]
  19. i, j = zip(*valid_ijs) if len(valid_ijs) > 0 else ([], [])
  20. i, j = np.array(i, dtype=np.int64), np.array(j, dtype=np.int64)
  21. if return_tgt_indices:
  22. return i, j
  23. order = np.argsort(j)
  24. return i[order]
  25. class HungarianMatcher(nn.Module):
  26. """This class computes an assignment between the targets and the predictions of the network
  27. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  28. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  29. while the others are un-matched (and thus treated as non-objects).
  30. """
  31. def __init__(
  32. self,
  33. cost_class: float = 1,
  34. cost_bbox: float = 1,
  35. cost_giou: float = 1,
  36. focal_loss: bool = False,
  37. focal_alpha: float = 0.25,
  38. focal_gamma: float = 2,
  39. ):
  40. """Creates the matcher
  41. Params:
  42. cost_class: This is the relative weight of the classification error in the matching cost
  43. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  44. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  45. """
  46. super().__init__()
  47. self.cost_class = cost_class
  48. self.cost_bbox = cost_bbox
  49. self.cost_giou = cost_giou
  50. self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
  51. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
  52. "all costs cant be 0"
  53. )
  54. self.focal_loss = focal_loss
  55. self.focal_alpha = focal_alpha
  56. self.focal_gamma = focal_gamma
  57. @torch.no_grad()
  58. def forward(self, outputs, batched_targets):
  59. """Performs the matching
  60. Params:
  61. outputs: This is a dict that contains at least these entries:
  62. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  63. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  64. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  65. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  66. objects in the target) containing the class labels
  67. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  68. Returns:
  69. A list of size batch_size, containing tuples of (index_i, index_j) where:
  70. - index_i is the indices of the selected predictions (in order)
  71. - index_j is the indices of the corresponding selected targets (in order)
  72. For each batch element, it holds:
  73. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  74. """
  75. bs, num_queries = outputs["pred_logits"].shape[:2]
  76. # We flatten to compute the cost matrices in a batch
  77. out_prob = self.norm(
  78. outputs["pred_logits"].flatten(0, 1)
  79. ) # [batch_size * num_queries, num_classes]
  80. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  81. # Also concat the target labels and boxes
  82. tgt_bbox = batched_targets["boxes"]
  83. if "positive_map" in batched_targets:
  84. # In this case we have a multi-hot target
  85. positive_map = batched_targets["positive_map"]
  86. assert len(tgt_bbox) == len(positive_map)
  87. if self.focal_loss:
  88. positive_map = positive_map > 1e-4
  89. alpha = self.focal_alpha
  90. gamma = self.focal_gamma
  91. neg_cost_class = (
  92. (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
  93. )
  94. pos_cost_class = (
  95. alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
  96. )
  97. cost_class = (
  98. (pos_cost_class - neg_cost_class).unsqueeze(1)
  99. * positive_map.unsqueeze(0)
  100. ).sum(-1)
  101. else:
  102. # Compute the soft-cross entropy between the predicted token alignment and the GT one for each box
  103. cost_class = -(out_prob.unsqueeze(1) * positive_map.unsqueeze(0)).sum(
  104. -1
  105. )
  106. else:
  107. # In this case we are doing a "standard" cross entropy
  108. tgt_ids = batched_targets["labels"]
  109. assert len(tgt_bbox) == len(tgt_ids)
  110. if self.focal_loss:
  111. alpha = self.focal_alpha
  112. gamma = self.focal_gamma
  113. neg_cost_class = (
  114. (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
  115. )
  116. pos_cost_class = (
  117. alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
  118. )
  119. cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
  120. else:
  121. # Compute the classification cost. Contrary to the loss, we don't use the NLL,
  122. # but approximate it in 1 - proba[target class].
  123. # The 1 is a constant that doesn't change the matching, it can be omitted.
  124. cost_class = -out_prob[:, tgt_ids]
  125. # Compute the L1 cost between boxes
  126. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
  127. assert cost_class.shape == cost_bbox.shape
  128. # Compute the giou cost betwen boxes
  129. cost_giou = -generalized_box_iou(
  130. box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
  131. )
  132. # Final cost matrix
  133. C = (
  134. self.cost_bbox * cost_bbox
  135. + self.cost_class * cost_class
  136. + self.cost_giou * cost_giou
  137. )
  138. C = C.view(bs, num_queries, -1).cpu().numpy()
  139. sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
  140. costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
  141. indices = [_do_matching(c) for c in costs]
  142. batch_idx = torch.as_tensor(
  143. sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
  144. )
  145. src_idx = torch.from_numpy(np.concatenate(indices)).long()
  146. return batch_idx, src_idx
  147. class BinaryHungarianMatcher(nn.Module):
  148. """This class computes an assignment between the targets and the predictions of the network
  149. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  150. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  151. while the others are un-matched (and thus treated as non-objects).
  152. """
  153. def __init__(
  154. self,
  155. cost_class: float = 1,
  156. cost_bbox: float = 1,
  157. cost_giou: float = 1,
  158. ):
  159. """Creates the matcher
  160. Params:
  161. cost_class: This is the relative weight of the classification error in the matching cost
  162. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  163. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  164. """
  165. super().__init__()
  166. self.cost_class = cost_class
  167. self.cost_bbox = cost_bbox
  168. self.cost_giou = cost_giou
  169. self.norm = nn.Sigmoid()
  170. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
  171. "all costs cant be 0"
  172. )
  173. @torch.no_grad()
  174. def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
  175. """Performs the matching
  176. Params:
  177. outputs: This is a dict that contains at least these entries:
  178. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  179. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  180. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  181. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  182. objects in the target) containing the class labels
  183. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  184. Returns:
  185. A list of size batch_size, containing tuples of (index_i, index_j) where:
  186. - index_i is the indices of the selected predictions (in order)
  187. - index_j is the indices of the corresponding selected targets (in order)
  188. For each batch element, it holds:
  189. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  190. """
  191. if repeat_batch != 1:
  192. raise NotImplementedError("please use BinaryHungarianMatcherV2 instead")
  193. bs, num_queries = outputs["pred_logits"].shape[:2]
  194. # We flatten to compute the cost matrices in a batch
  195. out_prob = self.norm(outputs["pred_logits"].flatten(0, 1)).squeeze(
  196. -1
  197. ) # [batch_size * num_queries]
  198. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  199. # Also concat the target labels and boxes
  200. tgt_bbox = batched_targets["boxes"]
  201. # Compute the L1 cost between boxes
  202. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
  203. cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
  204. assert cost_class.shape == cost_bbox.shape
  205. # Compute the giou cost betwen boxes
  206. cost_giou = -generalized_box_iou(
  207. box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
  208. )
  209. # Final cost matrix
  210. C = (
  211. self.cost_bbox * cost_bbox
  212. + self.cost_class * cost_class
  213. + self.cost_giou * cost_giou
  214. )
  215. C = C.view(bs, num_queries, -1).cpu().numpy()
  216. sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
  217. costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
  218. return_tgt_indices = False
  219. for c in costs:
  220. n_targ = c.shape[1]
  221. if repeats > 1:
  222. n_targ *= repeats
  223. if c.shape[0] < n_targ:
  224. return_tgt_indices = True
  225. break
  226. if return_tgt_indices:
  227. indices, tgt_indices = zip(
  228. *(
  229. _do_matching(
  230. c, repeats=repeats, return_tgt_indices=return_tgt_indices
  231. )
  232. for c in costs
  233. )
  234. )
  235. tgt_indices = list(tgt_indices)
  236. for i in range(1, len(tgt_indices)):
  237. tgt_indices[i] += sizes[i - 1].item()
  238. tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long()
  239. else:
  240. indices = [_do_matching(c, repeats=repeats) for c in costs]
  241. tgt_idx = None
  242. batch_idx = torch.as_tensor(
  243. sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
  244. )
  245. src_idx = torch.from_numpy(np.concatenate(indices)).long()
  246. return batch_idx, src_idx, tgt_idx
  247. class BinaryFocalHungarianMatcher(nn.Module):
  248. """This class computes an assignment between the targets and the predictions of the network
  249. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  250. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  251. while the others are un-matched (and thus treated as non-objects).
  252. """
  253. def __init__(
  254. self,
  255. cost_class: float = 1,
  256. cost_bbox: float = 1,
  257. cost_giou: float = 1,
  258. alpha: float = 0.25,
  259. gamma: float = 2.0,
  260. stable: bool = False,
  261. ):
  262. """Creates the matcher
  263. Params:
  264. cost_class: This is the relative weight of the classification error in the matching cost
  265. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  266. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  267. """
  268. super().__init__()
  269. self.cost_class = cost_class
  270. self.cost_bbox = cost_bbox
  271. self.cost_giou = cost_giou
  272. self.norm = nn.Sigmoid()
  273. self.alpha = alpha
  274. self.gamma = gamma
  275. self.stable = stable
  276. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
  277. "all costs cant be 0"
  278. )
  279. @torch.no_grad()
  280. def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
  281. """Performs the matching
  282. Params:
  283. outputs: This is a dict that contains at least these entries:
  284. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  285. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  286. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  287. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  288. objects in the target) containing the class labels
  289. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  290. Returns:
  291. A list of size batch_size, containing tuples of (index_i, index_j) where:
  292. - index_i is the indices of the selected predictions (in order)
  293. - index_j is the indices of the corresponding selected targets (in order)
  294. For each batch element, it holds:
  295. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  296. """
  297. if repeat_batch != 1:
  298. raise NotImplementedError("please use BinaryHungarianMatcherV2 instead")
  299. bs, num_queries = outputs["pred_logits"].shape[:2]
  300. # We flatten to compute the cost matrices in a batch
  301. out_score = outputs["pred_logits"].flatten(0, 1).squeeze(-1)
  302. out_prob = self.norm(out_score) # [batch_size * num_queries]
  303. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  304. # Also concat the target labels and boxes
  305. tgt_bbox = batched_targets["boxes"]
  306. # Compute the L1 cost between boxes
  307. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
  308. # Compute the giou cost betwen boxes
  309. cost_giou = -generalized_box_iou(
  310. box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
  311. )
  312. # cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
  313. if self.stable:
  314. rescaled_giou = (-cost_giou + 1) / 2
  315. out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou
  316. cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log(
  317. out_prob
  318. ) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob)
  319. else:
  320. # directly computing log sigmoid (more numerically stable)
  321. log_out_prob = torch.nn.functional.logsigmoid(out_score)
  322. log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score)
  323. cost_class = (
  324. -self.alpha * (1 - out_prob) ** self.gamma * log_out_prob
  325. + (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob
  326. )
  327. if not self.stable:
  328. cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)
  329. assert cost_class.shape == cost_bbox.shape
  330. # Final cost matrix
  331. C = (
  332. self.cost_bbox * cost_bbox
  333. + self.cost_class * cost_class
  334. + self.cost_giou * cost_giou
  335. )
  336. C = C.view(bs, num_queries, -1).cpu().numpy()
  337. sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1]
  338. costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))]
  339. return_tgt_indices = False
  340. for c in costs:
  341. n_targ = c.shape[1]
  342. if repeats > 1:
  343. n_targ *= repeats
  344. if c.shape[0] < n_targ:
  345. return_tgt_indices = True
  346. break
  347. if return_tgt_indices:
  348. indices, tgt_indices = zip(
  349. *(
  350. _do_matching(
  351. c, repeats=repeats, return_tgt_indices=return_tgt_indices
  352. )
  353. for c in costs
  354. )
  355. )
  356. tgt_indices = list(tgt_indices)
  357. for i in range(1, len(tgt_indices)):
  358. tgt_indices[i] += sizes[i - 1].item()
  359. tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long()
  360. else:
  361. indices = [_do_matching(c, repeats=repeats) for c in costs]
  362. tgt_idx = None
  363. batch_idx = torch.as_tensor(
  364. sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long
  365. )
  366. src_idx = torch.from_numpy(np.concatenate(indices)).long()
  367. return batch_idx, src_idx, tgt_idx
  368. class BinaryHungarianMatcherV2(nn.Module):
  369. """
  370. This class computes an assignment between the targets and the predictions
  371. of the network
  372. For efficiency reasons, the targets don't include the no_object. Because of
  373. this, in general, there are more predictions than targets. In this case, we
  374. do a 1-to-1 matching of the best predictions, while the others are
  375. un-matched (and thus treated as non-objects).
  376. This is a more efficient implementation of BinaryHungarianMatcher.
  377. """
  378. def __init__(
  379. self,
  380. cost_class: float = 1,
  381. cost_bbox: float = 1,
  382. cost_giou: float = 1,
  383. focal: bool = False,
  384. alpha: float = 0.25,
  385. gamma: float = 2.0,
  386. stable: bool = False,
  387. remove_samples_with_0_gt: bool = True,
  388. ):
  389. """
  390. Creates the matcher
  391. Params:
  392. - cost_class: Relative weight of the classification error in the
  393. matching cost
  394. - cost_bbox: Relative weight of the L1 error of the bounding box
  395. coordinates in the matching cost
  396. - cost_giou: This is the relative weight of the giou loss of the
  397. bounding box in the matching cost
  398. """
  399. super().__init__()
  400. self.cost_class = cost_class
  401. self.cost_bbox = cost_bbox
  402. self.cost_giou = cost_giou
  403. self.norm = nn.Sigmoid()
  404. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
  405. "all costs cant be 0"
  406. )
  407. self.focal = focal
  408. if focal:
  409. self.alpha = alpha
  410. self.gamma = gamma
  411. self.stable = stable
  412. self.remove_samples_with_0_gt = remove_samples_with_0_gt
  413. @torch.no_grad()
  414. def forward(
  415. self,
  416. outputs,
  417. batched_targets,
  418. repeats=1,
  419. repeat_batch=1,
  420. out_is_valid=None,
  421. target_is_valid_padded=None,
  422. ):
  423. """
  424. Performs the matching. The inputs and outputs are the same as
  425. BinaryHungarianMatcher.forward, except for the optional cached_padded
  426. flag and the optional "_boxes_padded" entry of batched_targets.
  427. Inputs:
  428. - outputs: A dict with the following keys:
  429. - "pred_logits": Tensor of shape (batch_size, num_queries, 1) with
  430. classification logits
  431. - "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with
  432. predicted box coordinates in cxcywh format.
  433. - batched_targets: A dict of targets. There may be a variable number of
  434. targets per batch entry; suppose that there are T_b targets for batch
  435. entry 0 <= b < batch_size. It should have the following keys:
  436. - "boxes": Tensor of shape (sum_b T_b, 4) giving ground-truth boxes
  437. in cxcywh format for all batch entries packed into a single tensor
  438. - "num_boxes": int64 Tensor of shape (batch_size,) giving the number
  439. of ground-truth boxes per batch entry: num_boxes[b] = T_b
  440. - "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving
  441. a padded version of ground-truth boxes. If this is not present then
  442. it will be computed from batched_targets["boxes"] instead, but
  443. caching it here can improve performance for repeated calls with the
  444. same targets.
  445. - out_is_valid: If not None, it should be a boolean tensor of shape
  446. (batch_size, num_queries) indicating which predictions are valid.
  447. Invalid predictions are ignored during matching and won't appear in
  448. the output indices.
  449. - target_is_valid_padded: If not None, it should be a boolean tensor of
  450. shape (batch_size, max_num_gt_boxes) in padded format indicating
  451. which GT boxes are valid. Invalid GT boxes are ignored during matching
  452. and won't appear in the output indices.
  453. Returns:
  454. A list of size batch_size, containing tuples of (idx_i, idx_j):
  455. - idx_i is the indices of the selected predictions (in order)
  456. - idx_j is the indices of the corresponding selected targets
  457. (in order)
  458. For each batch element, it holds:
  459. len(index_i) = len(index_j)
  460. = min(num_queries, num_target_boxes)
  461. """
  462. _, num_queries = outputs["pred_logits"].shape[:2]
  463. out_score = outputs["pred_logits"].squeeze(-1) # (B, Q)
  464. out_bbox = outputs["pred_boxes"] # (B, Q, 4))
  465. device = out_score.device
  466. num_boxes = batched_targets["num_boxes"].cpu()
  467. # Get a padded version of target boxes (as precomputed in the collator).
  468. # It should work for both repeat==1 (o2o) and repeat>1 (o2m) matching.
  469. tgt_bbox = batched_targets["boxes_padded"]
  470. if self.remove_samples_with_0_gt:
  471. # keep only samples w/ at least 1 GT box in targets (num_boxes and tgt_bbox)
  472. batch_keep = num_boxes > 0
  473. num_boxes = num_boxes[batch_keep]
  474. tgt_bbox = tgt_bbox[batch_keep]
  475. if target_is_valid_padded is not None:
  476. target_is_valid_padded = target_is_valid_padded[batch_keep]
  477. # Repeat the targets (for the case of batched aux outputs in the matcher)
  478. if repeat_batch > 1:
  479. # In this case, out_prob and out_bbox will be a concatenation of
  480. # both final and auxiliary outputs, so we also repeat the targets
  481. num_boxes = num_boxes.repeat(repeat_batch)
  482. tgt_bbox = tgt_bbox.repeat(repeat_batch, 1, 1)
  483. if target_is_valid_padded is not None:
  484. target_is_valid_padded = target_is_valid_padded.repeat(repeat_batch, 1)
  485. # keep only samples w/ at least 1 GT box in outputs
  486. if self.remove_samples_with_0_gt:
  487. if repeat_batch > 1:
  488. batch_keep = batch_keep.repeat(repeat_batch)
  489. out_score = out_score[batch_keep]
  490. out_bbox = out_bbox[batch_keep]
  491. if out_is_valid is not None:
  492. out_is_valid = out_is_valid[batch_keep]
  493. assert out_bbox.shape[0] == tgt_bbox.shape[0]
  494. assert out_bbox.shape[0] == num_boxes.shape[0]
  495. # Compute the L1 cost between boxes
  496. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
  497. # Compute the giou cost betwen boxes
  498. cost_giou = -generalized_box_iou(
  499. box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
  500. )
  501. out_prob = self.norm(out_score)
  502. if not self.focal:
  503. cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox)
  504. else:
  505. if self.stable:
  506. rescaled_giou = (-cost_giou + 1) / 2
  507. out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou
  508. cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log(
  509. out_prob
  510. ) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob)
  511. else:
  512. # directly computing log sigmoid (more numerically stable)
  513. log_out_prob = torch.nn.functional.logsigmoid(out_score)
  514. log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score)
  515. cost_class = (
  516. -self.alpha * (1 - out_prob) ** self.gamma * log_out_prob
  517. + (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob
  518. )
  519. if not self.stable:
  520. cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)
  521. assert cost_class.shape == cost_bbox.shape
  522. # Final cost matrix
  523. C = (
  524. self.cost_bbox * cost_bbox
  525. + self.cost_class * cost_class
  526. + self.cost_giou * cost_giou
  527. )
  528. # assign a very high cost (1e9) to invalid outputs and targets, so that we can
  529. # filter them out (in `_do_matching`) from bipartite matching results
  530. do_filtering = out_is_valid is not None or target_is_valid_padded is not None
  531. if out_is_valid is not None:
  532. C = torch.where(out_is_valid[:, :, None], C, 1e9)
  533. if target_is_valid_padded is not None:
  534. C = torch.where(target_is_valid_padded[:, None, :], C, 1e9)
  535. C = C.cpu().numpy()
  536. costs = [C[i, :, :s] for i, s in enumerate(num_boxes.tolist())]
  537. return_tgt_indices = (
  538. do_filtering or torch.any(num_queries < num_boxes * max(repeats, 1)).item()
  539. )
  540. if len(costs) == 0:
  541. # We have size 0 in the batch dimension, so we return empty matching indices
  542. # (note that this can happen due to `remove_samples_with_0_gt=True` even if
  543. # the original input batch size is not 0, when all queries have empty GTs).
  544. indices = []
  545. tgt_idx = torch.zeros(0).long().to(device) if return_tgt_indices else None
  546. elif return_tgt_indices:
  547. indices, tgt_indices = zip(
  548. *(
  549. _do_matching(
  550. c,
  551. repeats=repeats,
  552. return_tgt_indices=return_tgt_indices,
  553. do_filtering=do_filtering,
  554. )
  555. for c in costs
  556. )
  557. )
  558. tgt_indices = list(tgt_indices)
  559. sizes = torch.cumsum(num_boxes, -1)[:-1]
  560. for i in range(1, len(tgt_indices)):
  561. tgt_indices[i] += sizes[i - 1].item()
  562. tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long().to(device)
  563. else:
  564. indices = [
  565. _do_matching(c, repeats=repeats, do_filtering=do_filtering)
  566. for c in costs
  567. ]
  568. tgt_idx = None
  569. if self.remove_samples_with_0_gt:
  570. kept_inds = batch_keep.nonzero().squeeze(1)
  571. batch_idx = torch.as_tensor(
  572. sum([[kept_inds[i]] * len(src) for i, src in enumerate(indices)], []),
  573. dtype=torch.long,
  574. device=device,
  575. )
  576. else:
  577. batch_idx = torch.as_tensor(
  578. sum([[i] * len(src) for i, src in enumerate(indices)], []),
  579. dtype=torch.long,
  580. device=device,
  581. )
  582. # indices could be an empty list (since we remove samples w/ 0 GT boxes)
  583. if len(indices) > 0:
  584. src_idx = torch.from_numpy(np.concatenate(indices)).long().to(device)
  585. else:
  586. src_idx = torch.empty(0, dtype=torch.long, device=device)
  587. return batch_idx, src_idx, tgt_idx
  588. class BinaryOneToManyMatcher(nn.Module):
  589. """
  590. This class computes a greedy assignment between the targets and the predictions of the network.
  591. In this formulation, several predictions can be assigned to each target, but each prediction can be assigned to
  592. at most one target.
  593. See DAC-Detr for details
  594. """
  595. def __init__(
  596. self,
  597. alpha: float = 0.3,
  598. threshold: float = 0.4,
  599. topk: int = 6,
  600. ):
  601. """
  602. Creates the matcher
  603. Params:
  604. alpha: relative balancing between classification and localization
  605. threshold: threshold used to select positive predictions
  606. topk: number of top scoring predictions to consider
  607. """
  608. super().__init__()
  609. self.norm = nn.Sigmoid()
  610. self.alpha = alpha
  611. self.threshold = threshold
  612. self.topk = topk
  613. @torch.no_grad()
  614. def forward(
  615. self,
  616. outputs,
  617. batched_targets,
  618. repeats=1,
  619. repeat_batch=1,
  620. out_is_valid=None,
  621. target_is_valid_padded=None,
  622. ):
  623. """
  624. Performs the matching. The inputs and outputs are the same as
  625. BinaryHungarianMatcher.forward
  626. Inputs:
  627. - outputs: A dict with the following keys:
  628. - "pred_logits": Tensor of shape (batch_size, num_queries, 1) with
  629. classification logits
  630. - "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with
  631. predicted box coordinates in cxcywh format.
  632. - batched_targets: A dict of targets. There may be a variable number of
  633. targets per batch entry; suppose that there are T_b targets for batch
  634. entry 0 <= b < batch_size. It should have the following keys:
  635. - "num_boxes": int64 Tensor of shape (batch_size,) giving the number
  636. of ground-truth boxes per batch entry: num_boxes[b] = T_b
  637. - "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving
  638. a padded version of ground-truth boxes. If this is not present then
  639. it will be computed from batched_targets["boxes"] instead, but
  640. caching it here can improve performance for repeated calls with the
  641. same targets.
  642. - out_is_valid: If not None, it should be a boolean tensor of shape
  643. (batch_size, num_queries) indicating which predictions are valid.
  644. Invalid predictions are ignored during matching and won't appear in
  645. the output indices.
  646. - target_is_valid_padded: If not None, it should be a boolean tensor of
  647. shape (batch_size, max_num_gt_boxes) in padded format indicating
  648. which GT boxes are valid. Invalid GT boxes are ignored during matching
  649. and won't appear in the output indices.
  650. Returns:
  651. A list of size batch_size, containing tuples of (idx_i, idx_j):
  652. - idx_i is the indices of the selected predictions (in order)
  653. - idx_j is the indices of the corresponding selected targets
  654. (in order)
  655. For each batch element, it holds:
  656. len(index_i) = len(index_j)
  657. = min(num_queries, num_target_boxes)
  658. """
  659. assert repeats <= 1 and repeat_batch <= 1
  660. bs, num_queries = outputs["pred_logits"].shape[:2]
  661. out_prob = self.norm(outputs["pred_logits"]).squeeze(-1) # (B, Q)
  662. out_bbox = outputs["pred_boxes"] # (B, Q, 4))
  663. num_boxes = batched_targets["num_boxes"]
  664. # Get a padded version of target boxes (as precomputed in the collator).
  665. tgt_bbox = batched_targets["boxes_padded"]
  666. assert len(tgt_bbox) == bs
  667. num_targets = tgt_bbox.shape[1]
  668. if num_targets == 0:
  669. return (
  670. torch.empty(0, dtype=torch.long, device=out_prob.device),
  671. torch.empty(0, dtype=torch.long, device=out_prob.device),
  672. torch.empty(0, dtype=torch.long, device=out_prob.device),
  673. )
  674. iou, _ = box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
  675. assert iou.shape == (bs, num_queries, num_targets)
  676. # Final cost matrix (higher is better in `C`; this is unlike the case
  677. # of BinaryHungarianMatcherV2 above where lower is better in its `C`)
  678. C = self.alpha * out_prob.unsqueeze(-1) + (1 - self.alpha) * iou
  679. if out_is_valid is not None:
  680. C = torch.where(out_is_valid[:, :, None], C, -1e9)
  681. if target_is_valid_padded is not None:
  682. C = torch.where(target_is_valid_padded[:, None, :], C, -1e9)
  683. # Selecting topk predictions
  684. matches = C > torch.quantile(
  685. C, 1 - self.topk / num_queries, dim=1, keepdim=True
  686. )
  687. # Selecting predictions above threshold
  688. matches = matches & (C > self.threshold)
  689. if out_is_valid is not None:
  690. matches = matches & out_is_valid[:, :, None]
  691. if target_is_valid_padded is not None:
  692. matches = matches & target_is_valid_padded[:, None, :]
  693. # Removing padding
  694. matches = matches & (
  695. torch.arange(0, num_targets, device=num_boxes.device)[None]
  696. < num_boxes[:, None]
  697. ).unsqueeze(1)
  698. batch_idx, src_idx, tgt_idx = torch.nonzero(matches, as_tuple=True)
  699. cum_num_boxes = torch.cat(
  700. [
  701. torch.zeros(1, dtype=num_boxes.dtype, device=num_boxes.device),
  702. num_boxes.cumsum(-1)[:-1],
  703. ]
  704. )
  705. tgt_idx += cum_num_boxes[batch_idx]
  706. return batch_idx, src_idx, tgt_idx