| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import math
- from typing import Optional
- import torch
- from torch import nn
- class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
- def __init__(
- self,
- num_pos_feats,
- temperature: int = 10000,
- normalize: bool = True,
- scale: Optional[float] = None,
- precompute_resolution: Optional[int] = None,
- ):
- super().__init__()
- assert num_pos_feats % 2 == 0, "Expecting even model width"
- self.num_pos_feats = num_pos_feats // 2
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
- self.cache = {}
- # Precompute positional encodings under `precompute_resolution` to fill the cache
- # and avoid symbolic shape tracing errors in torch.compile in PyTorch 2.4 nightly.
- if precompute_resolution is not None:
- # We precompute pos enc for stride 4, 8, 16 and 32 to fill `self.cache`.
- precompute_sizes = [
- (precompute_resolution // 4, precompute_resolution // 4),
- (precompute_resolution // 8, precompute_resolution // 8),
- (precompute_resolution // 16, precompute_resolution // 16),
- (precompute_resolution // 32, precompute_resolution // 32),
- ]
- for size in precompute_sizes:
- tensors = torch.zeros((1, 1) + size, device="cuda")
- self.forward(tensors)
- # further clone and detach it in the cache (just to be safe)
- self.cache[size] = self.cache[size].clone().detach()
- def _encode_xy(self, x, y):
- # The positions are expected to be normalized
- assert len(x) == len(y) and x.ndim == y.ndim == 1
- x_embed = x * self.scale
- y_embed = y * self.scale
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
- pos_x = x_embed[:, None] / dim_t
- pos_y = y_embed[:, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
- ).flatten(1)
- pos_y = torch.stack(
- (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
- ).flatten(1)
- return pos_x, pos_y
- @torch.no_grad()
- def encode_boxes(self, x, y, w, h):
- pos_x, pos_y = self._encode_xy(x, y)
- pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
- return pos
- encode = encode_boxes # Backwards compatibility
- @torch.no_grad()
- def encode_points(self, x, y, labels):
- (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
- assert bx == by and nx == ny and bx == bl and nx == nl
- pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
- pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
- pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
- return pos
- @torch.no_grad()
- def forward(self, x):
- cache_key = None
- cache_key = (x.shape[-2], x.shape[-1])
- if cache_key in self.cache:
- return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
- y_embed = (
- torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
- .view(1, -1, 1)
- .repeat(x.shape[0], 1, x.shape[-1])
- )
- x_embed = (
- torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
- .view(1, 1, -1)
- .repeat(x.shape[0], x.shape[-2], 1)
- )
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- if cache_key is not None:
- self.cache[cache_key] = pos[0]
- return pos
|