||
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from collections import defaultdict
- from typing import Dict, List
- import torch
- import torch.distributed
- import torch.nn as nn
- import torch.nn.functional as F
- from training.trainer import CORE_LOSS_KEY
- from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
- def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
- """
- Compute the DICE loss, similar to generalized IOU for masks
- Args:
- inputs: A float tensor of arbitrary shape.
- The predictions for each example.
- targets: A float tensor with the same shape as inputs. Stores the binary
- classification label for each element in inputs
- (0 for the negative class and 1 for the positive class).
- num_objects: Number of objects in the batch
- loss_on_multimask: True if multimask prediction is enabled
- Returns:
- Dice loss tensor
- """
- inputs = inputs.sigmoid()
- if loss_on_multimask:
- # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
- assert inputs.dim() == 4 and targets.dim() == 4
- # flatten spatial dimension while keeping multimask channel dimension
- inputs = inputs.flatten(2)
- targets = targets.flatten(2)
- numerator = 2 * (inputs * targets).sum(-1)
- else:
- inputs = inputs.flatten(1)
- numerator = 2 * (inputs * targets).sum(1)
- denominator = inputs.sum(-1) + targets.sum(-1)
- loss = 1 - (numerator + 1) / (denominator + 1)
- if loss_on_multimask:
- return loss / num_objects
- return loss.sum() / num_objects
- def sigmoid_focal_loss(
- inputs,
- targets,
- num_objects,
- alpha: float = 0.25,
- gamma: float = 2,
- loss_on_multimask=False,
- ):
- """
- Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
- Args:
- inputs: A float tensor of arbitrary shape.
- The predictions for each example.
- targets: A float tensor with the same shape as inputs. Stores the binary
- classification label for each element in inputs
- (0 for the negative class and 1 for the positive class).
- num_objects: Number of objects in the batch
- alpha: (optional) Weighting factor in range (0,1) to balance
- positive vs negative examples. Default = -1 (no weighting).
- gamma: Exponent of the modulating factor (1 - p_t) to
- balance easy vs hard examples.
- loss_on_multimask: True if multimask prediction is enabled
- Returns:
- focal loss tensor
- """
- prob = inputs.sigmoid()
- ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
- p_t = prob * targets + (1 - prob) * (1 - targets)
- loss = ce_loss * ((1 - p_t) ** gamma)
- if alpha >= 0:
- alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
- loss = alpha_t * loss
- if loss_on_multimask:
- # loss is [N, M, H, W] where M corresponds to multiple predicted masks
- assert loss.dim() == 4
- return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
- return loss.mean(1).sum() / num_objects
- def iou_loss(
- inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
- ):
- """
- Args:
- inputs: A float tensor of arbitrary shape.
- The predictions for each example.
- targets: A float tensor with the same shape as inputs. Stores the binary
- classification label for each element in inputs
- (0 for the negative class and 1 for the positive class).
- pred_ious: A float tensor containing the predicted IoUs scores per mask
- num_objects: Number of objects in the batch
- loss_on_multimask: True if multimask prediction is enabled
- use_l1_loss: Whether to use L1 loss is used instead of MSE loss
- Returns:
- IoU loss tensor
- """
- assert inputs.dim() == 4 and targets.dim() == 4
- pred_mask = inputs.flatten(2) > 0
- gt_mask = targets.flatten(2) > 0
- area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
- area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
- actual_ious = area_i / torch.clamp(area_u, min=1.0)
- if use_l1_loss:
- loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
- else:
- loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
- if loss_on_multimask:
- return loss / num_objects
- return loss.sum() / num_objects
- class MultiStepMultiMasksAndIous(nn.Module):
- def __init__(
- self,
- weight_dict,
- focal_alpha=0.25,
- focal_gamma=2,
- supervise_all_iou=False,
- iou_use_l1_loss=False,
- pred_obj_scores=False,
- focal_gamma_obj_score=0.0,
- focal_alpha_obj_score=-1,
- ):
- """
- This class computes the multi-step multi-mask and IoU losses.
- Args:
- weight_dict: dict containing weights for focal, dice, iou losses
- focal_alpha: alpha for sigmoid focal loss
- focal_gamma: gamma for sigmoid focal loss
- supervise_all_iou: if True, back-prop iou losses for all predicted masks
- iou_use_l1_loss: use L1 loss instead of MSE loss for iou
- pred_obj_scores: if True, compute loss for object scores
- focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
- focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
- """
- super().__init__()
- self.weight_dict = weight_dict
- self.focal_alpha = focal_alpha
- self.focal_gamma = focal_gamma
- assert "loss_mask" in self.weight_dict
- assert "loss_dice" in self.weight_dict
- assert "loss_iou" in self.weight_dict
- if "loss_class" not in self.weight_dict:
- self.weight_dict["loss_class"] = 0.0
- self.focal_alpha_obj_score = focal_alpha_obj_score
- self.focal_gamma_obj_score = focal_gamma_obj_score
- self.supervise_all_iou = supervise_all_iou
- self.iou_use_l1_loss = iou_use_l1_loss
- self.pred_obj_scores = pred_obj_scores
- def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
- assert len(outs_batch) == len(targets_batch)
- num_objects = torch.tensor(
- (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
- ) # Number of objects is fixed within a batch
- if is_dist_avail_and_initialized():
- torch.distributed.all_reduce(num_objects)
- num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
- losses = defaultdict(int)
- for outs, targets in zip(outs_batch, targets_batch):
- cur_losses = self._forward(outs, targets, num_objects)
- for k, v in cur_losses.items():
- losses[k] += v
- return losses
- def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
- """
- Compute the losses related to the masks: the focal loss and the dice loss.
- and also the MAE or MSE loss between predicted IoUs and actual IoUs.
- Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
- of shape [N, M, H, W], where M could be 1 or larger, corresponding to
- one or multiple predicted masks from a click.
- We back-propagate focal, dice losses only on the prediction channel
- with the lowest focal+dice loss between predicted mask and ground-truth.
- If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
- """
- target_masks = targets.unsqueeze(1).float()
- assert target_masks.dim() == 4 # [N, 1, H, W]
- src_masks_list = outputs["multistep_pred_multimasks_high_res"]
- ious_list = outputs["multistep_pred_ious"]
- object_score_logits_list = outputs["multistep_object_score_logits"]
- assert len(src_masks_list) == len(ious_list)
- assert len(object_score_logits_list) == len(ious_list)
- # accumulate the loss over prediction steps
- losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
- for src_masks, ious, object_score_logits in zip(
- src_masks_list, ious_list, object_score_logits_list
- ):
- self._update_losses(
- losses, src_masks, target_masks, ious, num_objects, object_score_logits
- )
- losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
- return losses
- def _update_losses(
- self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
- ):
- target_masks = target_masks.expand_as(src_masks)
- # get focal, dice and iou loss on all output masks in a prediction step
- loss_multimask = sigmoid_focal_loss(
- src_masks,
- target_masks,
- num_objects,
- alpha=self.focal_alpha,
- gamma=self.focal_gamma,
- loss_on_multimask=True,
- )
- loss_multidice = dice_loss(
- src_masks, target_masks, num_objects, loss_on_multimask=True
- )
- if not self.pred_obj_scores:
- loss_class = torch.tensor(
- 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
- )
- target_obj = torch.ones(
- loss_multimask.shape[0],
- 1,
- dtype=loss_multimask.dtype,
- device=loss_multimask.device,
- )
- else:
- target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
- ..., None
- ].float()
- loss_class = sigmoid_focal_loss(
- object_score_logits,
- target_obj,
- num_objects,
- alpha=self.focal_alpha_obj_score,
- gamma=self.focal_gamma_obj_score,
- )
- loss_multiiou = iou_loss(
- src_masks,
- target_masks,
- ious,
- num_objects,
- loss_on_multimask=True,
- use_l1_loss=self.iou_use_l1_loss,
- )
- assert loss_multimask.dim() == 2
- assert loss_multidice.dim() == 2
- assert loss_multiiou.dim() == 2
- if loss_multimask.size(1) > 1:
- # take the mask indices with the smallest focal + dice loss for back propagation
- loss_combo = (
- loss_multimask * self.weight_dict["loss_mask"]
- + loss_multidice * self.weight_dict["loss_dice"]
- )
- best_loss_inds = torch.argmin(loss_combo, dim=-1)
- batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
- loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
- loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
- # calculate the iou prediction and slot losses only in the index
- # with the minimum loss for each mask (to be consistent w/ SAM)
- if self.supervise_all_iou:
- loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
- else:
- loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
- else:
- loss_mask = loss_multimask
- loss_dice = loss_multidice
- loss_iou = loss_multiiou
- # backprop focal, dice and iou loss only if obj present
- loss_mask = loss_mask * target_obj
- loss_dice = loss_dice * target_obj
- loss_iou = loss_iou * target_obj
- # sum over batch dimension (note that the losses are already divided by num_objects)
- losses["loss_mask"] += loss_mask.sum()
- losses["loss_dice"] += loss_dice.sum()
- losses["loss_iou"] += loss_iou.sum()
- losses["loss_class"] += loss_class
- def reduce_loss(self, losses):
- reduced_loss = 0.0
- for loss_key, weight in self.weight_dict.items():
- if loss_key not in losses:
- raise ValueError(f"{type(self)} doesn't compute {loss_key}")
- if weight != 0:
- reduced_loss += losses[loss_key] * weight
- return reduced_loss
|