hieradet.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. from functools import partial
  7. from typing import List, Tuple, Union
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from iopath.common.file_io import g_pathmgr
  12. from sam2.modeling.backbones.utils import (
  13. PatchEmbed,
  14. window_partition,
  15. window_unpartition,
  16. )
  17. from sam2.modeling.sam2_utils import DropPath, MLP
  18. def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
  19. if pool is None:
  20. return x
  21. # (B, H, W, C) -> (B, C, H, W)
  22. x = x.permute(0, 3, 1, 2)
  23. x = pool(x)
  24. # (B, C, H', W') -> (B, H', W', C)
  25. x = x.permute(0, 2, 3, 1)
  26. if norm:
  27. x = norm(x)
  28. return x
  29. class MultiScaleAttention(nn.Module):
  30. def __init__(
  31. self,
  32. dim: int,
  33. dim_out: int,
  34. num_heads: int,
  35. q_pool: nn.Module = None,
  36. ):
  37. super().__init__()
  38. self.dim = dim
  39. self.dim_out = dim_out
  40. self.num_heads = num_heads
  41. self.q_pool = q_pool
  42. self.qkv = nn.Linear(dim, dim_out * 3)
  43. self.proj = nn.Linear(dim_out, dim_out)
  44. def forward(self, x: torch.Tensor) -> torch.Tensor:
  45. B, H, W, _ = x.shape
  46. # qkv with shape (B, H * W, 3, nHead, C)
  47. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  48. # q, k, v with shape (B, H * W, nheads, C)
  49. q, k, v = torch.unbind(qkv, 2)
  50. # Q pooling (for downsample at stage changes)
  51. if self.q_pool:
  52. q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
  53. H, W = q.shape[1:3] # downsampled shape
  54. q = q.reshape(B, H * W, self.num_heads, -1)
  55. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  56. x = F.scaled_dot_product_attention(
  57. q.transpose(1, 2),
  58. k.transpose(1, 2),
  59. v.transpose(1, 2),
  60. )
  61. # Transpose back
  62. x = x.transpose(1, 2)
  63. x = x.reshape(B, H, W, -1)
  64. x = self.proj(x)
  65. return x
  66. class MultiScaleBlock(nn.Module):
  67. def __init__(
  68. self,
  69. dim: int,
  70. dim_out: int,
  71. num_heads: int,
  72. mlp_ratio: float = 4.0,
  73. drop_path: float = 0.0,
  74. norm_layer: Union[nn.Module, str] = "LayerNorm",
  75. q_stride: Tuple[int, int] = None,
  76. act_layer: nn.Module = nn.GELU,
  77. window_size: int = 0,
  78. ):
  79. super().__init__()
  80. if isinstance(norm_layer, str):
  81. norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
  82. self.dim = dim
  83. self.dim_out = dim_out
  84. self.norm1 = norm_layer(dim)
  85. self.window_size = window_size
  86. self.pool, self.q_stride = None, q_stride
  87. if self.q_stride:
  88. self.pool = nn.MaxPool2d(
  89. kernel_size=q_stride, stride=q_stride, ceil_mode=False
  90. )
  91. self.attn = MultiScaleAttention(
  92. dim,
  93. dim_out,
  94. num_heads=num_heads,
  95. q_pool=self.pool,
  96. )
  97. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  98. self.norm2 = norm_layer(dim_out)
  99. self.mlp = MLP(
  100. dim_out,
  101. int(dim_out * mlp_ratio),
  102. dim_out,
  103. num_layers=2,
  104. activation=act_layer,
  105. )
  106. if dim != dim_out:
  107. self.proj = nn.Linear(dim, dim_out)
  108. def forward(self, x: torch.Tensor) -> torch.Tensor:
  109. shortcut = x # B, H, W, C
  110. x = self.norm1(x)
  111. # Skip connection
  112. if self.dim != self.dim_out:
  113. shortcut = do_pool(self.proj(x), self.pool)
  114. # Window partition
  115. window_size = self.window_size
  116. if window_size > 0:
  117. H, W = x.shape[1], x.shape[2]
  118. x, pad_hw = window_partition(x, window_size)
  119. # Window Attention + Q Pooling (if stage change)
  120. x = self.attn(x)
  121. if self.q_stride:
  122. # Shapes have changed due to Q pooling
  123. window_size = self.window_size // self.q_stride[0]
  124. H, W = shortcut.shape[1:3]
  125. pad_h = (window_size - H % window_size) % window_size
  126. pad_w = (window_size - W % window_size) % window_size
  127. pad_hw = (H + pad_h, W + pad_w)
  128. # Reverse window partition
  129. if self.window_size > 0:
  130. x = window_unpartition(x, window_size, pad_hw, (H, W))
  131. x = shortcut + self.drop_path(x)
  132. # MLP
  133. x = x + self.drop_path(self.mlp(self.norm2(x)))
  134. return x
  135. class Hiera(nn.Module):
  136. """
  137. Reference: https://arxiv.org/abs/2306.00989
  138. """
  139. def __init__(
  140. self,
  141. embed_dim: int = 96, # initial embed dim
  142. num_heads: int = 1, # initial number of heads
  143. drop_path_rate: float = 0.0, # stochastic depth
  144. q_pool: int = 3, # number of q_pool stages
  145. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  146. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  147. dim_mul: float = 2.0, # dim_mul factor at stage shift
  148. head_mul: float = 2.0, # head_mul factor at stage shift
  149. window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
  150. # window size per stage, when not using global att.
  151. window_spec: Tuple[int, ...] = (
  152. 8,
  153. 4,
  154. 14,
  155. 7,
  156. ),
  157. # global attn in these blocks
  158. global_att_blocks: Tuple[int, ...] = (
  159. 12,
  160. 16,
  161. 20,
  162. ),
  163. weights_path=None,
  164. return_interm_layers=True, # return feats from every stage
  165. ):
  166. super().__init__()
  167. assert len(stages) == len(window_spec)
  168. self.window_spec = window_spec
  169. depth = sum(stages)
  170. self.q_stride = q_stride
  171. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  172. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  173. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  174. self.return_interm_layers = return_interm_layers
  175. self.patch_embed = PatchEmbed(
  176. embed_dim=embed_dim,
  177. )
  178. # Which blocks have global att?
  179. self.global_att_blocks = global_att_blocks
  180. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  181. self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
  182. self.pos_embed = nn.Parameter(
  183. torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
  184. )
  185. self.pos_embed_window = nn.Parameter(
  186. torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
  187. )
  188. dpr = [
  189. x.item() for x in torch.linspace(0, drop_path_rate, depth)
  190. ] # stochastic depth decay rule
  191. cur_stage = 1
  192. self.blocks = nn.ModuleList()
  193. for i in range(depth):
  194. dim_out = embed_dim
  195. # lags by a block, so first block of
  196. # next stage uses an initial window size
  197. # of previous stage and final window size of current stage
  198. window_size = self.window_spec[cur_stage - 1]
  199. if self.global_att_blocks is not None:
  200. window_size = 0 if i in self.global_att_blocks else window_size
  201. if i - 1 in self.stage_ends:
  202. dim_out = int(embed_dim * dim_mul)
  203. num_heads = int(num_heads * head_mul)
  204. cur_stage += 1
  205. block = MultiScaleBlock(
  206. dim=embed_dim,
  207. dim_out=dim_out,
  208. num_heads=num_heads,
  209. drop_path=dpr[i],
  210. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  211. window_size=window_size,
  212. )
  213. embed_dim = dim_out
  214. self.blocks.append(block)
  215. self.channel_list = (
  216. [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
  217. if return_interm_layers
  218. else [self.blocks[-1].dim_out]
  219. )
  220. if weights_path is not None:
  221. with g_pathmgr.open(weights_path, "rb") as f:
  222. chkpt = torch.load(f, map_location="cpu")
  223. logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
  224. def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
  225. h, w = hw
  226. window_embed = self.pos_embed_window
  227. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  228. pos_embed = pos_embed + window_embed.tile(
  229. [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
  230. )
  231. pos_embed = pos_embed.permute(0, 2, 3, 1)
  232. return pos_embed
  233. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  234. x = self.patch_embed(x)
  235. # x: (B, H, W, C)
  236. # Add pos embed
  237. x = x + self._get_pos_embed(x.shape[1:3])
  238. outputs = []
  239. for i, blk in enumerate(self.blocks):
  240. x = blk(x)
  241. if (i == self.stage_ends[-1]) or (
  242. i in self.stage_ends and self.return_interm_layers
  243. ):
  244. feats = x.permute(0, 3, 1, 2)
  245. outputs.append(feats)
  246. return outputs
  247. def get_layer_id(self, layer_name):
  248. # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
  249. num_layers = self.get_num_layers()
  250. if layer_name.find("rel_pos") != -1:
  251. return num_layers + 1
  252. elif layer_name.find("pos_embed") != -1:
  253. return 0
  254. elif layer_name.find("patch_embed") != -1:
  255. return 0
  256. elif layer_name.find("blocks") != -1:
  257. return int(layer_name.split("blocks")[1].split(".")[1]) + 1
  258. else:
  259. return num_layers + 1
  260. def get_num_layers(self) -> int:
  261. return len(self.blocks)