loss_fns.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from collections import defaultdict
  6. from typing import Dict, List
  7. import torch
  8. import torch.distributed
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from training.trainer import CORE_LOSS_KEY
  12. from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
  13. def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
  14. """
  15. Compute the DICE loss, similar to generalized IOU for masks
  16. Args:
  17. inputs: A float tensor of arbitrary shape.
  18. The predictions for each example.
  19. targets: A float tensor with the same shape as inputs. Stores the binary
  20. classification label for each element in inputs
  21. (0 for the negative class and 1 for the positive class).
  22. num_objects: Number of objects in the batch
  23. loss_on_multimask: True if multimask prediction is enabled
  24. Returns:
  25. Dice loss tensor
  26. """
  27. inputs = inputs.sigmoid()
  28. if loss_on_multimask:
  29. # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
  30. assert inputs.dim() == 4 and targets.dim() == 4
  31. # flatten spatial dimension while keeping multimask channel dimension
  32. inputs = inputs.flatten(2)
  33. targets = targets.flatten(2)
  34. numerator = 2 * (inputs * targets).sum(-1)
  35. else:
  36. inputs = inputs.flatten(1)
  37. numerator = 2 * (inputs * targets).sum(1)
  38. denominator = inputs.sum(-1) + targets.sum(-1)
  39. loss = 1 - (numerator + 1) / (denominator + 1)
  40. if loss_on_multimask:
  41. return loss / num_objects
  42. return loss.sum() / num_objects
  43. def sigmoid_focal_loss(
  44. inputs,
  45. targets,
  46. num_objects,
  47. alpha: float = 0.25,
  48. gamma: float = 2,
  49. loss_on_multimask=False,
  50. ):
  51. """
  52. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  53. Args:
  54. inputs: A float tensor of arbitrary shape.
  55. The predictions for each example.
  56. targets: A float tensor with the same shape as inputs. Stores the binary
  57. classification label for each element in inputs
  58. (0 for the negative class and 1 for the positive class).
  59. num_objects: Number of objects in the batch
  60. alpha: (optional) Weighting factor in range (0,1) to balance
  61. positive vs negative examples. Default = -1 (no weighting).
  62. gamma: Exponent of the modulating factor (1 - p_t) to
  63. balance easy vs hard examples.
  64. loss_on_multimask: True if multimask prediction is enabled
  65. Returns:
  66. focal loss tensor
  67. """
  68. prob = inputs.sigmoid()
  69. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  70. p_t = prob * targets + (1 - prob) * (1 - targets)
  71. loss = ce_loss * ((1 - p_t) ** gamma)
  72. if alpha >= 0:
  73. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  74. loss = alpha_t * loss
  75. if loss_on_multimask:
  76. # loss is [N, M, H, W] where M corresponds to multiple predicted masks
  77. assert loss.dim() == 4
  78. return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
  79. return loss.mean(1).sum() / num_objects
  80. def iou_loss(
  81. inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
  82. ):
  83. """
  84. Args:
  85. inputs: A float tensor of arbitrary shape.
  86. The predictions for each example.
  87. targets: A float tensor with the same shape as inputs. Stores the binary
  88. classification label for each element in inputs
  89. (0 for the negative class and 1 for the positive class).
  90. pred_ious: A float tensor containing the predicted IoUs scores per mask
  91. num_objects: Number of objects in the batch
  92. loss_on_multimask: True if multimask prediction is enabled
  93. use_l1_loss: Whether to use L1 loss is used instead of MSE loss
  94. Returns:
  95. IoU loss tensor
  96. """
  97. assert inputs.dim() == 4 and targets.dim() == 4
  98. pred_mask = inputs.flatten(2) > 0
  99. gt_mask = targets.flatten(2) > 0
  100. area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
  101. area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
  102. actual_ious = area_i / torch.clamp(area_u, min=1.0)
  103. if use_l1_loss:
  104. loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
  105. else:
  106. loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
  107. if loss_on_multimask:
  108. return loss / num_objects
  109. return loss.sum() / num_objects
  110. class MultiStepMultiMasksAndIous(nn.Module):
  111. def __init__(
  112. self,
  113. weight_dict,
  114. focal_alpha=0.25,
  115. focal_gamma=2,
  116. supervise_all_iou=False,
  117. iou_use_l1_loss=False,
  118. pred_obj_scores=False,
  119. focal_gamma_obj_score=0.0,
  120. focal_alpha_obj_score=-1,
  121. ):
  122. """
  123. This class computes the multi-step multi-mask and IoU losses.
  124. Args:
  125. weight_dict: dict containing weights for focal, dice, iou losses
  126. focal_alpha: alpha for sigmoid focal loss
  127. focal_gamma: gamma for sigmoid focal loss
  128. supervise_all_iou: if True, back-prop iou losses for all predicted masks
  129. iou_use_l1_loss: use L1 loss instead of MSE loss for iou
  130. pred_obj_scores: if True, compute loss for object scores
  131. focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
  132. focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
  133. """
  134. super().__init__()
  135. self.weight_dict = weight_dict
  136. self.focal_alpha = focal_alpha
  137. self.focal_gamma = focal_gamma
  138. assert "loss_mask" in self.weight_dict
  139. assert "loss_dice" in self.weight_dict
  140. assert "loss_iou" in self.weight_dict
  141. if "loss_class" not in self.weight_dict:
  142. self.weight_dict["loss_class"] = 0.0
  143. self.focal_alpha_obj_score = focal_alpha_obj_score
  144. self.focal_gamma_obj_score = focal_gamma_obj_score
  145. self.supervise_all_iou = supervise_all_iou
  146. self.iou_use_l1_loss = iou_use_l1_loss
  147. self.pred_obj_scores = pred_obj_scores
  148. def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
  149. assert len(outs_batch) == len(targets_batch)
  150. num_objects = torch.tensor(
  151. (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
  152. ) # Number of objects is fixed within a batch
  153. if is_dist_avail_and_initialized():
  154. torch.distributed.all_reduce(num_objects)
  155. num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
  156. losses = defaultdict(int)
  157. for outs, targets in zip(outs_batch, targets_batch):
  158. cur_losses = self._forward(outs, targets, num_objects)
  159. for k, v in cur_losses.items():
  160. losses[k] += v
  161. return losses
  162. def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
  163. """
  164. Compute the losses related to the masks: the focal loss and the dice loss.
  165. and also the MAE or MSE loss between predicted IoUs and actual IoUs.
  166. Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
  167. of shape [N, M, H, W], where M could be 1 or larger, corresponding to
  168. one or multiple predicted masks from a click.
  169. We back-propagate focal, dice losses only on the prediction channel
  170. with the lowest focal+dice loss between predicted mask and ground-truth.
  171. If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
  172. """
  173. target_masks = targets.unsqueeze(1).float()
  174. assert target_masks.dim() == 4 # [N, 1, H, W]
  175. src_masks_list = outputs["multistep_pred_multimasks_high_res"]
  176. ious_list = outputs["multistep_pred_ious"]
  177. object_score_logits_list = outputs["multistep_object_score_logits"]
  178. assert len(src_masks_list) == len(ious_list)
  179. assert len(object_score_logits_list) == len(ious_list)
  180. # accumulate the loss over prediction steps
  181. losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
  182. for src_masks, ious, object_score_logits in zip(
  183. src_masks_list, ious_list, object_score_logits_list
  184. ):
  185. self._update_losses(
  186. losses, src_masks, target_masks, ious, num_objects, object_score_logits
  187. )
  188. losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
  189. return losses
  190. def _update_losses(
  191. self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
  192. ):
  193. target_masks = target_masks.expand_as(src_masks)
  194. # get focal, dice and iou loss on all output masks in a prediction step
  195. loss_multimask = sigmoid_focal_loss(
  196. src_masks,
  197. target_masks,
  198. num_objects,
  199. alpha=self.focal_alpha,
  200. gamma=self.focal_gamma,
  201. loss_on_multimask=True,
  202. )
  203. loss_multidice = dice_loss(
  204. src_masks, target_masks, num_objects, loss_on_multimask=True
  205. )
  206. if not self.pred_obj_scores:
  207. loss_class = torch.tensor(
  208. 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
  209. )
  210. target_obj = torch.ones(
  211. loss_multimask.shape[0],
  212. 1,
  213. dtype=loss_multimask.dtype,
  214. device=loss_multimask.device,
  215. )
  216. else:
  217. target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
  218. ..., None
  219. ].float()
  220. loss_class = sigmoid_focal_loss(
  221. object_score_logits,
  222. target_obj,
  223. num_objects,
  224. alpha=self.focal_alpha_obj_score,
  225. gamma=self.focal_gamma_obj_score,
  226. )
  227. loss_multiiou = iou_loss(
  228. src_masks,
  229. target_masks,
  230. ious,
  231. num_objects,
  232. loss_on_multimask=True,
  233. use_l1_loss=self.iou_use_l1_loss,
  234. )
  235. assert loss_multimask.dim() == 2
  236. assert loss_multidice.dim() == 2
  237. assert loss_multiiou.dim() == 2
  238. if loss_multimask.size(1) > 1:
  239. # take the mask indices with the smallest focal + dice loss for back propagation
  240. loss_combo = (
  241. loss_multimask * self.weight_dict["loss_mask"]
  242. + loss_multidice * self.weight_dict["loss_dice"]
  243. )
  244. best_loss_inds = torch.argmin(loss_combo, dim=-1)
  245. batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
  246. loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
  247. loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
  248. # calculate the iou prediction and slot losses only in the index
  249. # with the minimum loss for each mask (to be consistent w/ SAM)
  250. if self.supervise_all_iou:
  251. loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
  252. else:
  253. loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
  254. else:
  255. loss_mask = loss_multimask
  256. loss_dice = loss_multidice
  257. loss_iou = loss_multiiou
  258. # backprop focal, dice and iou loss only if obj present
  259. loss_mask = loss_mask * target_obj
  260. loss_dice = loss_dice * target_obj
  261. loss_iou = loss_iou * target_obj
  262. # sum over batch dimension (note that the losses are already divided by num_objects)
  263. losses["loss_mask"] += loss_mask.sum()
  264. losses["loss_dice"] += loss_dice.sum()
  265. losses["loss_iou"] += loss_iou.sum()
  266. losses["loss_class"] += loss_class
  267. def reduce_loss(self, losses):
  268. reduced_loss = 0.0
  269. for loss_key, weight in self.weight_dict.items():
  270. if loss_key not in losses:
  271. raise ValueError(f"{type(self)} doesn't compute {loss_key}")
  272. if weight != 0:
  273. reduced_loss += losses[loss_key] * weight
  274. return reduced_loss