image_encoder.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import List, Optional
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class ImageEncoder(nn.Module):
  10. def __init__(
  11. self,
  12. trunk: nn.Module,
  13. neck: nn.Module,
  14. scalp: int = 0,
  15. ):
  16. super().__init__()
  17. self.trunk = trunk
  18. self.neck = neck
  19. self.scalp = scalp
  20. assert (
  21. self.trunk.channel_list == self.neck.backbone_channel_list
  22. ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
  23. def forward(self, sample: torch.Tensor):
  24. # Forward through backbone
  25. features, pos = self.neck(self.trunk(sample))
  26. if self.scalp > 0:
  27. # Discard the lowest resolution features
  28. features, pos = features[: -self.scalp], pos[: -self.scalp]
  29. src = features[-1]
  30. output = {
  31. "vision_features": src,
  32. "vision_pos_enc": pos,
  33. "backbone_fpn": features,
  34. }
  35. return output
  36. class FpnNeck(nn.Module):
  37. """
  38. A modified variant of Feature Pyramid Network (FPN) neck
  39. (we remove output conv and also do bicubic interpolation similar to ViT
  40. pos embed interpolation)
  41. """
  42. def __init__(
  43. self,
  44. position_encoding: nn.Module,
  45. d_model: int,
  46. backbone_channel_list: List[int],
  47. kernel_size: int = 1,
  48. stride: int = 1,
  49. padding: int = 0,
  50. fpn_interp_model: str = "bilinear",
  51. fuse_type: str = "sum",
  52. fpn_top_down_levels: Optional[List[int]] = None,
  53. ):
  54. """Initialize the neck
  55. :param trunk: the backbone
  56. :param position_encoding: the positional encoding to use
  57. :param d_model: the dimension of the model
  58. :param neck_norm: the normalization to use
  59. """
  60. super().__init__()
  61. self.position_encoding = position_encoding
  62. self.convs = nn.ModuleList()
  63. self.backbone_channel_list = backbone_channel_list
  64. self.d_model = d_model
  65. for dim in backbone_channel_list:
  66. current = nn.Sequential()
  67. current.add_module(
  68. "conv",
  69. nn.Conv2d(
  70. in_channels=dim,
  71. out_channels=d_model,
  72. kernel_size=kernel_size,
  73. stride=stride,
  74. padding=padding,
  75. ),
  76. )
  77. self.convs.append(current)
  78. self.fpn_interp_model = fpn_interp_model
  79. assert fuse_type in ["sum", "avg"]
  80. self.fuse_type = fuse_type
  81. # levels to have top-down features in its outputs
  82. # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
  83. # have top-down propagation, while outputs of level 0 and level 1 have only
  84. # lateral features from the same backbone level.
  85. if fpn_top_down_levels is None:
  86. # default is to have top-down features on all levels
  87. fpn_top_down_levels = range(len(self.convs))
  88. self.fpn_top_down_levels = list(fpn_top_down_levels)
  89. def forward(self, xs: List[torch.Tensor]):
  90. out = [None] * len(self.convs)
  91. pos = [None] * len(self.convs)
  92. assert len(xs) == len(self.convs)
  93. # fpn forward pass
  94. # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
  95. prev_features = None
  96. # forward in top-down order (from low to high resolution)
  97. n = len(self.convs) - 1
  98. for i in range(n, -1, -1):
  99. x = xs[i]
  100. lateral_features = self.convs[n - i](x)
  101. if i in self.fpn_top_down_levels and prev_features is not None:
  102. top_down_features = F.interpolate(
  103. prev_features.to(dtype=torch.float32),
  104. scale_factor=2.0,
  105. mode=self.fpn_interp_model,
  106. align_corners=(
  107. None if self.fpn_interp_model == "nearest" else False
  108. ),
  109. antialias=False,
  110. )
  111. prev_features = lateral_features + top_down_features
  112. if self.fuse_type == "avg":
  113. prev_features /= 2
  114. else:
  115. prev_features = lateral_features
  116. x_out = prev_features
  117. out[i] = x_out
  118. pos[i] = self.position_encoding(x_out).to(x_out.dtype)
  119. return out, pos