| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- """Necks are the interface between a vision backbone and the rest of the detection model"""
- from copy import deepcopy
- from typing import List, Optional, Tuple
- import torch
- import torch.nn as nn
- class Sam3DualViTDetNeck(nn.Module):
- def __init__(
- self,
- trunk: nn.Module,
- position_encoding: nn.Module,
- d_model: int,
- scale_factors=(4.0, 2.0, 1.0, 0.5),
- add_sam2_neck: bool = False,
- ):
- """
- SimpleFPN neck a la ViTDet
- (From detectron2, very lightly adapted)
- It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights
- :param trunk: the backbone
- :param position_encoding: the positional encoding to use
- :param d_model: the dimension of the model
- """
- super().__init__()
- self.trunk = trunk
- self.position_encoding = position_encoding
- self.convs = nn.ModuleList()
- self.scale_factors = scale_factors
- use_bias = True
- dim: int = self.trunk.channel_list[-1]
- for _, scale in enumerate(scale_factors):
- current = nn.Sequential()
- if scale == 4.0:
- current.add_module(
- "dconv_2x2_0",
- nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
- )
- current.add_module(
- "gelu",
- nn.GELU(),
- )
- current.add_module(
- "dconv_2x2_1",
- nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
- )
- out_dim = dim // 4
- elif scale == 2.0:
- current.add_module(
- "dconv_2x2",
- nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
- )
- out_dim = dim // 2
- elif scale == 1.0:
- out_dim = dim
- elif scale == 0.5:
- current.add_module(
- "maxpool_2x2",
- nn.MaxPool2d(kernel_size=2, stride=2),
- )
- out_dim = dim
- else:
- raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
- current.add_module(
- "conv_1x1",
- nn.Conv2d(
- in_channels=out_dim,
- out_channels=d_model,
- kernel_size=1,
- bias=use_bias,
- ),
- )
- current.add_module(
- "conv_3x3",
- nn.Conv2d(
- in_channels=d_model,
- out_channels=d_model,
- kernel_size=3,
- padding=1,
- bias=use_bias,
- ),
- )
- self.convs.append(current)
- self.sam2_convs = None
- if add_sam2_neck:
- # Assumes sam2 neck is just a clone of the original neck
- self.sam2_convs = deepcopy(self.convs)
- def forward(
- self, tensor_list: List[torch.Tensor]
- ) -> Tuple[
- List[torch.Tensor],
- List[torch.Tensor],
- Optional[List[torch.Tensor]],
- Optional[List[torch.Tensor]],
- ]:
- xs = self.trunk(tensor_list)
- sam3_out, sam3_pos = [], []
- sam2_out, sam2_pos = None, None
- if self.sam2_convs is not None:
- sam2_out, sam2_pos = [], []
- x = xs[-1] # simpleFPN
- for i in range(len(self.convs)):
- sam3_x_out = self.convs[i](x)
- sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
- sam3_out.append(sam3_x_out)
- sam3_pos.append(sam3_pos_out)
- if self.sam2_convs is not None:
- sam2_x_out = self.sam2_convs[i](x)
- sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
- sam2_out.append(sam2_x_out)
- sam2_pos.append(sam2_pos_out)
- return sam3_out, sam3_pos, sam2_out, sam2_pos
|