loss_fns.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. import warnings
  5. import torch
  6. import torch.distributed
  7. import torch.nn.functional as F
  8. import torchmetrics
  9. from sam3.model import box_ops
  10. from sam3.model.data_misc import interpolate
  11. from sam3.train.loss.sigmoid_focal_loss import (
  12. triton_sigmoid_focal_loss,
  13. triton_sigmoid_focal_loss_reduce,
  14. )
  15. from torch import nn
  16. from .mask_sampling import (
  17. calculate_uncertainty,
  18. get_uncertain_point_coords_with_randomness,
  19. point_sample,
  20. )
  21. CORE_LOSS_KEY = "core_loss"
  22. def instance_masks_to_semantic_masks(
  23. instance_masks: torch.Tensor, num_instances: torch.Tensor
  24. ) -> torch.Tensor:
  25. """This function converts instance masks to semantic masks.
  26. It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and
  27. the number of instances in each image of the batch.
  28. It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the
  29. semantic mask is the union of all the instance masks in the batch element.
  30. If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros.
  31. Args:
  32. instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch.
  33. num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances
  34. in each image of the batch.
  35. Returns:
  36. torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the
  37. input instance masks.
  38. """
  39. if num_instances.sum() == 0:
  40. # all negative batch, create a tensor of zeros (B, 1, 1)
  41. return num_instances.unsqueeze(-1).unsqueeze(-1)
  42. masks_per_query = torch.split(instance_masks, num_instances.tolist())
  43. return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0)
  44. @torch.no_grad()
  45. def accuracy(output, target, topk=(1,)):
  46. """Computes the precision@k for the specified values of k"""
  47. if target.numel() == 0:
  48. return [torch.zeros([], device=output.device)]
  49. maxk = max(topk)
  50. batch_size = target.size(0)
  51. _, pred = output.topk(maxk, 1, True, True)
  52. pred = pred.t()
  53. correct = pred.eq(target.view(1, -1).expand_as(pred))
  54. res = []
  55. for k in topk:
  56. correct_k = correct[:k].view(-1).float().sum(0)
  57. res.append(correct_k.mul_(100.0 / batch_size))
  58. return res
  59. def dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True):
  60. """
  61. Compute the DICE loss, similar to generalized IOU for masks
  62. Args:
  63. inputs: A float tensor of arbitrary shape.
  64. The predictions for each example.
  65. targets: A float tensor with the same shape as inputs. Stores the binary
  66. classification label for each element in inputs
  67. (0 for the negative class and 1 for the positive class).
  68. """
  69. try:
  70. loss = _dice_loss(inputs, targets, num_boxes, loss_on_multimask, reduce)
  71. except torch.OutOfMemoryError:
  72. logging.error("GPU OOM, computing dice loss on CPU")
  73. # try to recover from GPU OOM by moving tensors to CPU and computing loss there
  74. orig_device = inputs.device
  75. inputs = inputs.cpu()
  76. targets = targets.cpu()
  77. if isinstance(num_boxes, torch.Tensor):
  78. num_boxes = num_boxes.cpu()
  79. loss = _dice_loss(inputs, targets, num_boxes, loss_on_multimask, reduce)
  80. loss = loss.to(orig_device)
  81. return loss
  82. def _dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True):
  83. inputs = inputs.sigmoid()
  84. if loss_on_multimask:
  85. # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
  86. assert inputs.dim() == 4 and targets.dim() == 4
  87. # flatten spatial dimension while keeping multimask channel dimension
  88. inputs = inputs.flatten(2)
  89. targets = targets.flatten(2)
  90. numerator = 2 * (inputs * targets).sum(-1)
  91. else:
  92. inputs = inputs.flatten(1)
  93. numerator = 2 * (inputs * targets).sum(1)
  94. denominator = inputs.sum(-1) + targets.sum(-1)
  95. loss = 1 - (numerator + 1) / (denominator + 1)
  96. if loss_on_multimask:
  97. return loss / num_boxes
  98. if not reduce:
  99. return loss
  100. return loss.sum() / num_boxes
  101. def sigmoid_focal_loss(
  102. inputs,
  103. targets,
  104. num_boxes,
  105. alpha: float = 0.25,
  106. gamma: float = 2,
  107. loss_on_multimask=False,
  108. reduce=True,
  109. triton=True,
  110. ):
  111. """
  112. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  113. Args:
  114. inputs: A float tensor of arbitrary shape.
  115. The predictions for each example.
  116. targets: A float tensor with the same shape as inputs. Stores the binary
  117. classification label for each element in inputs
  118. (0 for the negative class and 1 for the positive class).
  119. alpha: (optional) Weighting factor in range (0,1) to balance
  120. positive vs negative examples. Default = -1 (no weighting).
  121. gamma: Exponent of the modulating factor (1 - p_t) to
  122. balance easy vs hard examples.
  123. Returns:
  124. Loss tensor
  125. """
  126. if not (0 <= alpha <= 1) and triton:
  127. raise RuntimeError(f"Alpha should be in [0,1], got {alpha}")
  128. if triton:
  129. if reduce and not loss_on_multimask:
  130. loss = triton_sigmoid_focal_loss_reduce(inputs, targets, alpha, gamma)
  131. return loss / (num_boxes * inputs.shape[1])
  132. loss = triton_sigmoid_focal_loss(inputs, targets, alpha, gamma)
  133. else:
  134. prob = inputs.sigmoid()
  135. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  136. p_t = prob * targets + (1 - prob) * (1 - targets)
  137. loss = ce_loss * ((1 - p_t) ** gamma)
  138. if alpha >= 0:
  139. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  140. loss = alpha_t * loss
  141. if not reduce:
  142. return loss
  143. if loss_on_multimask:
  144. # loss is [N, M, H, W] where M corresponds to multiple predicted masks
  145. assert loss.dim() == 4
  146. return loss.flatten(2).mean(-1) / num_boxes # average over spatial dims
  147. return loss.mean(1).sum() / num_boxes
  148. def iou_loss(
  149. inputs, targets, pred_ious, num_boxes, loss_on_multimask=False, use_l1_loss=False
  150. ):
  151. """MSE loss between predicted IoUs and actual IoUs between inputs and targets."""
  152. assert inputs.dim() == 4 and targets.dim() == 4
  153. pred_mask = inputs.flatten(2) > 0
  154. gt_mask = targets.flatten(2) > 0
  155. area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
  156. area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
  157. actual_ious = area_i / torch.clamp(area_u, min=1.0)
  158. if use_l1_loss:
  159. loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
  160. else:
  161. loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
  162. if loss_on_multimask:
  163. return loss / num_boxes
  164. return loss.sum() / num_boxes
  165. @torch.jit.script
  166. def _contrastive_align(logits, positive_map):
  167. positive_logits = -logits.masked_fill(~positive_map, 0)
  168. negative_logits = logits # .masked_fill(positive_map, -1000000)
  169. boxes_with_pos = positive_map.any(2)
  170. pos_term = positive_logits.sum(2)
  171. neg_term = negative_logits.logsumexp(2)
  172. nb_pos = positive_map.sum(2) + 1e-6
  173. box_to_token_loss = (
  174. (pos_term / nb_pos + neg_term).masked_fill(~boxes_with_pos, 0).sum()
  175. )
  176. tokens_with_pos = positive_map.any(1)
  177. pos_term = positive_logits.sum(1)
  178. neg_term = negative_logits.logsumexp(1)
  179. nb_pos = positive_map.sum(1) + 1e-6
  180. tokens_to_boxes_loss = (
  181. (pos_term / nb_pos + neg_term).masked_fill(~tokens_with_pos, 0).sum()
  182. )
  183. return (box_to_token_loss + tokens_to_boxes_loss) / 2
  184. def _get_src_permutation_idx(indices):
  185. # permute predictions following indices
  186. batch_idx = torch.cat(
  187. [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
  188. )
  189. src_idx = torch.cat([src for (src, _) in indices])
  190. return batch_idx, src_idx
  191. class LossWithWeights(nn.Module):
  192. def __init__(self, weight_dict, compute_aux, supports_o2m_loss=True):
  193. super().__init__()
  194. # weights for each computed loss key (those losses not in weight_dict
  195. # will not be aggregated in the final reduced core loss)
  196. self.weight_dict = weight_dict if weight_dict is not None else {}
  197. # whether this loss will be applied on auxiliary outputs
  198. self.compute_aux = compute_aux
  199. self.supports_o2m_loss = supports_o2m_loss
  200. self.target_keys = []
  201. def forward(self, *args, is_aux=False, **kwargs):
  202. if is_aux and not self.compute_aux:
  203. return {CORE_LOSS_KEY: 0.0}
  204. losses = self.get_loss(*args, **kwargs)
  205. losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
  206. return losses
  207. def get_loss(self, **kwargs):
  208. raise NotImplementedError()
  209. def reduce_loss(self, losses):
  210. reduced_loss = 0.0
  211. for loss_key, weight in self.weight_dict.items():
  212. if loss_key not in losses:
  213. raise ValueError(f"{type(self)} doesn't compute {loss_key}")
  214. if weight != 0:
  215. reduced_loss += losses[loss_key] * weight
  216. return reduced_loss
  217. class IABCEMdetr(LossWithWeights):
  218. def __init__(
  219. self,
  220. pos_weight,
  221. weight_dict=None,
  222. compute_aux=True,
  223. gamma=0,
  224. weak_loss=True,
  225. alpha=0.25,
  226. pad_n_queries=None,
  227. pad_scale_pos=1.0,
  228. use_separate_loss_for_det_and_trk=False,
  229. num_det_queries=None,
  230. det_exhaustive_loss_scale_pos=1.0,
  231. det_exhaustive_loss_scale_neg=1.0,
  232. det_non_exhaustive_loss_scale_pos=1.0,
  233. det_non_exhaustive_loss_scale_neg=1.0,
  234. trk_loss_scale_pos=1.0,
  235. trk_loss_scale_neg=1.0,
  236. no_loss_for_fp_propagation=False,
  237. apply_loss_to_det_queries_in_video_grounding=True,
  238. use_presence=False,
  239. use_presence_semgseg=False, # If True, use presence scores from the semgseg head.
  240. presence_alpha=0.5,
  241. presence_gamma=0.0,
  242. pos_focal: bool = False, # for box scores, use focal loss for positives as well
  243. ):
  244. super().__init__(weight_dict, compute_aux)
  245. self.pos_weight = pos_weight
  246. self.gamma = gamma
  247. self.weak_loss = weak_loss
  248. self.alpha = alpha
  249. self.target_keys.append("boxes_xyxy")
  250. self.no_loss_for_fp_propagation = no_loss_for_fp_propagation
  251. if self.weak_loss:
  252. self.target_keys.append("is_exhaustive")
  253. # NOTE: This is hacky solution to have the same CE loss scale across datasets where the model might predict different number of object queries for different tasks.
  254. # If not None, we assume there are a total pad_n_queries object queries.
  255. # For example, if the model predicts only 1 object query and pad_n_queries=100, we pad the predictions with 99 zero preds.
  256. # Currently this only affects the BCE loss and not the F1 score.
  257. self.pad_n_queries = pad_n_queries
  258. self.pad_scale_pos = pad_scale_pos
  259. if self.pad_scale_pos != 1.0:
  260. assert self.pad_n_queries is not None
  261. # whether to use presence scores
  262. self.use_presence = use_presence
  263. self.use_presence_semgseg = use_presence_semgseg
  264. if self.use_presence_semgseg:
  265. assert self.use_presence
  266. self.presence_alpha = presence_alpha
  267. self.presence_gamma = presence_gamma
  268. self.pos_focal = pos_focal
  269. # Decoupled loss for detection and tracking queries
  270. self.apply_loss_to_det_queries_in_video_grounding = (
  271. apply_loss_to_det_queries_in_video_grounding
  272. )
  273. self.use_separate_loss_for_det_and_trk = use_separate_loss_for_det_and_trk
  274. if num_det_queries is not None:
  275. logging.warning("note: it's not needed to set num_det_queries anymore")
  276. if self.use_separate_loss_for_det_and_trk:
  277. assert not self.weak_loss, (
  278. "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
  279. )
  280. self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
  281. self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
  282. self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
  283. self.det_non_exhaustive_loss_scale_neg = det_non_exhaustive_loss_scale_neg
  284. self.trk_loss_scale_pos = trk_loss_scale_pos
  285. self.trk_loss_scale_neg = trk_loss_scale_neg
  286. else:
  287. assert (
  288. det_exhaustive_loss_scale_pos == 1.0
  289. and det_exhaustive_loss_scale_neg == 1.0
  290. and det_non_exhaustive_loss_scale_pos == 1.0
  291. and det_non_exhaustive_loss_scale_neg == 1.0
  292. and trk_loss_scale_pos == 1.0
  293. and trk_loss_scale_neg == 1.0
  294. ), (
  295. "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
  296. )
  297. def get_loss(self, outputs, targets, indices, num_boxes):
  298. assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
  299. assert outputs["pred_logits"].shape[-1] == 1, "Incorrect predicted logits shape"
  300. src_logits = outputs["pred_logits"].squeeze(-1)
  301. prob = src_logits.sigmoid()
  302. with torch.no_grad():
  303. target_classes = torch.full(
  304. src_logits.shape[:2],
  305. 0,
  306. dtype=torch.float,
  307. device=src_logits.device,
  308. )
  309. target_classes[(indices[0], indices[1])] = 1
  310. src_boxes_xyxy = outputs["pred_boxes_xyxy"][(indices[0], indices[1])]
  311. target_boxes_giou = (
  312. targets["boxes_xyxy"][indices[2]]
  313. if indices[2] is not None
  314. else targets["boxes_xyxy"]
  315. )
  316. iou = box_ops.fast_diag_box_iou(src_boxes_xyxy, target_boxes_giou)
  317. t = prob[(indices[0], indices[1])] ** self.alpha * iou ** (1 - self.alpha)
  318. t = torch.clamp(t, 0.01).detach()
  319. positive_target_classes = target_classes.clone()
  320. positive_target_classes[(indices[0], indices[1])] = t
  321. # Soft loss on positives
  322. if self.pos_focal:
  323. loss_bce = sigmoid_focal_loss(
  324. src_logits.contiguous(),
  325. positive_target_classes,
  326. num_boxes=1,
  327. alpha=0.5,
  328. gamma=self.gamma,
  329. reduce=False,
  330. )
  331. else:
  332. loss_bce = F.binary_cross_entropy_with_logits(
  333. src_logits, positive_target_classes, reduction="none"
  334. )
  335. loss_bce = loss_bce * target_classes * self.pos_weight
  336. if (
  337. self.pad_n_queries is not None
  338. and isinstance(self.pad_n_queries, int)
  339. and loss_bce.size(1) < self.pad_n_queries
  340. ):
  341. loss_bce = loss_bce * self.pad_scale_pos
  342. # Negatives
  343. loss_bce = loss_bce + F.binary_cross_entropy_with_logits(
  344. src_logits, target_classes, reduction="none"
  345. ) * (1 - target_classes) * (prob**self.gamma)
  346. # Optionally, not applying IABCEMdetr loss to detection queries in video.
  347. is_video_grounding = outputs.get("is_video_grounding_batch", False)
  348. if is_video_grounding and not self.apply_loss_to_det_queries_in_video_grounding:
  349. Q_det = outputs["Q_det"]
  350. loss_bce[:, :Q_det] *= 0.0
  351. presence_loss = torch.tensor(0.0, device=src_logits.device)
  352. presence_dec_acc = torch.tensor(0.0, device=src_logits.device)
  353. if self.use_presence:
  354. # no classifiction loss for individual tokens if no target gt
  355. # cannot directly use targets["num_boxes"] to check if some
  356. # GT box exists as there may be dummy boxes for "invisible objects"
  357. # in video grounding data
  358. gt_padded_object_ids = targets["object_ids_padded"] # (B, H)
  359. gt_padded_boxes = targets["boxes_padded"] # (B, H, 4) shape, CxCyWH
  360. gt_padded_is_visible = (
  361. (gt_padded_object_ids >= 0)
  362. & (gt_padded_boxes[..., 2] > 0) # width > 0
  363. & (gt_padded_boxes[..., 3] > 0) # height > 0
  364. )
  365. keep_loss = (gt_padded_is_visible.sum(dim=-1)[..., None] != 0).float()
  366. loss_bce = loss_bce * keep_loss
  367. if self.use_presence_semgseg:
  368. # no loss here, has it's own separate loss computation
  369. assert "presence_logit_dec" not in outputs
  370. elif "presence_logit_dec" in outputs:
  371. presence_logits = outputs["presence_logit_dec"].view_as(keep_loss)
  372. bs = presence_logits.shape[0]
  373. presence_loss = sigmoid_focal_loss(
  374. presence_logits,
  375. keep_loss,
  376. # not num_boxes, but we'll use it to normalize by bs
  377. num_boxes=bs,
  378. alpha=self.presence_alpha,
  379. gamma=self.presence_gamma,
  380. )
  381. pred = (presence_logits.sigmoid() > 0.5).float()
  382. presence_dec_acc = (pred == keep_loss).float().mean()
  383. else:
  384. # for o2m, nothing to do
  385. pass
  386. if self.weak_loss:
  387. assert not self.use_separate_loss_for_det_and_trk, (
  388. "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
  389. )
  390. # nullify the negative loss for the non-exhaustive classes
  391. assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
  392. assert targets["is_exhaustive"].ndim == 1
  393. loss_mask = (~targets["is_exhaustive"]).view(-1, 1).expand_as(loss_bce)
  394. # restrict the mask to the negative supervision
  395. loss_mask = loss_mask & (target_classes < 0.5)
  396. loss_mask = ~loss_mask
  397. # Mask the loss
  398. loss_bce = loss_bce * loss_mask.float()
  399. # Average
  400. loss_bce = loss_bce.sum() / (loss_mask.sum() + 1e-6)
  401. else:
  402. # apply separate loss weights to detection and tracking queries
  403. if self.use_separate_loss_for_det_and_trk:
  404. Q_det = outputs["Q_det"]
  405. assert loss_bce.size(1) >= Q_det
  406. is_positive = target_classes > 0.5
  407. is_positive_det = is_positive[:, :Q_det]
  408. is_positive_trk = is_positive[:, Q_det:]
  409. assert loss_bce.size(0) == targets["is_exhaustive"].size(0)
  410. is_exhaustive = targets["is_exhaustive"].unsqueeze(1).bool()
  411. loss_scales = torch.zeros_like(loss_bce)
  412. # detection query loss weights
  413. loss_scales[:, :Q_det] = (
  414. (is_exhaustive & is_positive_det).float()
  415. * self.det_exhaustive_loss_scale_pos
  416. + (is_exhaustive & ~is_positive_det).float()
  417. * self.det_exhaustive_loss_scale_neg
  418. + (~is_exhaustive & is_positive_det).float()
  419. * self.det_non_exhaustive_loss_scale_pos
  420. + (~is_exhaustive & ~is_positive_det).float()
  421. * self.det_non_exhaustive_loss_scale_neg
  422. )
  423. # tracking query weights
  424. loss_scales[:, Q_det:] = (
  425. is_positive_trk.float() * self.trk_loss_scale_pos
  426. + (~is_positive_trk).float() * self.trk_loss_scale_neg
  427. )
  428. # apply the loss weights
  429. # if the id is -2 means it is a fp propagation , we don't apply the loss to them
  430. if self.no_loss_for_fp_propagation:
  431. is_original_queries = outputs["pred_old_obj_ids"] != -2
  432. loss_scales *= (is_exhaustive | is_original_queries).float()
  433. loss_bce = loss_bce * loss_scales
  434. if self.pad_n_queries is None or loss_bce.size(1) >= self.pad_n_queries:
  435. loss_bce = loss_bce.mean()
  436. else:
  437. assert isinstance(self.pad_n_queries, int)
  438. assert loss_bce.size(1) < self.pad_n_queries, (
  439. f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
  440. )
  441. loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
  442. bce_f1 = torchmetrics.functional.f1_score(
  443. src_logits.sigmoid().flatten(),
  444. target=target_classes.flatten().long(),
  445. task="binary",
  446. )
  447. losses = {
  448. "loss_ce": loss_bce,
  449. "ce_f1": bce_f1,
  450. "presence_loss": presence_loss,
  451. "presence_dec_acc": presence_dec_acc,
  452. }
  453. return losses
  454. class Boxes(LossWithWeights):
  455. def __init__(
  456. self,
  457. weight_dict=None,
  458. compute_aux=True,
  459. apply_loss_to_det_queries_in_video_grounding=True,
  460. ):
  461. super().__init__(weight_dict, compute_aux)
  462. self.apply_loss_to_det_queries_in_video_grounding = (
  463. apply_loss_to_det_queries_in_video_grounding
  464. )
  465. self.target_keys.extend(["boxes", "boxes_xyxy"])
  466. def get_loss(self, outputs, targets, indices, num_boxes):
  467. """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
  468. targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
  469. The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
  470. """
  471. # Optionally, not applying Boxes loss to detection queries in video.
  472. is_video_grounding = outputs.get("is_video_grounding_batch", False)
  473. if is_video_grounding and not self.apply_loss_to_det_queries_in_video_grounding:
  474. indices = _keep_only_trk_queries_in_match_inds(
  475. indices, Q_det=outputs["Q_det"]
  476. )
  477. assert "pred_boxes" in outputs
  478. # idx = self._get_src_permutation_idx(indices)
  479. src_boxes = outputs["pred_boxes"][(indices[0], indices[1])]
  480. src_boxes_xyxy = outputs["pred_boxes_xyxy"][(indices[0], indices[1])]
  481. target_boxes = (
  482. targets["boxes"] if indices[2] is None else targets["boxes"][indices[2]]
  483. )
  484. target_boxes_giou = (
  485. targets["boxes_xyxy"]
  486. if indices[2] is None
  487. else targets["boxes_xyxy"][indices[2]]
  488. )
  489. loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
  490. losses = {}
  491. losses["loss_bbox"] = loss_bbox.sum() / num_boxes
  492. loss_giou = 1 - box_ops.fast_diag_generalized_box_iou(
  493. src_boxes_xyxy, target_boxes_giou
  494. )
  495. losses["loss_giou"] = loss_giou.sum() / num_boxes
  496. return losses
  497. class Masks(LossWithWeights):
  498. def __init__(
  499. self,
  500. weight_dict=None,
  501. compute_aux=False,
  502. focal_alpha=0.25,
  503. focal_gamma=2,
  504. num_sample_points=None,
  505. oversample_ratio=None,
  506. importance_sample_ratio=None,
  507. apply_loss_to_det_queries_in_video_grounding=True,
  508. ):
  509. super().__init__(weight_dict, compute_aux)
  510. if compute_aux:
  511. warnings.warn("Masks loss usually shouldn't be applied to aux outputs")
  512. self.focal_alpha = focal_alpha
  513. self.focal_gamma = focal_gamma
  514. self.num_sample_points = num_sample_points
  515. self.oversample_ratio = oversample_ratio
  516. self.importance_sample_ratio = importance_sample_ratio
  517. self.apply_loss_to_det_queries_in_video_grounding = (
  518. apply_loss_to_det_queries_in_video_grounding
  519. )
  520. self.target_keys.extend(["masks", "is_valid_mask"])
  521. def _sampled_loss(self, src_masks, target_masks, num_boxes):
  522. assert len(src_masks.shape) == 3 and len(target_masks.shape) == 3
  523. src_masks = src_masks[:, None]
  524. target_masks = target_masks[:, None]
  525. with torch.no_grad():
  526. # Sample point_coords
  527. point_coords = get_uncertain_point_coords_with_randomness(
  528. src_masks,
  529. calculate_uncertainty,
  530. self.num_sample_points,
  531. self.oversample_ratio,
  532. self.importance_sample_ratio,
  533. )
  534. # get GT labels
  535. sampled_target_masks = point_sample(
  536. target_masks,
  537. point_coords,
  538. align_corners=False,
  539. ).squeeze(1)
  540. sampled_src_masks = point_sample(
  541. src_masks,
  542. point_coords,
  543. align_corners=False,
  544. ).squeeze(1)
  545. losses = {
  546. "loss_mask": sigmoid_focal_loss(
  547. sampled_src_masks,
  548. sampled_target_masks,
  549. num_boxes,
  550. alpha=self.focal_alpha,
  551. gamma=self.focal_gamma,
  552. ),
  553. "loss_dice": dice_loss(sampled_src_masks, sampled_target_masks, num_boxes),
  554. }
  555. # Not needed for backward
  556. del src_masks
  557. del target_masks
  558. return losses
  559. def get_loss(self, outputs, targets, indices, num_boxes):
  560. """Compute the losses related to the masks: the focal loss and the dice loss.
  561. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
  562. """
  563. assert "pred_masks" in outputs
  564. assert "is_valid_mask" in targets
  565. # Optionally, not applying Masks loss to detection queries in video.
  566. is_video_grounding = outputs.get("is_video_grounding_batch", False)
  567. if is_video_grounding and not self.apply_loss_to_det_queries_in_video_grounding:
  568. indices = _keep_only_trk_queries_in_match_inds(
  569. indices, Q_det=outputs["Q_det"]
  570. )
  571. src_masks = outputs["pred_masks"]
  572. # Dataset doesn't have segmentation masks
  573. if targets["masks"] is None:
  574. return {
  575. "loss_mask": torch.tensor(0.0, device=src_masks.device),
  576. "loss_dice": torch.tensor(0.0, device=src_masks.device),
  577. }
  578. target_masks = (
  579. targets["masks"] if indices[2] is None else targets["masks"][indices[2]]
  580. )
  581. target_masks = target_masks.to(src_masks)
  582. keep = (
  583. targets["is_valid_mask"]
  584. if indices[2] is None
  585. else targets["is_valid_mask"][indices[2]]
  586. )
  587. src_masks = src_masks[(indices[0], indices[1])]
  588. # Remove invalid masks from loss
  589. src_masks = src_masks[keep]
  590. target_masks = target_masks[keep]
  591. if self.num_sample_points is not None:
  592. # Compute loss on sampled points for the Mask
  593. losses = self._sampled_loss(src_masks, target_masks, num_boxes)
  594. else:
  595. # upsample predictions to the target size
  596. if target_masks.shape[0] == 0 and src_masks.shape[0] == 0:
  597. src_masks = src_masks.flatten(1)
  598. target_masks = target_masks.reshape(src_masks.shape)
  599. else:
  600. if len(src_masks.shape) == 3:
  601. src_masks = src_masks[:, None]
  602. if src_masks.dtype == torch.bfloat16:
  603. # Bilinear interpolation does not support bf16
  604. src_masks = src_masks.to(dtype=torch.float32)
  605. src_masks = interpolate(
  606. src_masks,
  607. size=target_masks.shape[-2:],
  608. mode="bilinear",
  609. align_corners=False,
  610. )
  611. src_masks = src_masks[:, 0].flatten(1)
  612. target_masks = target_masks.flatten(1)
  613. losses = {
  614. "loss_mask": sigmoid_focal_loss(
  615. src_masks,
  616. target_masks,
  617. num_boxes,
  618. alpha=self.focal_alpha,
  619. gamma=self.focal_gamma,
  620. ),
  621. "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
  622. }
  623. return losses
  624. # class MultiStepIteractiveMasks(LossWithWeights):
  625. # def __init__(
  626. # self,
  627. # weight_dict=None,
  628. # compute_aux=False,
  629. # focal_alpha=0.25,
  630. # focal_gamma=2,
  631. # ):
  632. # warnings.warn(
  633. # "MultiStepIteractiveMasks is deprecated. Please use MultiStepMultiMasksAndIous",
  634. # DeprecationWarning,
  635. # )
  636. # super().__init__(weight_dict, compute_aux)
  637. # self.focal_alpha = focal_alpha
  638. # self.focal_gamma = focal_gamma
  639. # self.target_keys.extend(["masks"])
  640. # def get_loss(self, outputs, targets, indices, num_boxes):
  641. # """Compute the losses related to the masks: the focal loss and the dice loss.
  642. # targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
  643. # Unlike `Masks`, here the "multistep_pred_masks" can have multiple channels, each
  644. # corresponding to one iterative prediction step in SAM-style training. We treat each
  645. # channel as a mask prediction and sum the loss across channels.
  646. # """
  647. # src_masks = outputs["multistep_pred_masks"]
  648. # target_masks = targets["masks"]
  649. # assert src_masks.size(0) == target_masks.size(0)
  650. # assert src_masks.dim() == 4
  651. # assert target_masks.dim() == 3
  652. # # tile target_masks according to the number of
  653. # # channels `src_masks`.
  654. # num_steps = src_masks.size(1)
  655. # target_masks = target_masks.unsqueeze(1).to(src_masks.dtype)
  656. # if num_steps > 1:
  657. # target_masks = target_masks.repeat(1, num_steps, 1, 1)
  658. # # resize `src_masks` to target mask resolution
  659. # if src_masks.shape != target_masks.shape:
  660. # src_masks = interpolate(
  661. # src_masks,
  662. # size=target_masks.shape[-2:],
  663. # mode="bilinear",
  664. # align_corners=False,
  665. # )
  666. # assert src_masks.shape == target_masks.shape
  667. # # flatten the multiple steps in to the batch dimension
  668. # src_masks = src_masks.flatten(0, 1).flatten(1)
  669. # target_masks = target_masks.flatten(0, 1).flatten(1)
  670. # losses = {
  671. # "loss_mask": sigmoid_focal_loss(
  672. # src_masks,
  673. # target_masks,
  674. # num_boxes,
  675. # alpha=self.focal_alpha,
  676. # gamma=self.focal_gamma,
  677. # ),
  678. # "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
  679. # }
  680. # return losses
  681. # class MultiStepMultiMasksAndIous(LossWithWeights):
  682. # def __init__(
  683. # self,
  684. # weight_dict=None,
  685. # compute_aux=False,
  686. # focal_alpha=0.25,
  687. # focal_gamma=2,
  688. # # if True, back-prop on all predicted ious
  689. # # not just the one with lowest loss_combo
  690. # supervise_all_iou=False,
  691. # # Less slack vs MSE loss in [-1, 1] error range
  692. # iou_use_l1_loss=False,
  693. # # Settings for obj score prediction
  694. # pred_obj_scores=False,
  695. # focal_gamma_obj_score=0.0,
  696. # focal_alpha_obj_score=-1,
  697. # ):
  698. # super().__init__(weight_dict, compute_aux)
  699. # self.focal_alpha = focal_alpha
  700. # self.focal_gamma = focal_gamma
  701. # self.target_keys.extend(["masks"])
  702. # assert "loss_mask" in self.weight_dict
  703. # assert "loss_dice" in self.weight_dict
  704. # assert "loss_iou" in self.weight_dict
  705. # if "loss_class" not in self.weight_dict:
  706. # self.weight_dict["loss_class"] = 0.0
  707. # self.focal_alpha_obj_score = focal_alpha_obj_score
  708. # self.focal_gamma_obj_score = focal_gamma_obj_score
  709. # self.supervise_all_iou = supervise_all_iou
  710. # self.iou_use_l1_loss = iou_use_l1_loss
  711. # self.pred_obj_scores = pred_obj_scores
  712. # def get_loss(self, outputs, targets, indices, num_boxes):
  713. # """
  714. # Compute the losses related to the masks: the focal loss and the dice loss.
  715. # and also the MSE loss between predicted IoUs and actual IoUs.
  716. # Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
  717. # of shape [N, M, H, W], where M could be 1 or larger, corresponding to
  718. # one or multiple predicted masks from a click.
  719. # We back-propagate focal, dice and iou losses only on the prediction channel
  720. # with the lowest focal+dice loss between predicted mask and ground-truth.
  721. # """
  722. # target_masks = targets["masks"].unsqueeze(1).float()
  723. # assert target_masks.dim() == 4 # [N, 1, H, W]
  724. # src_masks_list = outputs["multistep_pred_multimasks_high_res"]
  725. # ious_list = outputs["multistep_pred_ious"]
  726. # object_score_logits_list = outputs["multistep_object_score_logits"]
  727. # assert len(src_masks_list) == len(ious_list)
  728. # assert len(object_score_logits_list) == len(ious_list)
  729. # # Remove invalid masks from loss
  730. # keep = targets["is_valid_mask"]
  731. # target_masks = target_masks[keep]
  732. # # accumulate the loss over prediction steps
  733. # losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
  734. # for src_masks, ious, object_score_logits in zip(
  735. # src_masks_list, ious_list, object_score_logits_list
  736. # ):
  737. # object_score_logits = object_score_logits[keep]
  738. # ious = ious[keep]
  739. # src_masks = src_masks[keep]
  740. # self._update_losses(
  741. # losses, src_masks, target_masks, ious, num_boxes, object_score_logits
  742. # )
  743. # return losses
  744. # def _update_losses(
  745. # self, losses, src_masks, target_masks, ious, num_boxes, object_score_logits
  746. # ):
  747. # target_masks = target_masks.expand_as(src_masks)
  748. # # get focal, dice and iou loss on all output masks in a prediction step
  749. # loss_multimask = sigmoid_focal_loss(
  750. # src_masks,
  751. # target_masks,
  752. # num_boxes,
  753. # alpha=self.focal_alpha,
  754. # gamma=self.focal_gamma,
  755. # loss_on_multimask=True,
  756. # triton=False, # only use triton if alpha > 0
  757. # )
  758. # loss_multidice = dice_loss(
  759. # src_masks, target_masks, num_boxes, loss_on_multimask=True
  760. # )
  761. # if not self.pred_obj_scores:
  762. # loss_class = torch.tensor(
  763. # 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
  764. # )
  765. # target_obj = torch.ones(
  766. # loss_multimask.shape[0],
  767. # 1,
  768. # dtype=loss_multimask.dtype,
  769. # device=loss_multimask.device,
  770. # )
  771. # else:
  772. # target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
  773. # ..., None
  774. # ].float()
  775. # loss_class = sigmoid_focal_loss(
  776. # object_score_logits,
  777. # target_obj,
  778. # num_boxes,
  779. # alpha=self.focal_alpha_obj_score,
  780. # gamma=self.focal_gamma_obj_score,
  781. # triton=False,
  782. # )
  783. # loss_multiiou = iou_loss(
  784. # src_masks,
  785. # target_masks,
  786. # ious,
  787. # num_boxes,
  788. # loss_on_multimask=True,
  789. # use_l1_loss=self.iou_use_l1_loss,
  790. # )
  791. # assert loss_multimask.dim() == 2
  792. # assert loss_multidice.dim() == 2
  793. # assert loss_multiiou.dim() == 2
  794. # if loss_multimask.size(1) > 1:
  795. # # take the mask indices with the smallest focal + dice loss for back propagation
  796. # loss_combo = (
  797. # loss_multimask * self.weight_dict["loss_mask"]
  798. # + loss_multidice * self.weight_dict["loss_dice"]
  799. # )
  800. # best_loss_inds = torch.argmin(loss_combo, dim=-1)
  801. # batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
  802. # loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
  803. # loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
  804. # # calculate the iou prediction and slot losses only in the index
  805. # # with the minimum loss for each mask (to be consistent w/ SAM)
  806. # if self.supervise_all_iou:
  807. # loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
  808. # else:
  809. # loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
  810. # else:
  811. # loss_mask = loss_multimask
  812. # loss_dice = loss_multidice
  813. # loss_iou = loss_multiiou
  814. # # backprop focal, dice and iou loss only if obj present
  815. # loss_mask = loss_mask * target_obj
  816. # loss_dice = loss_dice * target_obj
  817. # loss_iou = loss_iou * target_obj
  818. # # sum over batch dimension (note that the losses are already divided by num_boxes)
  819. # losses["loss_mask"] += loss_mask.sum()
  820. # losses["loss_dice"] += loss_dice.sum()
  821. # losses["loss_iou"] += loss_iou.sum()
  822. # losses["loss_class"] += loss_class
  823. # class TextCriterion(LossWithWeights):
  824. # def __init__(
  825. # self,
  826. # pad_token,
  827. # max_seq_len=100,
  828. # weight_dict=None,
  829. # compute_aux=False,
  830. # ):
  831. # super().__init__(weight_dict, compute_aux)
  832. # self.pad_token = pad_token
  833. # self.max_seq_len = max_seq_len
  834. # self.in_lengths = None
  835. # def get_loss(self, outputs, **kwargs):
  836. # nb_tokens = outputs["captioning_tokenized_target"].input_ids.numel()
  837. # bs, seq_len = outputs["captioning_tokenized_target"].input_ids.shape
  838. # ce = F.cross_entropy(
  839. # outputs["captioning_pred_text"].flatten(0, -2),
  840. # outputs["captioning_tokenized_target"].input_ids.flatten(),
  841. # ignore_index=self.pad_token,
  842. # reduction="sum",
  843. # )
  844. # not_pad = (
  845. # outputs["captioning_tokenized_target"]
  846. # .input_ids.reshape(-1)
  847. # .ne(self.pad_token)
  848. # )
  849. # if nb_tokens > 0:
  850. # nb_non_pad = not_pad.numel()
  851. # ce = ce / nb_non_pad
  852. # preds = outputs["captioning_pred_text"].flatten(0, -2).argmax(-1)[not_pad]
  853. # targets = outputs["captioning_tokenized_target"].input_ids.flatten()[not_pad]
  854. # correct = preds == targets
  855. # correct = correct.sum() / (correct.numel() + 1e-5)
  856. # correct_sequence_level = torch.all(
  857. # (
  858. # outputs["captioning_pred_text"]
  859. # .flatten(0, -2)
  860. # .argmax(-1)
  861. # .reshape(bs, seq_len)
  862. # == outputs["captioning_tokenized_target"].input_ids
  863. # )
  864. # | (~not_pad).view(bs, seq_len),
  865. # dim=1,
  866. # )
  867. # seq_level_acc = correct_sequence_level.float().mean()
  868. # return {"loss_text": ce, "text_acc": correct, "text_seq_acc": seq_level_acc}
  869. def segment_miou(source, target):
  870. """Compute the mean IoU between two sets of masks"""
  871. assert source.shape == target.shape, "The two masks must have the same shape"
  872. assert source.ndim == 3, "The masks must be 3D"
  873. valid_targets = (target.sum(dim=(1, 2)) > 0).sum()
  874. if valid_targets == 0:
  875. return torch.tensor(1.0, device=source.device)
  876. intersection = (source.bool() & target.bool()).sum(dim=(1, 2))
  877. union = (source.bool() | target.bool()).sum(dim=(1, 2))
  878. iou = intersection / (union + 1e-8)
  879. return iou.sum() / valid_targets
  880. class SemanticSegCriterion(LossWithWeights):
  881. def __init__(
  882. self,
  883. weight_dict,
  884. focal: bool = False,
  885. focal_alpha: float = 0.6,
  886. focal_gamma: float = 1.6,
  887. downsample: bool = True,
  888. presence_head: bool = False,
  889. # Option to turn off presence loss, if some other component
  890. # is already doing it, e.g. decoder - in which case,
  891. # we could still set presence_head to True so that
  892. # losses are not propogated to masks when there is no GT mask
  893. presence_loss: bool = True,
  894. ):
  895. super().__init__(weight_dict, False)
  896. self.focal = focal
  897. self.focal_alpha = focal_alpha
  898. self.focal_gamma = focal_gamma
  899. self.downsample = downsample
  900. self.presence_head = presence_head
  901. self.presence_loss = presence_loss
  902. def get_loss(self, out_dict, targets):
  903. outputs = out_dict["semantic_seg"]
  904. presence_logit = out_dict["presence_logit"]
  905. if (
  906. "semantic_masks" in targets
  907. and targets["semantic_masks"] is not None
  908. and targets["semantic_masks"].size(0) > 0
  909. ):
  910. semantic_targets = targets["semantic_masks"]
  911. with torch.no_grad():
  912. if self.downsample:
  913. # downsample targets to the size of predictions
  914. size = outputs.shape[-2:]
  915. semantic_targets = (
  916. F.interpolate(
  917. semantic_targets.float().unsqueeze(1),
  918. size=size,
  919. mode="bilinear",
  920. align_corners=False,
  921. )
  922. .squeeze(1)
  923. .bool()
  924. )
  925. else:
  926. with torch.no_grad():
  927. if self.downsample:
  928. # downsample targets to the size of predictions
  929. size = outputs.shape[-2:]
  930. segments = (
  931. F.interpolate(
  932. targets["masks"].float().unsqueeze(1),
  933. size=size,
  934. mode="bilinear",
  935. align_corners=False,
  936. )
  937. .squeeze(1)
  938. .bool()
  939. )
  940. else:
  941. segments = targets["masks"].bool()
  942. # the annotations are for instance segmentation, so we merge them to get semantic segmentation
  943. semantic_targets = instance_masks_to_semantic_masks(
  944. segments, targets["num_boxes"]
  945. )
  946. if not self.downsample:
  947. # upsample predictions to the target size
  948. size = semantic_targets.shape[-2:]
  949. outputs = F.interpolate(
  950. outputs.float(),
  951. size=size,
  952. mode="bilinear",
  953. align_corners=False,
  954. )
  955. if self.focal:
  956. loss = sigmoid_focal_loss(
  957. outputs.squeeze(1).flatten(-2),
  958. semantic_targets.float().flatten(-2),
  959. num_boxes=len(semantic_targets),
  960. alpha=self.focal_alpha,
  961. gamma=self.focal_gamma,
  962. reduce=not self.presence_head,
  963. )
  964. if self.presence_head:
  965. loss = loss.mean(1)
  966. else:
  967. loss = F.binary_cross_entropy_with_logits(
  968. outputs.squeeze(1),
  969. semantic_targets.float(),
  970. reduction="none" if self.presence_head else "mean",
  971. )
  972. if self.presence_head:
  973. loss = loss.flatten(1).mean(1)
  974. loss_dice = dice_loss(
  975. outputs.squeeze(1).flatten(1),
  976. semantic_targets.flatten(1),
  977. len(semantic_targets),
  978. reduce=not self.presence_head,
  979. )
  980. miou = segment_miou(outputs.sigmoid().squeeze(1) > 0.5, semantic_targets)
  981. loss_dict = {}
  982. if self.presence_head:
  983. presence_target = semantic_targets.flatten(1).any(-1)
  984. if self.presence_loss:
  985. loss_presence = F.binary_cross_entropy_with_logits(
  986. presence_logit.flatten(),
  987. presence_target.float(),
  988. )
  989. presence_acc = (
  990. ((presence_logit.flatten().sigmoid() > 0.5) == presence_target)
  991. .float()
  992. .mean()
  993. )
  994. else:
  995. # Dummy values
  996. loss_presence = torch.tensor(0.0, device=loss.device)
  997. # Whichever component is computing the presence loss,
  998. # should also track presence_acc
  999. presence_acc = torch.tensor(0.0, device=loss.device)
  1000. loss_dict["loss_semantic_presence"] = loss_presence
  1001. loss_dict["presence_acc"] = presence_acc
  1002. # reduce the other losses, skipping the negative ones
  1003. bs = loss.shape[0]
  1004. assert presence_target.numel() == bs
  1005. mask = presence_target
  1006. nb_valid = presence_target.sum().item()
  1007. loss = (loss * mask.float()).sum() / (nb_valid + 1e-6)
  1008. loss_dice = (loss_dice * mask.float()).sum() / (nb_valid + 1e-6)
  1009. loss_dict.update(
  1010. {
  1011. "loss_semantic_seg": loss,
  1012. "loss_semantic_dice": loss_dice,
  1013. "miou_semantic_seg": miou,
  1014. }
  1015. )
  1016. return loss_dict
  1017. class Det2TrkAssoc(LossWithWeights):
  1018. def __init__(
  1019. self,
  1020. weight_dict,
  1021. use_fp_loss=False,
  1022. fp_loss_on_exhaustive_only=True,
  1023. treat_fp_as_new_obj=False,
  1024. ):
  1025. super().__init__(weight_dict, compute_aux=False)
  1026. self.use_fp_loss = use_fp_loss
  1027. self.fp_loss_on_exhaustive_only = fp_loss_on_exhaustive_only
  1028. self.treat_fp_as_new_obj = treat_fp_as_new_obj
  1029. if self.use_fp_loss:
  1030. self.target_keys.append("is_exhaustive")
  1031. def get_loss(self, outputs, targets, indices, num_boxes):
  1032. det2trk_assoc_logits = outputs["det2trk_assoc_logits"]
  1033. device = det2trk_assoc_logits.device
  1034. B, Q_det, Q_trk_plus_2 = det2trk_assoc_logits.shape
  1035. assert Q_trk_plus_2 >= 2
  1036. Q_trk = Q_trk_plus_2 - 2
  1037. # We only apply association losses to those detection queries that either match
  1038. # a GT instance or have score > 0 (i.e. those TP, FN and FP detection queries)
  1039. matched_object_ids = outputs["matched_object_ids"]
  1040. assert matched_object_ids.shape == (B, Q_det + Q_trk)
  1041. matched_obj_ids_det = matched_object_ids[:, :Q_det]
  1042. matched_obj_ids_trk = matched_object_ids[:, Q_det:]
  1043. det_is_matched_to_gt = matched_obj_ids_det >= 0
  1044. trk_is_matched_to_gt = matched_obj_ids_trk >= 0
  1045. # note: -1 label is ignored in the (softmax) cross_entropy loss below
  1046. det2trk_assoc_labels = -torch.ones(B, Q_det, dtype=torch.long, device=device)
  1047. # a) If a detection query is matched to a same object ID as a tracking query,
  1048. # we assign it the index of the tracking query as a label
  1049. det_is_same_obj_id_as_trk = (
  1050. det_is_matched_to_gt[:, :, None]
  1051. & trk_is_matched_to_gt[:, None, :]
  1052. & (matched_obj_ids_det[:, :, None] == matched_obj_ids_trk[:, None, :])
  1053. )
  1054. batch_idx, det_idx, trk_idx = det_is_same_obj_id_as_trk.nonzero(as_tuple=True)
  1055. det2trk_assoc_labels[batch_idx, det_idx] = trk_idx
  1056. # b) If a detection query is matched to GT but not to any tracking query,
  1057. # we assign it a "new_object" label
  1058. det_is_new_obj = det_is_matched_to_gt & ~det_is_same_obj_id_as_trk.any(dim=-1)
  1059. det2trk_assoc_labels[det_is_new_obj] = Q_trk
  1060. # c) If a detection query is not matched to GT but have score > 0,
  1061. # we assign it a "false_positive" label
  1062. if self.use_fp_loss:
  1063. det_is_above_thresh = outputs["pred_logits"][:, :Q_det].squeeze(2) > 0
  1064. det_is_fp = ~det_is_matched_to_gt & det_is_above_thresh
  1065. if self.treat_fp_as_new_obj:
  1066. det2trk_assoc_labels[det_is_fp] = Q_trk
  1067. else:
  1068. if self.fp_loss_on_exhaustive_only:
  1069. # only count FP detections on batches that are exhaustively annotated
  1070. det_is_fp &= targets["is_exhaustive"].unsqueeze(1).bool()
  1071. det2trk_assoc_labels[det_is_fp] = Q_trk + 1
  1072. # softmax cross-entropy loss for detection-to-tracking association
  1073. loss_det2trk_assoc = F.cross_entropy(
  1074. input=det2trk_assoc_logits.flatten(0, 1), # (B * Q_det, Q_trk + 2)
  1075. target=det2trk_assoc_labels.flatten(0, 1), # (B * Q_det)
  1076. ignore_index=-1,
  1077. reduction="none",
  1078. ).view(B, Q_det)
  1079. # skip det2trk assocation loss on frames w/o any (non-padding) tracking queries
  1080. frame_has_valid_trk = trk_is_matched_to_gt.any(dim=-1, keepdims=True) # (B, 1)
  1081. loss_det2trk_assoc = loss_det2trk_assoc * frame_has_valid_trk.float()
  1082. loss_det2trk_assoc = loss_det2trk_assoc.sum() / (B * num_boxes)
  1083. return {"loss_det2trk_assoc": loss_det2trk_assoc}
  1084. class TrackingByDetectionAssoc(LossWithWeights):
  1085. def __init__(self, weight_dict):
  1086. super().__init__(weight_dict, compute_aux=False, supports_o2m_loss=False)
  1087. assert "loss_det2trk_assoc" in self.weight_dict
  1088. assert "loss_trk2det_assoc" in self.weight_dict
  1089. def get_loss(self, outputs, targets, indices, num_boxes):
  1090. # Part A: gather object id matching between detection and tracking
  1091. det2trk_assoc_logits = outputs["det2trk_assoc_logits"] # (B, Q_det+1, Q_trk+1)
  1092. B, Q_det_plus_1, Q_trk_plus_1 = det2trk_assoc_logits.shape
  1093. assert Q_det_plus_1 >= 1 and Q_trk_plus_1 >= 1
  1094. Q_det = Q_det_plus_1 - 1
  1095. Q_trk = Q_trk_plus_1 - 1
  1096. device = det2trk_assoc_logits.device
  1097. matched_obj_ids_det = outputs["matched_object_ids"]
  1098. assert matched_obj_ids_det.shape == (B, Q_det)
  1099. det_is_matched_to_gt = matched_obj_ids_det >= 0
  1100. matched_obj_ids_trk = outputs["prev_trk_object_ids"]
  1101. assert matched_obj_ids_trk.shape == (B, Q_trk)
  1102. trk_is_matched_to_gt = matched_obj_ids_trk >= 0
  1103. frame_has_valid_trk = trk_is_matched_to_gt.any(dim=-1, keepdims=True) # (B, 1)
  1104. # check whether a detection object is the same as a tracking object
  1105. det_is_same_obj_id_as_trk = (
  1106. det_is_matched_to_gt[:, :, None]
  1107. & trk_is_matched_to_gt[:, None, :]
  1108. & (matched_obj_ids_det[:, :, None] == matched_obj_ids_trk[:, None, :])
  1109. ) # (B, Q_det, Q_trk)
  1110. # there should be at most one match for each detection and each previous tracked object
  1111. torch._assert_async(torch.all(det_is_same_obj_id_as_trk.sum(dim=2) <= 1))
  1112. torch._assert_async(torch.all(det_is_same_obj_id_as_trk.sum(dim=1) <= 1))
  1113. batch_idx, det_idx, trk_idx = det_is_same_obj_id_as_trk.nonzero(as_tuple=True)
  1114. # Part B: Detection-to-tracking association loss
  1115. # assign detection-to-tracking labels (note: -1 label is ignored in the loss below)
  1116. det2trk_assoc_labels = -torch.ones(B, Q_det, dtype=torch.long, device=device)
  1117. det2trk_assoc_labels[batch_idx, det_idx] = trk_idx
  1118. # if a detection is matched to GT but not to any tracking, assign it a "new-object" label
  1119. det_is_new_obj = det_is_matched_to_gt & ~det_is_same_obj_id_as_trk.any(dim=2)
  1120. det2trk_assoc_labels[det_is_new_obj] = Q_trk # "Q_trk" label is "new-object"
  1121. # softmax cross-entropy loss for detection-to-tracking association
  1122. loss_det2trk_assoc = F.cross_entropy(
  1123. input=det2trk_assoc_logits[:, :-1].flatten(0, 1), # (B*Q_det, Q_trk+1)
  1124. target=det2trk_assoc_labels.flatten(0, 1), # (B*Q_det)
  1125. ignore_index=-1,
  1126. reduction="none",
  1127. ).view(B, Q_det)
  1128. # skip det2trk assocation loss on frames w/o any (non-padding) tracking queries
  1129. loss_det2trk_assoc = loss_det2trk_assoc * frame_has_valid_trk.float()
  1130. loss_det2trk_assoc = loss_det2trk_assoc.sum() / (B * num_boxes)
  1131. loss_dict = {"loss_det2trk_assoc": loss_det2trk_assoc}
  1132. # Part C: tracking-to-detection association loss
  1133. trk2det_assoc_logits = det2trk_assoc_logits.transpose(1, 2)
  1134. assert trk2det_assoc_logits.shape == (B, Q_trk + 1, Q_det + 1)
  1135. # assign tracking-to-detection labels (note: -1 label is ignored in the loss below)
  1136. trk2det_assoc_labels = -torch.ones(B, Q_trk, dtype=torch.long, device=device)
  1137. trk2det_assoc_labels[batch_idx, trk_idx] = det_idx
  1138. # if a tracking is matched to GT but not to any detection, assign it a "occluded" label
  1139. trk_is_occluded = trk_is_matched_to_gt & ~det_is_same_obj_id_as_trk.any(dim=1)
  1140. trk2det_assoc_labels[trk_is_occluded] = Q_det # "Q_det" label is "occluded"
  1141. # softmax cross-entropy loss for tracking-to-detection association
  1142. loss_trk2det_assoc = F.cross_entropy(
  1143. input=trk2det_assoc_logits[:, :-1].flatten(0, 1), # (B*Q_trk, Q_det+1)
  1144. target=trk2det_assoc_labels.flatten(0, 1), # (B*Q_trk)
  1145. ignore_index=-1,
  1146. reduction="none",
  1147. ).view(B, Q_trk)
  1148. # skip trk2det association loss on frames w/o any (non-padding) tracking queries
  1149. loss_trk2det_assoc = loss_trk2det_assoc * frame_has_valid_trk.float()
  1150. loss_trk2det_assoc = loss_trk2det_assoc.sum() / (B * num_boxes)
  1151. loss_dict["loss_trk2det_assoc"] = loss_trk2det_assoc
  1152. return loss_dict
  1153. def _keep_only_trk_queries_in_match_inds(inds, Q_det):
  1154. """Keep only the tracking query indices in the indices tuple"""
  1155. batch_idx, src_idx, tgt_idx = inds
  1156. if batch_idx.numel() == 0:
  1157. return (batch_idx, src_idx, tgt_idx) # empty indices, nothing to filter
  1158. # keep only the tracking query indices
  1159. is_trk_query = src_idx >= Q_det
  1160. batch_idx_trk = batch_idx[is_trk_query]
  1161. src_idx_trk = src_idx[is_trk_query]
  1162. tgt_idx_trk = tgt_idx[is_trk_query] if tgt_idx is not None else None
  1163. return (batch_idx_trk, src_idx_trk, tgt_idx_trk)