memory_encoder.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. from typing import Tuple
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
  11. class MaskDownSampler(nn.Module):
  12. """
  13. Progressively downsample a mask by total_stride, each time by stride.
  14. Note that LayerNorm is applied per *token*, like in ViT.
  15. With each downsample (by a factor stride**2), channel capacity increases by the same factor.
  16. In the end, we linearly project to embed_dim channels.
  17. """
  18. def __init__(
  19. self,
  20. embed_dim=256,
  21. kernel_size=4,
  22. stride=4,
  23. padding=0,
  24. total_stride=16,
  25. activation=nn.GELU,
  26. ):
  27. super().__init__()
  28. num_layers = int(math.log2(total_stride) // math.log2(stride))
  29. assert stride**num_layers == total_stride
  30. self.encoder = nn.Sequential()
  31. mask_in_chans, mask_out_chans = 1, 1
  32. for _ in range(num_layers):
  33. mask_out_chans = mask_in_chans * (stride**2)
  34. self.encoder.append(
  35. nn.Conv2d(
  36. mask_in_chans,
  37. mask_out_chans,
  38. kernel_size=kernel_size,
  39. stride=stride,
  40. padding=padding,
  41. )
  42. )
  43. self.encoder.append(LayerNorm2d(mask_out_chans))
  44. self.encoder.append(activation())
  45. mask_in_chans = mask_out_chans
  46. self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
  47. def forward(self, x):
  48. return self.encoder(x)
  49. # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
  50. class CXBlock(nn.Module):
  51. r"""ConvNeXt Block. There are two equivalent implementations:
  52. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  53. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  54. We use (2) as we find it slightly faster in PyTorch
  55. Args:
  56. dim (int): Number of input channels.
  57. drop_path (float): Stochastic depth rate. Default: 0.0
  58. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  59. """
  60. def __init__(
  61. self,
  62. dim,
  63. kernel_size=7,
  64. padding=3,
  65. drop_path=0.0,
  66. layer_scale_init_value=1e-6,
  67. use_dwconv=True,
  68. ):
  69. super().__init__()
  70. self.dwconv = nn.Conv2d(
  71. dim,
  72. dim,
  73. kernel_size=kernel_size,
  74. padding=padding,
  75. groups=dim if use_dwconv else 1,
  76. ) # depthwise conv
  77. self.norm = LayerNorm2d(dim, eps=1e-6)
  78. self.pwconv1 = nn.Linear(
  79. dim, 4 * dim
  80. ) # pointwise/1x1 convs, implemented with linear layers
  81. self.act = nn.GELU()
  82. self.pwconv2 = nn.Linear(4 * dim, dim)
  83. self.gamma = (
  84. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  85. if layer_scale_init_value > 0
  86. else None
  87. )
  88. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  89. def forward(self, x):
  90. input = x
  91. x = self.dwconv(x)
  92. x = self.norm(x)
  93. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  94. x = self.pwconv1(x)
  95. x = self.act(x)
  96. x = self.pwconv2(x)
  97. if self.gamma is not None:
  98. x = self.gamma * x
  99. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  100. x = input + self.drop_path(x)
  101. return x
  102. class Fuser(nn.Module):
  103. def __init__(self, layer, num_layers, dim=None, input_projection=False):
  104. super().__init__()
  105. self.proj = nn.Identity()
  106. self.layers = get_clones(layer, num_layers)
  107. if input_projection:
  108. assert dim is not None
  109. self.proj = nn.Conv2d(dim, dim, kernel_size=1)
  110. def forward(self, x):
  111. # normally x: (N, C, H, W)
  112. x = self.proj(x)
  113. for layer in self.layers:
  114. x = layer(x)
  115. return x
  116. class MemoryEncoder(nn.Module):
  117. def __init__(
  118. self,
  119. out_dim,
  120. mask_downsampler,
  121. fuser,
  122. position_encoding,
  123. in_dim=256, # in_dim of pix_feats
  124. ):
  125. super().__init__()
  126. self.mask_downsampler = mask_downsampler
  127. self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
  128. self.fuser = fuser
  129. self.position_encoding = position_encoding
  130. self.out_proj = nn.Identity()
  131. if out_dim != in_dim:
  132. self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
  133. def forward(
  134. self,
  135. pix_feat: torch.Tensor,
  136. masks: torch.Tensor,
  137. skip_mask_sigmoid: bool = False,
  138. ) -> Tuple[torch.Tensor, torch.Tensor]:
  139. ## Process masks
  140. # sigmoid, so that less domain shift from gt masks which are bool
  141. if not skip_mask_sigmoid:
  142. masks = F.sigmoid(masks)
  143. masks = self.mask_downsampler(masks)
  144. ## Fuse pix_feats and downsampled masks
  145. # in case the visual features are on CPU, cast them to CUDA
  146. pix_feat = pix_feat.to(masks.device)
  147. x = self.pix_feat_proj(pix_feat)
  148. x = x + masks
  149. x = self.fuser(x)
  150. x = self.out_proj(x)
  151. pos = self.position_encoding(x).to(x.dtype)
  152. return {"vision_features": x, "vision_pos_enc": [pos]}