position_encoding.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. from typing import Any, Optional, Tuple
  7. import numpy as np
  8. import torch
  9. from torch import nn
  10. class PositionEmbeddingSine(nn.Module):
  11. """
  12. This is a more standard version of the position embedding, very similar to the one
  13. used by the Attention Is All You Need paper, generalized to work on images.
  14. """
  15. def __init__(
  16. self,
  17. num_pos_feats,
  18. temperature: int = 10000,
  19. normalize: bool = True,
  20. scale: Optional[float] = None,
  21. # Following settings only relevant
  22. # for warmping up cache for compilation
  23. warmup_cache: bool = True,
  24. image_size: int = 1024,
  25. strides: Tuple[int] = (4, 8, 16, 32),
  26. ):
  27. super().__init__()
  28. assert num_pos_feats % 2 == 0, "Expecting even model width"
  29. self.num_pos_feats = num_pos_feats // 2
  30. self.temperature = temperature
  31. self.normalize = normalize
  32. if scale is not None and normalize is False:
  33. raise ValueError("normalize should be True if scale is passed")
  34. if scale is None:
  35. scale = 2 * math.pi
  36. self.scale = scale
  37. self.cache = {}
  38. if warmup_cache and torch.cuda.is_available():
  39. # Warmup cache for cuda, to help with compilation
  40. device = torch.device("cuda")
  41. for stride in strides:
  42. cache_key = (image_size // stride, image_size // stride)
  43. self._pe(1, device, *cache_key)
  44. def _encode_xy(self, x, y):
  45. # The positions are expected to be normalized
  46. assert len(x) == len(y) and x.ndim == y.ndim == 1
  47. x_embed = x * self.scale
  48. y_embed = y * self.scale
  49. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  50. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  51. pos_x = x_embed[:, None] / dim_t
  52. pos_y = y_embed[:, None] / dim_t
  53. pos_x = torch.stack(
  54. (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
  55. ).flatten(1)
  56. pos_y = torch.stack(
  57. (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
  58. ).flatten(1)
  59. return pos_x, pos_y
  60. @torch.no_grad()
  61. def encode_boxes(self, x, y, w, h):
  62. pos_x, pos_y = self._encode_xy(x, y)
  63. pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
  64. return pos
  65. encode = encode_boxes # Backwards compatibility
  66. @torch.no_grad()
  67. def encode_points(self, x, y, labels):
  68. (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
  69. assert bx == by and nx == ny and bx == bl and nx == nl
  70. pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
  71. pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
  72. pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
  73. return pos
  74. @torch.no_grad()
  75. def _pe(self, B, device, *cache_key):
  76. H, W = cache_key
  77. if cache_key in self.cache:
  78. return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
  79. y_embed = (
  80. torch.arange(1, H + 1, dtype=torch.float32, device=device)
  81. .view(1, -1, 1)
  82. .repeat(B, 1, W)
  83. )
  84. x_embed = (
  85. torch.arange(1, W + 1, dtype=torch.float32, device=device)
  86. .view(1, 1, -1)
  87. .repeat(B, H, 1)
  88. )
  89. if self.normalize:
  90. eps = 1e-6
  91. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  92. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  93. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
  94. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  95. pos_x = x_embed[:, :, :, None] / dim_t
  96. pos_y = y_embed[:, :, :, None] / dim_t
  97. pos_x = torch.stack(
  98. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
  99. ).flatten(3)
  100. pos_y = torch.stack(
  101. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
  102. ).flatten(3)
  103. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  104. self.cache[cache_key] = pos[0]
  105. return pos
  106. @torch.no_grad()
  107. def forward(self, x: torch.Tensor):
  108. B = x.shape[0]
  109. cache_key = (x.shape[-2], x.shape[-1])
  110. return self._pe(B, x.device, *cache_key)
  111. class PositionEmbeddingRandom(nn.Module):
  112. """
  113. Positional encoding using random spatial frequencies.
  114. """
  115. def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
  116. super().__init__()
  117. if scale is None or scale <= 0.0:
  118. scale = 1.0
  119. self.register_buffer(
  120. "positional_encoding_gaussian_matrix",
  121. scale * torch.randn((2, num_pos_feats)),
  122. )
  123. def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
  124. """Positionally encode points that are normalized to [0,1]."""
  125. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  126. coords = 2 * coords - 1
  127. coords = coords @ self.positional_encoding_gaussian_matrix
  128. coords = 2 * np.pi * coords
  129. # outputs d_1 x ... x d_n x C shape
  130. return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  131. def forward(self, size: Tuple[int, int]) -> torch.Tensor:
  132. """Generate positional encoding for a grid of the specified size."""
  133. h, w = size
  134. device: Any = self.positional_encoding_gaussian_matrix.device
  135. grid = torch.ones((h, w), device=device, dtype=torch.float32)
  136. y_embed = grid.cumsum(dim=0) - 0.5
  137. x_embed = grid.cumsum(dim=1) - 0.5
  138. y_embed = y_embed / h
  139. x_embed = x_embed / w
  140. pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
  141. return pe.permute(2, 0, 1) # C x H x W
  142. def forward_with_coords(
  143. self, coords_input: torch.Tensor, image_size: Tuple[int, int]
  144. ) -> torch.Tensor:
  145. """Positionally encode points that are not normalized to [0,1]."""
  146. coords = coords_input.clone()
  147. coords[:, :, 0] = coords[:, :, 0] / image_size[1]
  148. coords[:, :, 1] = coords[:, :, 1] / image_size[0]
  149. return self._pe_encoding(coords.to(torch.float)) # B x N x C
  150. # Rotary Positional Encoding, adapted from:
  151. # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
  152. # 2. https://github.com/naver-ai/rope-vit
  153. # 3. https://github.com/lucidrains/rotary-embedding-torch
  154. def init_t_xy(end_x: int, end_y: int):
  155. t = torch.arange(end_x * end_y, dtype=torch.float32)
  156. t_x = (t % end_x).float()
  157. t_y = torch.div(t, end_x, rounding_mode="floor").float()
  158. return t_x, t_y
  159. def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
  160. freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  161. freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  162. t_x, t_y = init_t_xy(end_x, end_y)
  163. freqs_x = torch.outer(t_x, freqs_x)
  164. freqs_y = torch.outer(t_y, freqs_y)
  165. freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
  166. freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
  167. return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
  168. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  169. ndim = x.ndim
  170. assert 0 <= 1 < ndim
  171. assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
  172. shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
  173. return freqs_cis.view(*shape)
  174. def apply_rotary_enc(
  175. xq: torch.Tensor,
  176. xk: torch.Tensor,
  177. freqs_cis: torch.Tensor,
  178. repeat_freqs_k: bool = False,
  179. ):
  180. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  181. xk_ = (
  182. torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  183. if xk.shape[-2] != 0
  184. else None
  185. )
  186. freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
  187. xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
  188. if xk_ is None:
  189. # no keys to rotate, due to dropout
  190. return xq_out.type_as(xq).to(xq.device), xk
  191. # repeat freqs along seq_len dim to match k seq_len
  192. if repeat_freqs_k:
  193. r = xk_.shape[-2] // xq_.shape[-2]
  194. if freqs_cis.is_cuda:
  195. freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
  196. else:
  197. # torch.repeat on complex numbers may not be supported on non-CUDA devices
  198. # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
  199. freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
  200. xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
  201. return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)