hieradet.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from functools import partial
  6. from typing import List, Tuple, Union
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from sam2.modeling.backbones.utils import (
  11. PatchEmbed,
  12. window_partition,
  13. window_unpartition,
  14. )
  15. from sam2.modeling.sam2_utils import DropPath, MLP
  16. def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
  17. if pool is None:
  18. return x
  19. # (B, H, W, C) -> (B, C, H, W)
  20. x = x.permute(0, 3, 1, 2)
  21. x = pool(x)
  22. # (B, C, H', W') -> (B, H', W', C)
  23. x = x.permute(0, 2, 3, 1)
  24. if norm:
  25. x = norm(x)
  26. return x
  27. class MultiScaleAttention(nn.Module):
  28. def __init__(
  29. self,
  30. dim: int,
  31. dim_out: int,
  32. num_heads: int,
  33. q_pool: nn.Module = None,
  34. ):
  35. super().__init__()
  36. self.dim = dim
  37. self.dim_out = dim_out
  38. self.num_heads = num_heads
  39. head_dim = dim_out // num_heads
  40. self.scale = head_dim**-0.5
  41. self.q_pool = q_pool
  42. self.qkv = nn.Linear(dim, dim_out * 3)
  43. self.proj = nn.Linear(dim_out, dim_out)
  44. def forward(self, x: torch.Tensor) -> torch.Tensor:
  45. B, H, W, _ = x.shape
  46. # qkv with shape (B, H * W, 3, nHead, C)
  47. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  48. # q, k, v with shape (B, H * W, nheads, C)
  49. q, k, v = torch.unbind(qkv, 2)
  50. # Q pooling (for downsample at stage changes)
  51. if self.q_pool:
  52. q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
  53. H, W = q.shape[1:3] # downsampled shape
  54. q = q.reshape(B, H * W, self.num_heads, -1)
  55. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  56. x = F.scaled_dot_product_attention(
  57. q.transpose(1, 2),
  58. k.transpose(1, 2),
  59. v.transpose(1, 2),
  60. )
  61. # Transpose back
  62. x = x.transpose(1, 2)
  63. x = x.reshape(B, H, W, -1)
  64. x = self.proj(x)
  65. return x
  66. class MultiScaleBlock(nn.Module):
  67. def __init__(
  68. self,
  69. dim: int,
  70. dim_out: int,
  71. num_heads: int,
  72. mlp_ratio: float = 4.0,
  73. drop_path: float = 0.0,
  74. norm_layer: Union[nn.Module, str] = "LayerNorm",
  75. q_stride: Tuple[int, int] = None,
  76. act_layer: nn.Module = nn.GELU,
  77. window_size: int = 0,
  78. ):
  79. super().__init__()
  80. if isinstance(norm_layer, str):
  81. norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
  82. self.dim = dim
  83. self.dim_out = dim_out
  84. self.norm1 = norm_layer(dim)
  85. self.window_size = window_size
  86. self.pool, self.q_stride = None, q_stride
  87. if self.q_stride:
  88. self.pool = nn.MaxPool2d(
  89. kernel_size=q_stride, stride=q_stride, ceil_mode=False
  90. )
  91. self.attn = MultiScaleAttention(
  92. dim,
  93. dim_out,
  94. num_heads=num_heads,
  95. q_pool=self.pool,
  96. )
  97. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  98. self.norm2 = norm_layer(dim_out)
  99. self.mlp = MLP(
  100. dim_out,
  101. int(dim_out * mlp_ratio),
  102. dim_out,
  103. num_layers=2,
  104. activation=act_layer,
  105. )
  106. if dim != dim_out:
  107. self.proj = nn.Linear(dim, dim_out)
  108. def forward(self, x: torch.Tensor) -> torch.Tensor:
  109. shortcut = x # B, H, W, C
  110. x = self.norm1(x)
  111. # Skip connection
  112. if self.dim != self.dim_out:
  113. shortcut = do_pool(self.proj(x), self.pool)
  114. # Window partition
  115. window_size = self.window_size
  116. if window_size > 0:
  117. H, W = x.shape[1], x.shape[2]
  118. x, pad_hw = window_partition(x, window_size)
  119. # Window Attention + Q Pooling (if stage change)
  120. x = self.attn(x)
  121. if self.q_stride:
  122. # Shapes have changed due to Q pooling
  123. window_size = self.window_size // self.q_stride[0]
  124. H, W = shortcut.shape[1:3]
  125. pad_h = (window_size - H % window_size) % window_size
  126. pad_w = (window_size - W % window_size) % window_size
  127. pad_hw = (H + pad_h, W + pad_w)
  128. # Reverse window partition
  129. if self.window_size > 0:
  130. x = window_unpartition(x, window_size, pad_hw, (H, W))
  131. x = shortcut + self.drop_path(x)
  132. # MLP
  133. x = x + self.drop_path(self.mlp(self.norm2(x)))
  134. return x
  135. class Hiera(nn.Module):
  136. """
  137. Reference: https://arxiv.org/abs/2306.00989
  138. """
  139. def __init__(
  140. self,
  141. embed_dim: int = 96, # initial embed dim
  142. num_heads: int = 1, # initial number of heads
  143. drop_path_rate: float = 0.0, # stochastic depth
  144. q_pool: int = 3, # number of q_pool stages
  145. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  146. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  147. dim_mul: float = 2.0, # dim_mul factor at stage shift
  148. head_mul: float = 2.0, # head_mul factor at stage shift
  149. window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
  150. # window size per stage, when not using global att.
  151. window_spec: Tuple[int, ...] = (
  152. 8,
  153. 4,
  154. 14,
  155. 7,
  156. ),
  157. # global attn in these blocks
  158. global_att_blocks: Tuple[int, ...] = (
  159. 12,
  160. 16,
  161. 20,
  162. ),
  163. return_interm_layers=True, # return feats from every stage
  164. ):
  165. super().__init__()
  166. assert len(stages) == len(window_spec)
  167. self.window_spec = window_spec
  168. depth = sum(stages)
  169. self.q_stride = q_stride
  170. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  171. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  172. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  173. self.return_interm_layers = return_interm_layers
  174. self.patch_embed = PatchEmbed(
  175. embed_dim=embed_dim,
  176. )
  177. # Which blocks have global att?
  178. self.global_att_blocks = global_att_blocks
  179. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  180. self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
  181. self.pos_embed = nn.Parameter(
  182. torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
  183. )
  184. self.pos_embed_window = nn.Parameter(
  185. torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
  186. )
  187. dpr = [
  188. x.item() for x in torch.linspace(0, drop_path_rate, depth)
  189. ] # stochastic depth decay rule
  190. cur_stage = 1
  191. self.blocks = nn.ModuleList()
  192. for i in range(depth):
  193. dim_out = embed_dim
  194. # lags by a block, so first block of
  195. # next stage uses an initial window size
  196. # of previous stage and final window size of current stage
  197. window_size = self.window_spec[cur_stage - 1]
  198. if self.global_att_blocks is not None:
  199. window_size = 0 if i in self.global_att_blocks else window_size
  200. if i - 1 in self.stage_ends:
  201. dim_out = int(embed_dim * dim_mul)
  202. num_heads = int(num_heads * head_mul)
  203. cur_stage += 1
  204. block = MultiScaleBlock(
  205. dim=embed_dim,
  206. dim_out=dim_out,
  207. num_heads=num_heads,
  208. drop_path=dpr[i],
  209. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  210. window_size=window_size,
  211. )
  212. embed_dim = dim_out
  213. self.blocks.append(block)
  214. self.channel_list = (
  215. [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
  216. if return_interm_layers
  217. else [self.blocks[-1].dim_out]
  218. )
  219. def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
  220. h, w = hw
  221. window_embed = self.pos_embed_window
  222. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  223. pos_embed = pos_embed + window_embed.tile(
  224. [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
  225. )
  226. pos_embed = pos_embed.permute(0, 2, 3, 1)
  227. return pos_embed
  228. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  229. x = self.patch_embed(x)
  230. # x: (B, H, W, C)
  231. # Add pos embed
  232. x = x + self._get_pos_embed(x.shape[1:3])
  233. outputs = []
  234. for i, blk in enumerate(self.blocks):
  235. x = blk(x)
  236. if (i == self.stage_ends[-1]) or (
  237. i in self.stage_ends and self.return_interm_layers
  238. ):
  239. feats = x.permute(0, 3, 1, 2)
  240. outputs.append(feats)
  241. return outputs