hieradet.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from functools import partial
  6. from typing import List, Tuple, Union
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from sam2.modeling.backbones.utils import (
  11. PatchEmbed,
  12. window_partition,
  13. window_unpartition,
  14. )
  15. from sam2.modeling.sam2_utils import DropPath, MLP
  16. def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
  17. if pool is None:
  18. return x
  19. # (B, H, W, C) -> (B, C, H, W)
  20. x = x.permute(0, 3, 1, 2)
  21. x = pool(x)
  22. # (B, C, H', W') -> (B, H', W', C)
  23. x = x.permute(0, 2, 3, 1)
  24. if norm:
  25. x = norm(x)
  26. return x
  27. class MultiScaleAttention(nn.Module):
  28. def __init__(
  29. self,
  30. dim: int,
  31. dim_out: int,
  32. num_heads: int,
  33. q_pool: nn.Module = None,
  34. ):
  35. super().__init__()
  36. self.dim = dim
  37. self.dim_out = dim_out
  38. self.num_heads = num_heads
  39. self.q_pool = q_pool
  40. self.qkv = nn.Linear(dim, dim_out * 3)
  41. self.proj = nn.Linear(dim_out, dim_out)
  42. def forward(self, x: torch.Tensor) -> torch.Tensor:
  43. B, H, W, _ = x.shape
  44. # qkv with shape (B, H * W, 3, nHead, C)
  45. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  46. # q, k, v with shape (B, H * W, nheads, C)
  47. q, k, v = torch.unbind(qkv, 2)
  48. # Q pooling (for downsample at stage changes)
  49. if self.q_pool:
  50. q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
  51. H, W = q.shape[1:3] # downsampled shape
  52. q = q.reshape(B, H * W, self.num_heads, -1)
  53. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  54. x = F.scaled_dot_product_attention(
  55. q.transpose(1, 2),
  56. k.transpose(1, 2),
  57. v.transpose(1, 2),
  58. )
  59. # Transpose back
  60. x = x.transpose(1, 2)
  61. x = x.reshape(B, H, W, -1)
  62. x = self.proj(x)
  63. return x
  64. class MultiScaleBlock(nn.Module):
  65. def __init__(
  66. self,
  67. dim: int,
  68. dim_out: int,
  69. num_heads: int,
  70. mlp_ratio: float = 4.0,
  71. drop_path: float = 0.0,
  72. norm_layer: Union[nn.Module, str] = "LayerNorm",
  73. q_stride: Tuple[int, int] = None,
  74. act_layer: nn.Module = nn.GELU,
  75. window_size: int = 0,
  76. ):
  77. super().__init__()
  78. if isinstance(norm_layer, str):
  79. norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
  80. self.dim = dim
  81. self.dim_out = dim_out
  82. self.norm1 = norm_layer(dim)
  83. self.window_size = window_size
  84. self.pool, self.q_stride = None, q_stride
  85. if self.q_stride:
  86. self.pool = nn.MaxPool2d(
  87. kernel_size=q_stride, stride=q_stride, ceil_mode=False
  88. )
  89. self.attn = MultiScaleAttention(
  90. dim,
  91. dim_out,
  92. num_heads=num_heads,
  93. q_pool=self.pool,
  94. )
  95. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  96. self.norm2 = norm_layer(dim_out)
  97. self.mlp = MLP(
  98. dim_out,
  99. int(dim_out * mlp_ratio),
  100. dim_out,
  101. num_layers=2,
  102. activation=act_layer,
  103. )
  104. if dim != dim_out:
  105. self.proj = nn.Linear(dim, dim_out)
  106. def forward(self, x: torch.Tensor) -> torch.Tensor:
  107. shortcut = x # B, H, W, C
  108. x = self.norm1(x)
  109. # Skip connection
  110. if self.dim != self.dim_out:
  111. shortcut = do_pool(self.proj(x), self.pool)
  112. # Window partition
  113. window_size = self.window_size
  114. if window_size > 0:
  115. H, W = x.shape[1], x.shape[2]
  116. x, pad_hw = window_partition(x, window_size)
  117. # Window Attention + Q Pooling (if stage change)
  118. x = self.attn(x)
  119. if self.q_stride:
  120. # Shapes have changed due to Q pooling
  121. window_size = self.window_size // self.q_stride[0]
  122. H, W = shortcut.shape[1:3]
  123. pad_h = (window_size - H % window_size) % window_size
  124. pad_w = (window_size - W % window_size) % window_size
  125. pad_hw = (H + pad_h, W + pad_w)
  126. # Reverse window partition
  127. if self.window_size > 0:
  128. x = window_unpartition(x, window_size, pad_hw, (H, W))
  129. x = shortcut + self.drop_path(x)
  130. # MLP
  131. x = x + self.drop_path(self.mlp(self.norm2(x)))
  132. return x
  133. class Hiera(nn.Module):
  134. """
  135. Reference: https://arxiv.org/abs/2306.00989
  136. """
  137. def __init__(
  138. self,
  139. embed_dim: int = 96, # initial embed dim
  140. num_heads: int = 1, # initial number of heads
  141. drop_path_rate: float = 0.0, # stochastic depth
  142. q_pool: int = 3, # number of q_pool stages
  143. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  144. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  145. dim_mul: float = 2.0, # dim_mul factor at stage shift
  146. head_mul: float = 2.0, # head_mul factor at stage shift
  147. window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
  148. # window size per stage, when not using global att.
  149. window_spec: Tuple[int, ...] = (
  150. 8,
  151. 4,
  152. 14,
  153. 7,
  154. ),
  155. # global attn in these blocks
  156. global_att_blocks: Tuple[int, ...] = (
  157. 12,
  158. 16,
  159. 20,
  160. ),
  161. return_interm_layers=True, # return feats from every stage
  162. ):
  163. super().__init__()
  164. assert len(stages) == len(window_spec)
  165. self.window_spec = window_spec
  166. depth = sum(stages)
  167. self.q_stride = q_stride
  168. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  169. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  170. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  171. self.return_interm_layers = return_interm_layers
  172. self.patch_embed = PatchEmbed(
  173. embed_dim=embed_dim,
  174. )
  175. # Which blocks have global att?
  176. self.global_att_blocks = global_att_blocks
  177. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  178. self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
  179. self.pos_embed = nn.Parameter(
  180. torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
  181. )
  182. self.pos_embed_window = nn.Parameter(
  183. torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
  184. )
  185. dpr = [
  186. x.item() for x in torch.linspace(0, drop_path_rate, depth)
  187. ] # stochastic depth decay rule
  188. cur_stage = 1
  189. self.blocks = nn.ModuleList()
  190. for i in range(depth):
  191. dim_out = embed_dim
  192. # lags by a block, so first block of
  193. # next stage uses an initial window size
  194. # of previous stage and final window size of current stage
  195. window_size = self.window_spec[cur_stage - 1]
  196. if self.global_att_blocks is not None:
  197. window_size = 0 if i in self.global_att_blocks else window_size
  198. if i - 1 in self.stage_ends:
  199. dim_out = int(embed_dim * dim_mul)
  200. num_heads = int(num_heads * head_mul)
  201. cur_stage += 1
  202. block = MultiScaleBlock(
  203. dim=embed_dim,
  204. dim_out=dim_out,
  205. num_heads=num_heads,
  206. drop_path=dpr[i],
  207. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  208. window_size=window_size,
  209. )
  210. embed_dim = dim_out
  211. self.blocks.append(block)
  212. self.channel_list = (
  213. [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
  214. if return_interm_layers
  215. else [self.blocks[-1].dim_out]
  216. )
  217. def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
  218. h, w = hw
  219. window_embed = self.pos_embed_window
  220. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  221. pos_embed = pos_embed + window_embed.tile(
  222. [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
  223. )
  224. pos_embed = pos_embed.permute(0, 2, 3, 1)
  225. return pos_embed
  226. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  227. x = self.patch_embed(x)
  228. # x: (B, H, W, C)
  229. # Add pos embed
  230. x = x + self._get_pos_embed(x.shape[1:3])
  231. outputs = []
  232. for i, blk in enumerate(self.blocks):
  233. x = blk(x)
  234. if (i == self.stage_ends[-1]) or (
  235. i in self.stage_ends and self.return_interm_layers
  236. ):
  237. feats = x.permute(0, 3, 1, 2)
  238. outputs.append(feats)
  239. return outputs