vl_combiner.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Provides utility to combine a vision backbone with a language backbone."""
  4. from copy import copy
  5. from typing import List, Optional
  6. import torch
  7. import torch.nn as nn
  8. from torch.nn.attention import sdpa_kernel, SDPBackend
  9. from .act_ckpt_utils import activation_ckpt_wrapper
  10. from .necks import Sam3DualViTDetNeck
  11. class SAM3VLBackbone(nn.Module):
  12. """This backbone combines a vision backbone and a language backbone without fusion.
  13. As such it is more of a convenience wrapper to handle the two backbones together.
  14. It adds support for activation checkpointing and compilation.
  15. """
  16. def __init__(
  17. self,
  18. visual: Sam3DualViTDetNeck,
  19. text,
  20. compile_visual: bool = False,
  21. act_ckpt_whole_vision_backbone: bool = False,
  22. act_ckpt_whole_language_backbone: bool = False,
  23. scalp=0,
  24. ):
  25. """Initialize the backbone combiner.
  26. :param visual: The vision backbone to use
  27. :param text: The text encoder to use
  28. """
  29. super().__init__()
  30. self.vision_backbone: Sam3DualViTDetNeck = (
  31. torch.compile(visual) if compile_visual else visual
  32. )
  33. self.language_backbone = text
  34. self.scalp = scalp
  35. # allow running activation checkpointing on the entire vision and language backbones
  36. self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
  37. self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
  38. def forward(
  39. self,
  40. samples: torch.Tensor,
  41. captions: List[str],
  42. input_boxes: Optional[torch.Tensor] = None,
  43. additional_text: Optional[List[str]] = None,
  44. ):
  45. """Forward pass of the backbone combiner.
  46. :param samples: The input images
  47. :param captions: The input captions
  48. :param input_boxes: If the text contains place-holders for boxes, this
  49. parameter contains the tensor containing their spatial features
  50. :param additional_text: This can be used to encode some additional text
  51. (different from the captions) in the same forward of the backbone
  52. :return: Output dictionary with the following keys:
  53. - vision_features: The output of the vision backbone
  54. - language_features: The output of the language backbone
  55. - language_mask: The attention mask of the language backbone
  56. - vision_pos_enc: The positional encoding of the vision backbone
  57. - (optional) additional_text_features: The output of the language
  58. backbone for the additional text
  59. - (optional) additional_text_mask: The attention mask of the
  60. language backbone for the additional text
  61. """
  62. output = self.forward_image(samples)
  63. device = output["vision_features"].device
  64. output.update(self.forward_text(captions, input_boxes, additional_text, device))
  65. return output
  66. def forward_image(self, samples: torch.Tensor):
  67. return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
  68. samples=samples,
  69. act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
  70. )
  71. def _forward_image_no_act_ckpt(self, samples):
  72. # Forward through backbone
  73. sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
  74. samples
  75. )
  76. if self.scalp > 0:
  77. # Discard the lowest resolution features
  78. sam3_features, sam3_pos = (
  79. sam3_features[: -self.scalp],
  80. sam3_pos[: -self.scalp],
  81. )
  82. if sam2_features is not None and sam2_pos is not None:
  83. sam2_features, sam2_pos = (
  84. sam2_features[: -self.scalp],
  85. sam2_pos[: -self.scalp],
  86. )
  87. sam2_output = None
  88. if sam2_features is not None and sam2_pos is not None:
  89. sam2_src = sam2_features[-1]
  90. sam2_output = {
  91. "vision_features": sam2_src,
  92. "vision_pos_enc": sam2_pos,
  93. "backbone_fpn": sam2_features,
  94. }
  95. sam3_src = sam3_features[-1]
  96. output = {
  97. "vision_features": sam3_src,
  98. "vision_pos_enc": sam3_pos,
  99. "backbone_fpn": sam3_features,
  100. "sam2_backbone_out": sam2_output,
  101. }
  102. return output
  103. def forward_text(
  104. self, captions, input_boxes=None, additional_text=None, device="cuda"
  105. ):
  106. return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
  107. captions=captions,
  108. input_boxes=input_boxes,
  109. additional_text=additional_text,
  110. device=device,
  111. act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training,
  112. )
  113. def _forward_text_no_ack_ckpt(
  114. self,
  115. captions,
  116. input_boxes=None,
  117. additional_text=None,
  118. device="cuda",
  119. ):
  120. output = {}
  121. # Forward through text_encoder
  122. text_to_encode = copy(captions)
  123. if additional_text is not None:
  124. # if there are additional_text, we piggy-back them into this forward.
  125. # They'll be used later for output alignment
  126. text_to_encode += additional_text
  127. sdpa_context = sdpa_kernel(
  128. [
  129. SDPBackend.MATH,
  130. SDPBackend.EFFICIENT_ATTENTION,
  131. SDPBackend.FLASH_ATTENTION,
  132. ]
  133. )
  134. with sdpa_context:
  135. text_attention_mask, text_memory, text_embeds = self.language_backbone(
  136. text_to_encode, input_boxes, device=device
  137. )
  138. if additional_text is not None:
  139. output["additional_text_features"] = text_memory[:, -len(additional_text) :]
  140. output["additional_text_mask"] = text_attention_mask[
  141. -len(additional_text) :
  142. ]
  143. text_memory = text_memory[:, : len(captions)]
  144. text_attention_mask = text_attention_mask[: len(captions)]
  145. text_embeds = text_embeds[:, : len(captions)]
  146. output["language_features"] = text_memory
  147. output["language_mask"] = text_attention_mask
  148. output["language_embeds"] = (
  149. text_embeds # Text embeddings before forward to the encoder
  150. )
  151. return output