image_encoder.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import List, Optional
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class ImageEncoder(nn.Module):
  10. def __init__(
  11. self,
  12. trunk: nn.Module,
  13. neck: nn.Module,
  14. scalp: int = 0,
  15. ):
  16. super().__init__()
  17. self.trunk = trunk
  18. self.neck = neck
  19. self.scalp = scalp
  20. assert (
  21. self.trunk.channel_list == self.neck.backbone_channel_list
  22. ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
  23. def forward(self, sample: torch.Tensor):
  24. # Forward through backbone
  25. features, pos = self.neck(self.trunk(sample))
  26. if self.scalp > 0:
  27. # Discard the lowest resolution features
  28. features, pos = features[: -self.scalp], pos[: -self.scalp]
  29. src = features[-1]
  30. output = {
  31. "vision_features": src,
  32. "vision_pos_enc": pos,
  33. "backbone_fpn": features,
  34. }
  35. return output
  36. class FpnNeck(nn.Module):
  37. """
  38. A modified variant of Feature Pyramid Network (FPN) neck
  39. (we remove output conv and also do bicubic interpolation similar to ViT
  40. pos embed interpolation)
  41. """
  42. def __init__(
  43. self,
  44. position_encoding: nn.Module,
  45. d_model: int,
  46. backbone_channel_list: List[int],
  47. kernel_size: int = 1,
  48. stride: int = 1,
  49. padding: int = 0,
  50. fpn_interp_model: str = "bilinear",
  51. fuse_type: str = "sum",
  52. fpn_top_down_levels: Optional[List[int]] = None,
  53. ):
  54. """Initialize the neck
  55. :param trunk: the backbone
  56. :param position_encoding: the positional encoding to use
  57. :param d_model: the dimension of the model
  58. :param neck_norm: the normalization to use
  59. """
  60. super().__init__()
  61. self.position_encoding = position_encoding
  62. self.convs = nn.ModuleList()
  63. self.backbone_channel_list = backbone_channel_list
  64. for dim in backbone_channel_list:
  65. current = nn.Sequential()
  66. current.add_module(
  67. "conv",
  68. nn.Conv2d(
  69. in_channels=dim,
  70. out_channels=d_model,
  71. kernel_size=kernel_size,
  72. stride=stride,
  73. padding=padding,
  74. ),
  75. )
  76. self.convs.append(current)
  77. self.fpn_interp_model = fpn_interp_model
  78. assert fuse_type in ["sum", "avg"]
  79. self.fuse_type = fuse_type
  80. # levels to have top-down features in its outputs
  81. # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
  82. # have top-down propagation, while outputs of level 0 and level 1 have only
  83. # lateral features from the same backbone level.
  84. if fpn_top_down_levels is None:
  85. # default is to have top-down features on all levels
  86. fpn_top_down_levels = range(len(self.convs))
  87. self.fpn_top_down_levels = list(fpn_top_down_levels)
  88. def forward(self, xs: List[torch.Tensor]):
  89. out = [None] * len(self.convs)
  90. pos = [None] * len(self.convs)
  91. assert len(xs) == len(self.convs)
  92. # fpn forward pass
  93. # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
  94. prev_features = None
  95. # forward in top-down order (from low to high resolution)
  96. n = len(self.convs) - 1
  97. for i in range(n, -1, -1):
  98. x = xs[i]
  99. lateral_features = self.convs[n - i](x)
  100. if i in self.fpn_top_down_levels and prev_features is not None:
  101. top_down_features = F.interpolate(
  102. prev_features.to(dtype=torch.float32),
  103. scale_factor=2.0,
  104. mode=self.fpn_interp_model,
  105. align_corners=(
  106. None if self.fpn_interp_model == "nearest" else False
  107. ),
  108. antialias=False,
  109. )
  110. prev_features = lateral_features + top_down_features
  111. if self.fuse_type == "avg":
  112. prev_features /= 2
  113. else:
  114. prev_features = lateral_features
  115. x_out = prev_features
  116. out[i] = x_out
  117. pos[i] = self.position_encoding(x_out).to(x_out.dtype)
  118. return out, pos