# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe """Provides utility to combine a vision backbone with a language backbone.""" from copy import copy from typing import List, Optional import torch import torch.nn as nn from torch.nn.attention import sdpa_kernel, SDPBackend from .act_ckpt_utils import activation_ckpt_wrapper from .necks import Sam3DualViTDetNeck class SAM3VLBackbone(nn.Module): """This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a convenience wrapper to handle the two backbones together. It adds support for activation checkpointing and compilation. """ def __init__( self, visual: Sam3DualViTDetNeck, text, compile_visual: bool = False, act_ckpt_whole_vision_backbone: bool = False, act_ckpt_whole_language_backbone: bool = False, scalp=0, ): """Initialize the backbone combiner. :param visual: The vision backbone to use :param text: The text encoder to use """ super().__init__() self.vision_backbone: Sam3DualViTDetNeck = ( torch.compile(visual) if compile_visual else visual ) self.language_backbone = text self.scalp = scalp # allow running activation checkpointing on the entire vision and language backbones self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone def forward( self, samples: torch.Tensor, captions: List[str], input_boxes: Optional[torch.Tensor] = None, additional_text: Optional[List[str]] = None, ): """Forward pass of the backbone combiner. :param samples: The input images :param captions: The input captions :param input_boxes: If the text contains place-holders for boxes, this parameter contains the tensor containing their spatial features :param additional_text: This can be used to encode some additional text (different from the captions) in the same forward of the backbone :return: Output dictionary with the following keys: - vision_features: The output of the vision backbone - language_features: The output of the language backbone - language_mask: The attention mask of the language backbone - vision_pos_enc: The positional encoding of the vision backbone - (optional) additional_text_features: The output of the language backbone for the additional text - (optional) additional_text_mask: The attention mask of the language backbone for the additional text """ output = self.forward_image(samples) device = output["vision_features"].device output.update(self.forward_text(captions, input_boxes, additional_text, device)) return output def forward_image(self, samples: torch.Tensor): return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)( samples=samples, act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, ) def _forward_image_no_act_ckpt(self, samples): # Forward through backbone sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward( samples ) if self.scalp > 0: # Discard the lowest resolution features sam3_features, sam3_pos = ( sam3_features[: -self.scalp], sam3_pos[: -self.scalp], ) if sam2_features is not None and sam2_pos is not None: sam2_features, sam2_pos = ( sam2_features[: -self.scalp], sam2_pos[: -self.scalp], ) sam2_output = None if sam2_features is not None and sam2_pos is not None: sam2_src = sam2_features[-1] sam2_output = { "vision_features": sam2_src, "vision_pos_enc": sam2_pos, "backbone_fpn": sam2_features, } sam3_src = sam3_features[-1] output = { "vision_features": sam3_src, "vision_pos_enc": sam3_pos, "backbone_fpn": sam3_features, "sam2_backbone_out": sam2_output, } return output def forward_text( self, captions, input_boxes=None, additional_text=None, device="cuda" ): return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)( captions=captions, input_boxes=input_boxes, additional_text=additional_text, device=device, act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training, ) def _forward_text_no_ack_ckpt( self, captions, input_boxes=None, additional_text=None, device="cuda", ): output = {} # Forward through text_encoder text_to_encode = copy(captions) if additional_text is not None: # if there are additional_text, we piggy-back them into this forward. # They'll be used later for output alignment text_to_encode += additional_text sdpa_context = sdpa_kernel( [ SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, ] ) with sdpa_context: text_attention_mask, text_memory, text_embeds = self.language_backbone( text_to_encode, input_boxes, device=device ) if additional_text is not None: output["additional_text_features"] = text_memory[:, -len(additional_text) :] output["additional_text_mask"] = text_attention_mask[ -len(additional_text) : ] text_memory = text_memory[:, : len(captions)] text_attention_mask = text_attention_mask[: len(captions)] text_embeds = text_embeds[:, : len(captions)] output["language_features"] = text_memory output["language_mask"] = text_attention_mask output["language_embeds"] = ( text_embeds # Text embeddings before forward to the encoder ) return output