sam3_loss.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import torch
  4. from sam3.model.model_misc import SAM3Output
  5. from sam3.train.utils.distributed import get_world_size
  6. from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
  7. class DummyLoss(torch.nn.Module):
  8. """A dummy loss that always returns 0 (as a placeholder for eval)"""
  9. def __init__(
  10. self,
  11. core_loss_key: str = CORE_LOSS_KEY,
  12. device: str = "cuda",
  13. **kwargs,
  14. ):
  15. super().__init__()
  16. self.core_loss_key = core_loss_key
  17. self.device = torch.device(device)
  18. def forward(self, *args, **kwargs):
  19. return {self.core_loss_key: torch.tensor(0.0, device=self.device)}
  20. def accumulate(self, out_dict):
  21. """
  22. Called by iterative losses.
  23. """
  24. if self.core_loss_key not in out_dict:
  25. out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device)
  26. return out_dict
  27. class Sam3LossWrapper(torch.nn.Module):
  28. def __init__(
  29. self,
  30. loss_fns_find,
  31. normalization="global",
  32. matcher=None,
  33. o2m_matcher=None,
  34. o2m_weight=1.0,
  35. use_o2m_matcher_on_o2m_aux=True,
  36. loss_fn_semantic_seg=None,
  37. normalize_by_valid_object_num=False,
  38. normalize_by_stage_num=False,
  39. scale_by_find_batch_size=False,
  40. ):
  41. super().__init__()
  42. self.loss_fns_find = loss_fns_find
  43. assert normalization in ["global", "local", "none"]
  44. self.normalization = normalization
  45. self.normalize_by_valid_object_num = normalize_by_valid_object_num
  46. self.normalize_by_stage_num = normalize_by_stage_num
  47. self.matcher = matcher
  48. self.o2m_matcher = o2m_matcher
  49. self.o2m_weight = o2m_weight
  50. # whether to use the o2m matcher on the o2m queries in auxiliary outputs
  51. self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux
  52. self.loss_fn_semantic_seg = loss_fn_semantic_seg
  53. self.scale_by_find_batch_size = scale_by_find_batch_size
  54. def _get_num_boxes(self, targets):
  55. # the average number of target boxes for loss normalization
  56. if self.normalize_by_valid_object_num:
  57. # valid boxes are those with non-zero height and width
  58. # (while padded invisible boxes are )
  59. boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h
  60. num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float()
  61. else:
  62. num_boxes = targets["num_boxes"].sum().float()
  63. if self.normalization == "global":
  64. torch.distributed.all_reduce(num_boxes)
  65. num_boxes = torch.clamp(num_boxes / get_world_size(), min=1)
  66. elif self.normalization == "local":
  67. num_boxes = torch.clamp(num_boxes, min=1)
  68. elif self.normalization == "none":
  69. num_boxes = 1
  70. return num_boxes
  71. def compute_loss(self, nested_out, targets):
  72. num_boxes = self._get_num_boxes(targets)
  73. o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None)
  74. o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None)
  75. # Get a list of outputs, including auxiliary and first stage outputs
  76. output_list = [(nested_out, "", False)] # (out, suffix, is_aux)
  77. if "aux_outputs" in nested_out:
  78. output_list.extend(
  79. (aux_out, f"_aux_{i}", True)
  80. for i, aux_out in enumerate(nested_out["aux_outputs"])
  81. )
  82. if "first_stage" in nested_out:
  83. output_list.append((nested_out["first_stage"], "_fs", True))
  84. # Compute all the requested losses
  85. losses = {}
  86. total_core_loss = 0.0
  87. for out, suffix, is_aux in output_list:
  88. # o2o matcher indices need to be computed by the model (as the video model requires
  89. # a specific way of matching free and locked indices beyond just calling the matcher)
  90. indices = out["indices"]
  91. has_o2m_out = "pred_logits_o2m" in out
  92. if has_o2m_out:
  93. o2m_out = {
  94. k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m")
  95. }
  96. # o2m targets are the same as the o2o targets (assuming repeat=1)
  97. o2m_targets = targets
  98. if self.use_o2m_matcher_on_o2m_aux or not is_aux:
  99. o2m_indices = self.o2m_matcher(
  100. o2m_out,
  101. o2m_targets,
  102. out_is_valid=o2m_out_is_valid,
  103. target_is_valid_padded=o2m_target_is_valid_padded,
  104. )
  105. else:
  106. o2m_indices = self.matcher(
  107. o2m_out,
  108. o2m_targets,
  109. out_is_valid=o2m_out_is_valid,
  110. target_is_valid_padded=o2m_target_is_valid_padded,
  111. )
  112. for loss_fn in self.loss_fns_find:
  113. l_dict = loss_fn(
  114. outputs=out,
  115. targets=targets,
  116. indices=indices,
  117. num_boxes=num_boxes,
  118. is_aux=is_aux,
  119. )
  120. total_core_loss += l_dict.pop(CORE_LOSS_KEY)
  121. losses.update({f"{k}{suffix}": v for k, v in l_dict.items()})
  122. compute_o2m_loss = has_o2m_out
  123. # a special handling to allow turning off mask loss in o2m
  124. # (to be compatible with the original implementation)
  125. if isinstance(loss_fn, Masks):
  126. compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out
  127. if isinstance(loss_fn, Det2TrkAssoc):
  128. compute_o2m_loss = False # Det2TrkAssoc does not support o2m
  129. if compute_o2m_loss:
  130. l_dict = loss_fn(
  131. outputs=o2m_out,
  132. targets=o2m_targets,
  133. indices=o2m_indices,
  134. num_boxes=num_boxes,
  135. is_aux=is_aux,
  136. )
  137. for k in l_dict:
  138. l_dict[k] *= self.o2m_weight
  139. total_core_loss += l_dict.pop(CORE_LOSS_KEY)
  140. losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()})
  141. losses[CORE_LOSS_KEY] = total_core_loss
  142. return losses
  143. def forward(self, find_stages: SAM3Output, find_targets):
  144. if find_stages.loss_stages is not None:
  145. find_targets = [find_targets[i] for i in find_stages.loss_stages]
  146. with SAM3Output.iteration_mode(
  147. find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE
  148. ) as find_stages:
  149. assert len(find_stages) == len(find_targets)
  150. total_losses = {}
  151. for stage_outputs, stage_targets in zip(find_stages, find_targets):
  152. stage_targets = [stage_targets] * len(stage_outputs)
  153. # If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity)
  154. for outputs, targets in zip(stage_outputs, stage_targets):
  155. cur_losses = self.compute_loss(outputs, targets)
  156. if self.loss_fn_semantic_seg is not None:
  157. cur_losses_semantic = self.loss_fn_semantic_seg(
  158. outputs, targets
  159. )
  160. cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop(
  161. CORE_LOSS_KEY
  162. )
  163. # make sure the semantic losses don't overlap with the find losses
  164. assert set(cur_losses).isdisjoint(set(cur_losses_semantic))
  165. cur_losses.update(cur_losses_semantic)
  166. # Optionally, normalize the loss by the number of find stages (training video frames) so that
  167. # image batches and video batches have similar loss scales. (Otherwise video batches would
  168. # have a much higher loss scale due to summing the losses over all the find stages.)
  169. if self.normalize_by_stage_num:
  170. cur_losses[CORE_LOSS_KEY] /= len(find_stages)
  171. if self.scale_by_find_batch_size:
  172. bs = targets["num_boxes"].shape[0]
  173. # sqrt scaling based on the "effective" batch size
  174. cur_losses[CORE_LOSS_KEY] *= bs**0.5
  175. for k, v in cur_losses.items():
  176. if k not in total_losses:
  177. total_losses[k] = v
  178. else:
  179. total_losses[k] += v
  180. return total_losses