memory.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. from typing import Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. try:
  9. from timm.layers import DropPath
  10. except ModuleNotFoundError:
  11. # compatibility for older timm versions
  12. from timm.models.layers import DropPath
  13. from .model_misc import get_clones, LayerNorm2d
  14. class SimpleMaskDownSampler(nn.Module):
  15. """
  16. Progressively downsample a mask by total_stride, each time by stride.
  17. Note that LayerNorm is applied per *token*, like in ViT.
  18. With each downsample (by a factor stride**2), channel capacity increases by the same factor.
  19. In the end, we linearly project to embed_dim channels.
  20. """
  21. def __init__(
  22. self,
  23. embed_dim=256,
  24. kernel_size=4,
  25. stride=4,
  26. padding=0,
  27. total_stride=16,
  28. activation=nn.GELU,
  29. # Option to interpolate the input mask first before downsampling using convs. In that case, the total_stride is assumed to be after interpolation.
  30. # If set to input resolution or None, we don't interpolate. We default to None to be safe (for older configs or if not explicitly set)
  31. interpol_size=None,
  32. ):
  33. super().__init__()
  34. num_layers = int(math.log2(total_stride) // math.log2(stride))
  35. assert stride**num_layers == total_stride
  36. self.encoder = nn.Sequential()
  37. mask_in_chans, mask_out_chans = 1, 1
  38. for _ in range(num_layers):
  39. mask_out_chans = mask_in_chans * (stride**2)
  40. self.encoder.append(
  41. nn.Conv2d(
  42. mask_in_chans,
  43. mask_out_chans,
  44. kernel_size=kernel_size,
  45. stride=stride,
  46. padding=padding,
  47. )
  48. )
  49. self.encoder.append(LayerNorm2d(mask_out_chans))
  50. self.encoder.append(activation())
  51. mask_in_chans = mask_out_chans
  52. self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
  53. self.interpol_size = interpol_size
  54. if self.interpol_size is not None:
  55. assert isinstance(self.interpol_size, (list, tuple)), (
  56. f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
  57. )
  58. self.interpol_size = list(interpol_size)
  59. assert len(self.interpol_size) == 2
  60. def forward(self, x: torch.Tensor):
  61. if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
  62. x = F.interpolate(
  63. x.float(),
  64. size=self.interpol_size,
  65. align_corners=False,
  66. mode="bilinear",
  67. antialias=True,
  68. )
  69. return self.encoder(x)
  70. # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
  71. class CXBlock(nn.Module):
  72. r"""ConvNeXt Block. There are two equivalent implementations:
  73. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  74. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  75. We use (2) as we find it slightly faster in PyTorch
  76. Args:
  77. dim (int): Number of input channels.
  78. drop_path (float): Stochastic depth rate. Default: 0.0
  79. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  80. """
  81. def __init__(
  82. self,
  83. dim,
  84. kernel_size=7,
  85. padding=3,
  86. drop_path=0.0,
  87. layer_scale_init_value=1e-6,
  88. use_dwconv=True,
  89. ):
  90. super().__init__()
  91. self.dwconv = nn.Conv2d(
  92. dim,
  93. dim,
  94. kernel_size=kernel_size,
  95. padding=padding,
  96. groups=dim if use_dwconv else 1,
  97. ) # depthwise conv
  98. self.norm = LayerNorm2d(dim, eps=1e-6)
  99. self.pwconv1 = nn.Linear(
  100. dim, 4 * dim
  101. ) # pointwise/1x1 convs, implemented with linear layers
  102. self.act = nn.GELU()
  103. self.pwconv2 = nn.Linear(4 * dim, dim)
  104. self.gamma = (
  105. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  106. if layer_scale_init_value > 0
  107. else None
  108. )
  109. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  110. def forward(self, x):
  111. input = x
  112. x = self.dwconv(x)
  113. x = self.norm(x)
  114. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  115. x = self.pwconv1(x)
  116. x = self.act(x)
  117. x = self.pwconv2(x)
  118. if self.gamma is not None:
  119. x = self.gamma * x
  120. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  121. x = input + self.drop_path(x)
  122. return x
  123. class SimpleFuser(nn.Module):
  124. def __init__(self, layer, num_layers, dim=None, input_projection=False):
  125. super().__init__()
  126. self.proj = nn.Identity()
  127. self.layers = get_clones(layer, num_layers)
  128. if input_projection:
  129. assert dim is not None
  130. self.proj = nn.Conv2d(dim, dim, kernel_size=1)
  131. def forward(self, x):
  132. # normally x: (N, C, H, W)
  133. x = self.proj(x)
  134. for layer in self.layers:
  135. x = layer(x)
  136. return x
  137. class SimpleMaskEncoder(nn.Module):
  138. def __init__(
  139. self,
  140. out_dim,
  141. mask_downsampler,
  142. fuser,
  143. position_encoding,
  144. in_dim=256, # in_dim of pix_feats
  145. ):
  146. super().__init__()
  147. self.mask_downsampler = mask_downsampler
  148. self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
  149. self.fuser = fuser
  150. self.position_encoding = position_encoding
  151. self.out_proj = nn.Identity()
  152. if out_dim != in_dim:
  153. self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
  154. def forward(
  155. self,
  156. pix_feat: torch.Tensor,
  157. masks: torch.Tensor,
  158. skip_mask_sigmoid: bool = False,
  159. ) -> Tuple[torch.Tensor, torch.Tensor]:
  160. ## Process masks
  161. # sigmoid, so that less domain shift from gt masks which are bool
  162. if not skip_mask_sigmoid:
  163. masks = F.sigmoid(masks)
  164. masks = self.mask_downsampler(masks)
  165. ## Fuse pix_feats and downsampled masks
  166. # in case the visual features are on CPU, cast them to CUDA
  167. pix_feat = pix_feat.to(masks.device)
  168. x = self.pix_feat_proj(pix_feat)
  169. x = x + masks
  170. x = self.fuser(x)
  171. x = self.out_proj(x)
  172. pos = self.position_encoding(x).to(x.dtype)
  173. return {"vision_features": x, "vision_pos_enc": [pos]}