necks.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """Necks are the interface between a vision backbone and the rest of the detection model"""
  4. from copy import deepcopy
  5. from typing import List, Optional, Tuple
  6. import torch
  7. import torch.nn as nn
  8. class Sam3DualViTDetNeck(nn.Module):
  9. def __init__(
  10. self,
  11. trunk: nn.Module,
  12. position_encoding: nn.Module,
  13. d_model: int,
  14. scale_factors=(4.0, 2.0, 1.0, 0.5),
  15. add_sam2_neck: bool = False,
  16. ):
  17. """
  18. SimpleFPN neck a la ViTDet
  19. (From detectron2, very lightly adapted)
  20. It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights
  21. :param trunk: the backbone
  22. :param position_encoding: the positional encoding to use
  23. :param d_model: the dimension of the model
  24. """
  25. super().__init__()
  26. self.trunk = trunk
  27. self.position_encoding = position_encoding
  28. self.convs = nn.ModuleList()
  29. self.scale_factors = scale_factors
  30. use_bias = True
  31. dim: int = self.trunk.channel_list[-1]
  32. for _, scale in enumerate(scale_factors):
  33. current = nn.Sequential()
  34. if scale == 4.0:
  35. current.add_module(
  36. "dconv_2x2_0",
  37. nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
  38. )
  39. current.add_module(
  40. "gelu",
  41. nn.GELU(),
  42. )
  43. current.add_module(
  44. "dconv_2x2_1",
  45. nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
  46. )
  47. out_dim = dim // 4
  48. elif scale == 2.0:
  49. current.add_module(
  50. "dconv_2x2",
  51. nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
  52. )
  53. out_dim = dim // 2
  54. elif scale == 1.0:
  55. out_dim = dim
  56. elif scale == 0.5:
  57. current.add_module(
  58. "maxpool_2x2",
  59. nn.MaxPool2d(kernel_size=2, stride=2),
  60. )
  61. out_dim = dim
  62. else:
  63. raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
  64. current.add_module(
  65. "conv_1x1",
  66. nn.Conv2d(
  67. in_channels=out_dim,
  68. out_channels=d_model,
  69. kernel_size=1,
  70. bias=use_bias,
  71. ),
  72. )
  73. current.add_module(
  74. "conv_3x3",
  75. nn.Conv2d(
  76. in_channels=d_model,
  77. out_channels=d_model,
  78. kernel_size=3,
  79. padding=1,
  80. bias=use_bias,
  81. ),
  82. )
  83. self.convs.append(current)
  84. self.sam2_convs = None
  85. if add_sam2_neck:
  86. # Assumes sam2 neck is just a clone of the original neck
  87. self.sam2_convs = deepcopy(self.convs)
  88. def forward(
  89. self, tensor_list: List[torch.Tensor]
  90. ) -> Tuple[
  91. List[torch.Tensor],
  92. List[torch.Tensor],
  93. Optional[List[torch.Tensor]],
  94. Optional[List[torch.Tensor]],
  95. ]:
  96. xs = self.trunk(tensor_list)
  97. sam3_out, sam3_pos = [], []
  98. sam2_out, sam2_pos = None, None
  99. if self.sam2_convs is not None:
  100. sam2_out, sam2_pos = [], []
  101. x = xs[-1] # simpleFPN
  102. for i in range(len(self.convs)):
  103. sam3_x_out = self.convs[i](x)
  104. sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
  105. sam3_out.append(sam3_x_out)
  106. sam3_pos.append(sam3_pos_out)
  107. if self.sam2_convs is not None:
  108. sam2_x_out = self.sam2_convs[i](x)
  109. sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
  110. sam2_out.append(sam2_x_out)
  111. sam2_pos.append(sam2_pos_out)
  112. return sam3_out, sam3_pos, sam2_out, sam2_pos