vitdet.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. ViTDet backbone adapted from Detectron2.
  5. This module implements Vision Transformer (ViT) backbone for object detection.
  6. Rope embedding code adopted from:
  7. 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
  8. 2. https://github.com/naver-ai/rope-vit
  9. 3. https://github.com/lucidrains/rotary-embedding-torch
  10. """
  11. import math
  12. from functools import partial
  13. from typing import Callable, List, Optional, Tuple, Union
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. import torch.utils.checkpoint as checkpoint
  18. try:
  19. from timm.layers import DropPath, Mlp, trunc_normal_
  20. except ModuleNotFoundError:
  21. # compatibility for older timm versions
  22. from timm.models.layers import DropPath, Mlp, trunc_normal_
  23. from torch import Tensor
  24. from .model_misc import LayerScale
  25. def init_t_xy(
  26. end_x: int, end_y: int, scale: float = 1.0, offset: int = 0
  27. ) -> Tuple[torch.Tensor, torch.Tensor]:
  28. t = torch.arange(end_x * end_y, dtype=torch.float32)
  29. t_x = (t % end_x).float()
  30. t_y = torch.div(t, end_x, rounding_mode="floor").float()
  31. return t_x * scale + offset, t_y * scale + offset
  32. def compute_axial_cis(
  33. dim: int,
  34. end_x: int,
  35. end_y: int,
  36. theta: float = 10000.0,
  37. scale_pos: float = 1.0,
  38. offset: int = 0,
  39. ) -> torch.Tensor:
  40. freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  41. freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  42. t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset)
  43. freqs_x = torch.outer(t_x, freqs_x)
  44. freqs_y = torch.outer(t_y, freqs_y)
  45. freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
  46. freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
  47. return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
  48. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
  49. ndim = x.ndim
  50. assert 0 <= 1 < ndim
  51. assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
  52. shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
  53. return freqs_cis.view(*shape)
  54. def apply_rotary_enc(
  55. xq: torch.Tensor,
  56. xk: torch.Tensor,
  57. freqs_cis: torch.Tensor,
  58. repeat_freqs_k: bool = False,
  59. ) -> Tuple[torch.Tensor, torch.Tensor]:
  60. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  61. xk_ = (
  62. torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  63. if xk.shape[-2] != 0
  64. else None
  65. )
  66. freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
  67. xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
  68. if xk_ is None:
  69. # no keys to rotate, due to dropout
  70. return xq_out.type_as(xq).to(xq.device), xk
  71. # repeat freqs along seq_len dim to match k seq_len
  72. if repeat_freqs_k:
  73. r = xk_.shape[-2] // xq_.shape[-2]
  74. freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
  75. xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
  76. return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
  77. def window_partition(x: Tensor, window_size: int) -> Tuple[Tensor, Tuple[int, int]]:
  78. """
  79. Partition into non-overlapping windows with padding if needed.
  80. Args:
  81. x (tensor): input tokens with [B, H, W, C].
  82. window_size (int): window size.
  83. Returns:
  84. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  85. (Hp, Wp): padded height and width before partition
  86. """
  87. B, H, W, C = x.shape
  88. pad_h = (window_size - H % window_size) % window_size
  89. pad_w = (window_size - W % window_size) % window_size
  90. if pad_h > 0 or pad_w > 0:
  91. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  92. Hp, Wp = H + pad_h, W + pad_w
  93. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  94. windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
  95. return windows, (Hp, Wp)
  96. def window_unpartition(
  97. windows: Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
  98. ) -> Tensor:
  99. """
  100. Window unpartition into original sequences and removing padding.
  101. Args:
  102. x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  103. window_size (int): window size.
  104. pad_hw (Tuple): padded height and width (Hp, Wp).
  105. hw (Tuple): original height and width (H, W) before padding.
  106. Returns:
  107. x: unpartitioned sequences with [B, H, W, C].
  108. """
  109. Hp, Wp = pad_hw
  110. H, W = hw
  111. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  112. x = windows.reshape(
  113. B, Hp // window_size, Wp // window_size, window_size, window_size, -1
  114. )
  115. x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
  116. if Hp > H or Wp > W:
  117. x = x[:, :H, :W, :]
  118. return x
  119. def get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) -> Tensor:
  120. """
  121. Get relative positional embeddings according to the relative positions of
  122. query and key sizes.
  123. Args:
  124. q_size (int): size of query q.
  125. k_size (int): size of key k.
  126. rel_pos (Tensor): relative position embeddings (L, C).
  127. Returns:
  128. Extracted positional embeddings according to relative positions.
  129. """
  130. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  131. # Interpolate rel pos if needed.
  132. if rel_pos.shape[0] != max_rel_dist:
  133. # Interpolate rel pos.
  134. rel_pos_resized = F.interpolate(
  135. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  136. size=max_rel_dist,
  137. mode="linear",
  138. align_corners=False,
  139. )
  140. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  141. else:
  142. rel_pos_resized = rel_pos
  143. # Scale the coords with short length if shapes for q and k are different.
  144. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  145. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  146. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  147. return rel_pos_resized[relative_coords.long()]
  148. def get_abs_pos(
  149. abs_pos: Tensor,
  150. has_cls_token: bool,
  151. hw: Tuple[int, int],
  152. retain_cls_token: bool = False,
  153. tiling: bool = False,
  154. ) -> Tensor:
  155. """
  156. Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
  157. dimension for the original embeddings.
  158. Args:
  159. abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
  160. has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
  161. hw (Tuple): size of input image tokens.
  162. retain_cls_token: whether to retain the cls_token
  163. tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
  164. Returns:
  165. Absolute positional embeddings after processing with shape (1, H, W, C),
  166. if retain_cls_token is False, otherwise (1, 1+H*W, C)
  167. """
  168. if retain_cls_token:
  169. assert has_cls_token
  170. h, w = hw
  171. if has_cls_token:
  172. cls_pos = abs_pos[:, :1]
  173. abs_pos = abs_pos[:, 1:]
  174. xy_num = abs_pos.shape[1]
  175. size = int(math.sqrt(xy_num))
  176. assert size * size == xy_num
  177. if size != h or size != w:
  178. new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
  179. if tiling:
  180. new_abs_pos = new_abs_pos.tile(
  181. [1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])]
  182. )[:, :, :h, :w]
  183. else:
  184. new_abs_pos = F.interpolate(
  185. new_abs_pos,
  186. size=(h, w),
  187. mode="bicubic",
  188. align_corners=False,
  189. )
  190. if not retain_cls_token:
  191. return new_abs_pos.permute(0, 2, 3, 1)
  192. else:
  193. # add cls_token back, flatten spatial dims
  194. assert has_cls_token
  195. return torch.cat(
  196. [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
  197. dim=1,
  198. )
  199. else:
  200. if not retain_cls_token:
  201. return abs_pos.reshape(1, h, w, -1)
  202. else:
  203. assert has_cls_token
  204. return torch.cat([cls_pos, abs_pos], dim=1)
  205. def concat_rel_pos(
  206. q: Tensor,
  207. k: Tensor,
  208. q_hw: Tuple[int, int],
  209. k_hw: Tuple[int, int],
  210. rel_pos_h: Tensor,
  211. rel_pos_w: Tensor,
  212. rescale: bool = False,
  213. relative_coords: Optional[Tensor] = None,
  214. ) -> Tuple[Tensor, Tensor]:
  215. """
  216. Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now
  217. effectively including rel pos biases.
  218. Args:
  219. q (Tensor): q tensor with shape (B, L_q, C).
  220. k (Tensor): k tensor with shape (B, L_k, C).
  221. q_hw, k_hw: These are spatial size of q & k tensors.
  222. rel_pos_h, rel_pos_w: These are relative pos embeddings/params of height, width.
  223. rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will
  224. scale by the wrong factor due to the concat.
  225. Returns:
  226. q, k: But, padded so that qk^T accounts for rel pos biases
  227. """
  228. q_h, q_w = q_hw
  229. k_h, k_w = k_hw
  230. assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
  231. if relative_coords is not None:
  232. Rh = rel_pos_h[relative_coords]
  233. Rw = rel_pos_w[relative_coords]
  234. else:
  235. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  236. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  237. B, _, dim = q.shape
  238. r_q = q.reshape(B, q_h, q_w, dim)
  239. old_scale = dim**0.5
  240. new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
  241. # attn will be divided by new_scale, but we want to divide q by old_scale
  242. scale_ratio = new_scale / old_scale
  243. rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
  244. rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
  245. eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
  246. eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
  247. eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
  248. eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
  249. q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
  250. k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(
  251. B, k_h * k_w, -1
  252. )
  253. return q, k
  254. class PatchEmbed(nn.Module):
  255. """
  256. Image to Patch Embedding.
  257. """
  258. def __init__(
  259. self,
  260. kernel_size: Tuple[int, int] = (16, 16),
  261. stride: Tuple[int, int] = (16, 16),
  262. padding: Tuple[int, int] = (0, 0),
  263. in_chans: int = 3,
  264. embed_dim: int = 768,
  265. bias: bool = True,
  266. ):
  267. """
  268. Args:
  269. kernel_size (Tuple): kernel size of the projection layer.
  270. stride (Tuple): stride of the projection layer.
  271. padding (Tuple): padding size of the projection layer.
  272. in_chans (int): Number of input image channels.
  273. embed_dim (int): embed_dim (int): Patch embedding dimension.
  274. """
  275. super().__init__()
  276. self.proj = nn.Conv2d(
  277. in_chans,
  278. embed_dim,
  279. kernel_size=kernel_size,
  280. stride=stride,
  281. padding=padding,
  282. bias=bias,
  283. )
  284. def forward(self, x: Tensor) -> Tensor:
  285. x = self.proj(x)
  286. # B C H W -> B H W C
  287. x = x.permute(0, 2, 3, 1)
  288. return x
  289. class Attention(nn.Module):
  290. """Multi-head Attention block with relative position embeddings and 2d-rope."""
  291. def __init__(
  292. self,
  293. dim: int,
  294. num_heads: int = 8,
  295. qkv_bias: bool = True,
  296. use_rel_pos: bool = False,
  297. rel_pos_zero_init: bool = True,
  298. input_size: Optional[Tuple[int, int]] = None,
  299. cls_token: bool = False,
  300. use_rope: bool = False,
  301. rope_theta: float = 10000.0,
  302. rope_pt_size: Optional[Tuple[int, int]] = None,
  303. rope_interp: bool = False,
  304. ):
  305. """
  306. Args:
  307. dim (int): Number of input channels.
  308. num_heads (int): Number of attention heads.
  309. qkv_bias (bool: If True, add a learnable bias to query, key, value.
  310. rel_pos (bool): If True, add relative positional embeddings to the attention map.
  311. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  312. input_size (int or None): Input resolution for calculating the relative positional
  313. parameter size or rope size.
  314. attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
  315. cls_token: whether a cls_token is present.
  316. use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
  317. rope_theta: control frequencies of rope
  318. rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
  319. rope_interp: whether to interpolate (or extrapolate) rope to match input size
  320. """
  321. super().__init__()
  322. self.num_heads = num_heads
  323. self.head_dim = dim // num_heads
  324. self.scale = self.head_dim**-0.5
  325. self.cls_token = cls_token
  326. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  327. self.proj = nn.Linear(dim, dim)
  328. # rel_pos embeddings and rope
  329. self.use_rel_pos = use_rel_pos
  330. self.input_size = input_size
  331. self.use_rope = use_rope
  332. self.rope_theta = rope_theta
  333. self.rope_pt_size = rope_pt_size
  334. self.rope_interp = rope_interp
  335. # init rel_pos embeddings and rope
  336. self._setup_rel_pos(rel_pos_zero_init)
  337. self._setup_rope_freqs()
  338. def _setup_rel_pos(self, rel_pos_zero_init: bool = True) -> None:
  339. if not self.use_rel_pos:
  340. self.rel_pos_h = None
  341. self.rel_pos_w = None
  342. return
  343. assert self.input_size is not None
  344. assert self.cls_token is False, "not supported"
  345. # initialize relative positional embeddings
  346. self.rel_pos_h = nn.Parameter(
  347. torch.zeros(2 * self.input_size[0] - 1, self.head_dim)
  348. )
  349. self.rel_pos_w = nn.Parameter(
  350. torch.zeros(2 * self.input_size[1] - 1, self.head_dim)
  351. )
  352. if not rel_pos_zero_init:
  353. trunc_normal_(self.rel_pos_h, std=0.02)
  354. trunc_normal_(self.rel_pos_w, std=0.02)
  355. # Precompute the relative coords
  356. H, W = self.input_size
  357. q_coords = torch.arange(H)[:, None]
  358. k_coords = torch.arange(W)[None, :]
  359. relative_coords = (q_coords - k_coords) + (H - 1)
  360. self.register_buffer("relative_coords", relative_coords.long())
  361. def _setup_rope_freqs(self) -> None:
  362. if not self.use_rope:
  363. self.freqs_cis = None
  364. return
  365. assert self.input_size is not None
  366. # determine rope input size
  367. if self.rope_pt_size is None:
  368. self.rope_pt_size = self.input_size
  369. # initialize 2d rope freqs
  370. self.compute_cis = partial(
  371. compute_axial_cis,
  372. dim=self.head_dim,
  373. theta=self.rope_theta,
  374. )
  375. # interpolate rope
  376. scale_pos = 1.0
  377. if self.rope_interp:
  378. scale_pos = self.rope_pt_size[0] / self.input_size[0]
  379. # get scaled freqs_cis
  380. freqs_cis = self.compute_cis(
  381. end_x=self.input_size[0],
  382. end_y=self.input_size[1],
  383. scale_pos=scale_pos,
  384. )
  385. if self.cls_token:
  386. t = torch.zeros(
  387. self.head_dim // 2,
  388. dtype=torch.float32,
  389. device=freqs_cis.device,
  390. )
  391. cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
  392. freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
  393. self.register_buffer("freqs_cis", freqs_cis)
  394. def _apply_rope(self, q, k) -> Tuple[Tensor, Tensor]:
  395. if not self.use_rope:
  396. return q, k
  397. assert self.freqs_cis is not None
  398. return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis)
  399. def forward(self, x: Tensor) -> Tensor:
  400. s = 1 if self.cls_token else 0 # used to exclude cls_token
  401. if x.ndim == 4:
  402. B, H, W, _ = x.shape
  403. assert s == 0 # no cls_token
  404. L = H * W
  405. ndim = 4
  406. else:
  407. assert x.ndim == 3
  408. B, L, _ = x.shape
  409. ndim = 3
  410. H = W = math.sqrt(L - s)
  411. # qkv with shape (3, B, nHead, L, C)
  412. qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
  413. # q, k, v with shape (B, nHead, L, C)
  414. q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
  415. # handle rope and rel pos embeddings
  416. q, k = self._apply_rope(q, k)
  417. if self.use_rel_pos:
  418. q, k = concat_rel_pos(
  419. q.flatten(0, 1),
  420. k.flatten(0, 1),
  421. (H, W),
  422. x.shape[1:3],
  423. self.rel_pos_h,
  424. self.rel_pos_w,
  425. rescale=True,
  426. relative_coords=self.relative_coords,
  427. )
  428. # sdpa expects [B, nheads, H*W, C] so we transpose back
  429. q = q.reshape(B, self.num_heads, H * W, -1)
  430. k = k.reshape(B, self.num_heads, H * W, -1)
  431. x = F.scaled_dot_product_attention(q, k, v)
  432. if ndim == 4:
  433. x = (
  434. x.view(B, self.num_heads, H, W, -1)
  435. .permute(0, 2, 3, 1, 4)
  436. .reshape(B, H, W, -1)
  437. )
  438. else:
  439. x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
  440. x = self.proj(x)
  441. return x
  442. class Block(nn.Module):
  443. """Transformer blocks with support of window attention"""
  444. def __init__(
  445. self,
  446. dim: int,
  447. num_heads: int,
  448. mlp_ratio: float = 4.0,
  449. qkv_bias: bool = True,
  450. drop_path: float = 0.0,
  451. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  452. act_layer: Callable[..., nn.Module] = nn.GELU,
  453. use_rel_pos: bool = False,
  454. rel_pos_zero_init: bool = True,
  455. window_size: int = 0,
  456. input_size: Optional[Tuple[int, int]] = None,
  457. use_rope: bool = False,
  458. rope_pt_size: Optional[Tuple[int, int]] = None,
  459. rope_tiled: bool = False,
  460. rope_interp: bool = False,
  461. use_ve_rope: bool = False,
  462. cls_token: bool = False,
  463. dropout: float = 0.0,
  464. init_values: Optional[float] = None,
  465. ):
  466. """
  467. Args:
  468. dim (int): Number of input channels.
  469. num_heads (int): Number of attention heads in each ViT block.
  470. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  471. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  472. drop_path (float): Stochastic depth rate.
  473. norm_layer (nn.Module): Normalization layer.
  474. act_layer (nn.Module): Activation layer.
  475. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  476. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  477. window_size (int): Window size for window attention blocks. If it equals 0, then not
  478. use window attention.
  479. input_size (int or None): Input resolution for calculating the relative positional
  480. parameter size.
  481. dropout (float): Dropout rate.
  482. cls_token: whether a cls_token is present.
  483. use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
  484. rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
  485. rope_interp: whether to interpolate (or extrapolate) rope to match target input size,
  486. expected to specify source size as rope_pt_size.
  487. """
  488. super().__init__()
  489. self.norm1 = norm_layer(dim)
  490. self.attn = Attention(
  491. dim,
  492. num_heads=num_heads,
  493. qkv_bias=qkv_bias,
  494. use_rel_pos=use_rel_pos,
  495. rel_pos_zero_init=rel_pos_zero_init,
  496. input_size=input_size if window_size == 0 else (window_size, window_size),
  497. use_rope=use_rope,
  498. rope_pt_size=rope_pt_size,
  499. rope_interp=rope_interp,
  500. cls_token=cls_token,
  501. )
  502. self.ls1 = (
  503. LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
  504. )
  505. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  506. self.norm2 = norm_layer(dim)
  507. self.mlp = Mlp(
  508. in_features=dim,
  509. hidden_features=int(dim * mlp_ratio),
  510. act_layer=act_layer,
  511. drop=(dropout, 0.0),
  512. )
  513. self.ls2 = (
  514. LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
  515. )
  516. self.dropout = nn.Dropout(dropout)
  517. self.window_size = window_size
  518. def forward(self, x: Tensor) -> Tensor:
  519. shortcut = x
  520. x = self.norm1(x)
  521. # Window partition
  522. if self.window_size > 0:
  523. H, W = x.shape[1], x.shape[2]
  524. x, pad_hw = window_partition(x, self.window_size)
  525. x = self.ls1(self.attn(x))
  526. # Reverse window partition
  527. if self.window_size > 0:
  528. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  529. x = shortcut + self.dropout(self.drop_path(x))
  530. x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
  531. return x
  532. class ViT(nn.Module):
  533. """
  534. This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
  535. "Exploring Plain Vision Transformer Backbones for Object Detection",
  536. https://arxiv.org/abs/2203.16527
  537. """
  538. def __init__(
  539. self,
  540. img_size: int = 1024,
  541. patch_size: int = 16,
  542. in_chans: int = 3,
  543. embed_dim: int = 768,
  544. depth: int = 12,
  545. num_heads: int = 12,
  546. mlp_ratio: float = 4.0,
  547. qkv_bias: bool = True,
  548. drop_path_rate: float = 0.0,
  549. norm_layer: Union[Callable[..., nn.Module], str] = "LayerNorm",
  550. act_layer: Callable[..., nn.Module] = nn.GELU,
  551. use_abs_pos: bool = True,
  552. tile_abs_pos: bool = True,
  553. rel_pos_blocks: Union[Tuple[int, ...], bool] = (2, 5, 8, 11),
  554. rel_pos_zero_init: bool = True,
  555. window_size: int = 14,
  556. global_att_blocks: Tuple[int, ...] = (2, 5, 8, 11),
  557. use_rope: bool = False,
  558. rope_pt_size: Optional[int] = None,
  559. use_interp_rope: bool = False,
  560. pretrain_img_size: int = 224,
  561. pretrain_use_cls_token: bool = True,
  562. retain_cls_token: bool = True,
  563. dropout: float = 0.0,
  564. return_interm_layers: bool = False,
  565. init_values: Optional[float] = None, # for layerscale
  566. ln_pre: bool = False,
  567. ln_post: bool = False,
  568. bias_patch_embed: bool = True,
  569. compile_mode: Optional[str] = None,
  570. use_act_checkpoint: bool = True,
  571. ):
  572. """
  573. Args:
  574. img_size (int): Input image size. Only relevant for rel pos or rope.
  575. patch_size (int): Patch size.
  576. in_chans (int): Number of input image channels.
  577. embed_dim (int): Patch embedding dimension.
  578. depth (int): Depth of ViT.
  579. num_heads (int): Number of attention heads in each ViT block.
  580. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  581. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  582. drop_path_rate (float): Stochastic depth rate.
  583. norm_layer (nn.Module): Normalization layer.
  584. act_layer (nn.Module): Activation layer.
  585. use_abs_pos (bool): If True, use absolute positional embeddings.
  586. tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
  587. rel_pos_blocks (list): Blocks which have rel pos embeddings.
  588. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  589. window_size (int): Window size for window attention blocks.
  590. global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
  591. use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
  592. rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
  593. use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size,
  594. expected to specify source size as rope_pt_size.
  595. use_act_checkpoint (bool): If True, use activation checkpointing.
  596. pretrain_img_size (int): input image size for pretraining models.
  597. pretrain_use_cls_token (bool): If True, pretraining models use class token.
  598. retain_cls_token: whether cls_token should be retained.
  599. dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
  600. return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
  601. init_values: layer scale init, None for no layer scale.
  602. ln_pre (bool): If True, apply layer norm before transformer blocks.
  603. ln_post (bool): If True, apply layer norm after transformer blocks.
  604. bias_patch_embed (bool): bias in conv for patch embed?
  605. compile_mode (str): mode to compile the forward
  606. """
  607. super().__init__()
  608. self.pretrain_use_cls_token = pretrain_use_cls_token
  609. window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
  610. self.full_attn_ids = list(global_att_blocks)
  611. self.rel_pos_blocks = [False] * depth
  612. if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
  613. self.rel_pos_blocks = [True] * depth
  614. else:
  615. for i in rel_pos_blocks:
  616. self.rel_pos_blocks[i] = True
  617. self.retain_cls_token = retain_cls_token
  618. if self.retain_cls_token:
  619. assert pretrain_use_cls_token
  620. assert len(window_block_indexes) == 0, (
  621. "windowing not supported with cls token"
  622. )
  623. assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
  624. scale = embed_dim**-0.5
  625. self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
  626. if isinstance(norm_layer, str):
  627. norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
  628. self.patch_embed = PatchEmbed(
  629. kernel_size=(patch_size, patch_size),
  630. stride=(patch_size, patch_size),
  631. in_chans=in_chans,
  632. embed_dim=embed_dim,
  633. bias=bias_patch_embed,
  634. )
  635. # Handle absolute positional embedding
  636. self.tile_abs_pos = tile_abs_pos
  637. self.use_abs_pos = use_abs_pos
  638. if self.tile_abs_pos:
  639. assert self.use_abs_pos
  640. if self.use_abs_pos:
  641. # Initialize absolute positional embedding with pretrain image size.
  642. num_patches = (pretrain_img_size // patch_size) * (
  643. pretrain_img_size // patch_size
  644. )
  645. num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
  646. self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
  647. else:
  648. self.pos_embed = None
  649. # stochastic depth decay rule
  650. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
  651. self.blocks = nn.ModuleList()
  652. cur_stage = 1
  653. for i in range(depth):
  654. block = Block(
  655. dim=embed_dim,
  656. num_heads=num_heads,
  657. mlp_ratio=mlp_ratio,
  658. qkv_bias=qkv_bias,
  659. drop_path=dpr[i],
  660. norm_layer=norm_layer,
  661. act_layer=act_layer,
  662. use_rel_pos=self.rel_pos_blocks[i],
  663. rel_pos_zero_init=rel_pos_zero_init,
  664. window_size=window_size if i in window_block_indexes else 0,
  665. input_size=(img_size // patch_size, img_size // patch_size),
  666. use_rope=use_rope,
  667. rope_pt_size=(
  668. (window_size, window_size)
  669. if rope_pt_size is None
  670. else (rope_pt_size, rope_pt_size)
  671. ),
  672. rope_interp=use_interp_rope,
  673. cls_token=self.retain_cls_token,
  674. dropout=dropout,
  675. init_values=init_values,
  676. )
  677. if i not in window_block_indexes:
  678. cur_stage += 1
  679. self.use_act_checkpoint = use_act_checkpoint
  680. self.blocks.append(block)
  681. self.return_interm_layers = return_interm_layers
  682. self.channel_list = (
  683. [embed_dim] * len(self.full_attn_ids)
  684. if return_interm_layers
  685. else [embed_dim]
  686. )
  687. if self.pos_embed is not None:
  688. trunc_normal_(self.pos_embed, std=0.02)
  689. self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
  690. self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
  691. self.apply(self._init_weights)
  692. if compile_mode is not None:
  693. self.forward = torch.compile(
  694. self.forward, mode=compile_mode, fullgraph=True
  695. )
  696. if self.use_act_checkpoint and self.training:
  697. torch._dynamo.config.optimize_ddp = False
  698. def _init_weights(self, m: nn.Module) -> None:
  699. if isinstance(m, nn.Linear):
  700. trunc_normal_(m.weight, std=0.02)
  701. if isinstance(m, nn.Linear) and m.bias is not None:
  702. nn.init.constant_(m.bias, 0)
  703. elif isinstance(m, nn.LayerNorm):
  704. nn.init.constant_(m.bias, 0)
  705. nn.init.constant_(m.weight, 1.0)
  706. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  707. x = self.patch_embed(x)
  708. h, w = x.shape[1], x.shape[2]
  709. s = 0
  710. if self.retain_cls_token:
  711. # If cls_token is retained, we don't
  712. # maintain spatial shape
  713. x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
  714. s = 1
  715. if self.pos_embed is not None:
  716. x = x + get_abs_pos(
  717. self.pos_embed,
  718. self.pretrain_use_cls_token,
  719. (h, w),
  720. self.retain_cls_token,
  721. tiling=self.tile_abs_pos,
  722. )
  723. x = self.ln_pre(x)
  724. outputs = []
  725. for i, blk in enumerate(self.blocks):
  726. if self.use_act_checkpoint and self.training:
  727. x = checkpoint.checkpoint(blk, x, use_reentrant=False)
  728. else:
  729. x = blk(x)
  730. if (i == self.full_attn_ids[-1]) or (
  731. self.return_interm_layers and i in self.full_attn_ids
  732. ):
  733. if i == self.full_attn_ids[-1]:
  734. x = self.ln_post(x)
  735. feats = x[:, s:]
  736. if feats.ndim == 4:
  737. feats = feats.permute(0, 3, 1, 2)
  738. else:
  739. assert feats.ndim == 3
  740. h = w = math.sqrt(feats.shape[1])
  741. feats = feats.reshape(
  742. feats.shape[0], h, w, feats.shape[-1]
  743. ).permute(0, 3, 1, 2)
  744. outputs.append(feats)
  745. return outputs
  746. def get_layer_id(self, layer_name: str) -> int:
  747. # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
  748. num_layers = self.get_num_layers()
  749. if layer_name.find("rel_pos") != -1:
  750. return num_layers + 1
  751. elif layer_name.find("ln_pre") != -1:
  752. return 0
  753. elif layer_name.find("pos_embed") != -1 or layer_name.find("cls_token") != -1:
  754. return 0
  755. elif layer_name.find("patch_embed") != -1:
  756. return 0
  757. elif layer_name.find("blocks") != -1:
  758. return int(layer_name.split("blocks")[1].split(".")[1]) + 1
  759. else:
  760. return num_layers + 1
  761. def get_num_layers(self) -> int:
  762. return len(self.blocks)