position_encoding.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. from typing import Optional
  5. import torch
  6. from torch import nn
  7. class PositionEmbeddingSine(nn.Module):
  8. """
  9. This is a more standard version of the position embedding, very similar to the one
  10. used by the Attention is all you need paper, generalized to work on images.
  11. """
  12. def __init__(
  13. self,
  14. num_pos_feats,
  15. temperature: int = 10000,
  16. normalize: bool = True,
  17. scale: Optional[float] = None,
  18. precompute_resolution: Optional[int] = None,
  19. ):
  20. super().__init__()
  21. assert num_pos_feats % 2 == 0, "Expecting even model width"
  22. self.num_pos_feats = num_pos_feats // 2
  23. self.temperature = temperature
  24. self.normalize = normalize
  25. if scale is not None and normalize is False:
  26. raise ValueError("normalize should be True if scale is passed")
  27. if scale is None:
  28. scale = 2 * math.pi
  29. self.scale = scale
  30. self.cache = {}
  31. # Precompute positional encodings under `precompute_resolution` to fill the cache
  32. # and avoid symbolic shape tracing errors in torch.compile in PyTorch 2.4 nightly.
  33. if precompute_resolution is not None:
  34. # We precompute pos enc for stride 4, 8, 16 and 32 to fill `self.cache`.
  35. precompute_sizes = [
  36. (precompute_resolution // 4, precompute_resolution // 4),
  37. (precompute_resolution // 8, precompute_resolution // 8),
  38. (precompute_resolution // 16, precompute_resolution // 16),
  39. (precompute_resolution // 32, precompute_resolution // 32),
  40. ]
  41. for size in precompute_sizes:
  42. tensors = torch.zeros((1, 1) + size, device="cuda")
  43. self.forward(tensors)
  44. # further clone and detach it in the cache (just to be safe)
  45. self.cache[size] = self.cache[size].clone().detach()
  46. def _encode_xy(self, x, y):
  47. # The positions are expected to be normalized
  48. assert len(x) == len(y) and x.ndim == y.ndim == 1
  49. x_embed = x * self.scale
  50. y_embed = y * self.scale
  51. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  52. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  53. pos_x = x_embed[:, None] / dim_t
  54. pos_y = y_embed[:, None] / dim_t
  55. pos_x = torch.stack(
  56. (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
  57. ).flatten(1)
  58. pos_y = torch.stack(
  59. (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
  60. ).flatten(1)
  61. return pos_x, pos_y
  62. @torch.no_grad()
  63. def encode_boxes(self, x, y, w, h):
  64. pos_x, pos_y = self._encode_xy(x, y)
  65. pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
  66. return pos
  67. encode = encode_boxes # Backwards compatibility
  68. @torch.no_grad()
  69. def encode_points(self, x, y, labels):
  70. (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
  71. assert bx == by and nx == ny and bx == bl and nx == nl
  72. pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
  73. pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
  74. pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
  75. return pos
  76. @torch.no_grad()
  77. def forward(self, x):
  78. cache_key = None
  79. cache_key = (x.shape[-2], x.shape[-1])
  80. if cache_key in self.cache:
  81. return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
  82. y_embed = (
  83. torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
  84. .view(1, -1, 1)
  85. .repeat(x.shape[0], 1, x.shape[-1])
  86. )
  87. x_embed = (
  88. torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
  89. .view(1, 1, -1)
  90. .repeat(x.shape[0], x.shape[-2], 1)
  91. )
  92. if self.normalize:
  93. eps = 1e-6
  94. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  95. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  96. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  97. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  98. pos_x = x_embed[:, :, :, None] / dim_t
  99. pos_y = y_embed[:, :, :, None] / dim_t
  100. pos_x = torch.stack(
  101. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
  102. ).flatten(3)
  103. pos_y = torch.stack(
  104. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
  105. ).flatten(3)
  106. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  107. if cache_key is not None:
  108. self.cache[cache_key] = pos[0]
  109. return pos