prompt_encoder.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Any, Optional, Tuple, Type
  4. import numpy as np
  5. import torch
  6. from torch import nn
  7. from .common import LayerNorm2d
  8. class PromptEncoder(nn.Module):
  9. def __init__(
  10. self,
  11. embed_dim: int,
  12. image_embedding_size: Tuple[int, int],
  13. input_image_size: Tuple[int, int],
  14. mask_in_chans: int,
  15. activation: Type[nn.Module] = nn.GELU,
  16. ) -> None:
  17. """
  18. Encodes prompts for input to SAM's mask decoder.
  19. Arguments:
  20. embed_dim (int): The prompts' embedding dimension
  21. image_embedding_size (tuple(int, int)): The spatial size of the
  22. image embedding, as (H, W).
  23. input_image_size (int): The padded size of the image as input
  24. to the image encoder, as (H, W).
  25. mask_in_chans (int): The number of hidden channels used for
  26. encoding input masks.
  27. activation (nn.Module): The activation to use when encoding
  28. input masks.
  29. """
  30. super().__init__()
  31. self.embed_dim = embed_dim
  32. self.input_image_size = input_image_size
  33. self.image_embedding_size = image_embedding_size
  34. self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
  35. self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
  36. point_embeddings = [
  37. nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
  38. ]
  39. self.point_embeddings = nn.ModuleList(point_embeddings)
  40. self.not_a_point_embed = nn.Embedding(1, embed_dim)
  41. self.mask_input_size = (
  42. 4 * image_embedding_size[0],
  43. 4 * image_embedding_size[1],
  44. )
  45. self.mask_downscaling = nn.Sequential(
  46. nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
  47. LayerNorm2d(mask_in_chans // 4),
  48. activation(),
  49. nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
  50. LayerNorm2d(mask_in_chans),
  51. activation(),
  52. nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
  53. )
  54. self.no_mask_embed = nn.Embedding(1, embed_dim)
  55. def get_dense_pe(self) -> torch.Tensor:
  56. """
  57. Returns the positional encoding used to encode point prompts,
  58. applied to a dense set of points the shape of the image encoding.
  59. Returns:
  60. torch.Tensor: Positional encoding with shape
  61. 1x(embed_dim)x(embedding_h)x(embedding_w)
  62. """
  63. return self.pe_layer(self.image_embedding_size).unsqueeze(0)
  64. def _embed_points(
  65. self,
  66. points: torch.Tensor,
  67. labels: torch.Tensor,
  68. pad: bool,
  69. ) -> torch.Tensor:
  70. """Embeds point prompts."""
  71. points = points + 0.5 # Shift to center of pixel
  72. if pad:
  73. padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
  74. padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
  75. points = torch.cat([points, padding_point], dim=1)
  76. labels = torch.cat([labels, padding_label], dim=1)
  77. point_embedding = self.pe_layer.forward_with_coords(
  78. points, self.input_image_size
  79. )
  80. point_embedding = torch.where(
  81. (labels == -1).unsqueeze(-1),
  82. torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
  83. point_embedding,
  84. )
  85. point_embedding = torch.where(
  86. (labels == 0).unsqueeze(-1),
  87. point_embedding + self.point_embeddings[0].weight,
  88. point_embedding,
  89. )
  90. point_embedding = torch.where(
  91. (labels == 1).unsqueeze(-1),
  92. point_embedding + self.point_embeddings[1].weight,
  93. point_embedding,
  94. )
  95. point_embedding = torch.where(
  96. (labels == 2).unsqueeze(-1),
  97. point_embedding + self.point_embeddings[2].weight,
  98. point_embedding,
  99. )
  100. point_embedding = torch.where(
  101. (labels == 3).unsqueeze(-1),
  102. point_embedding + self.point_embeddings[3].weight,
  103. point_embedding,
  104. )
  105. return point_embedding
  106. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  107. """Embeds box prompts."""
  108. boxes = boxes + 0.5 # Shift to center of pixel
  109. coords = boxes.reshape(-1, 2, 2)
  110. corner_embedding = self.pe_layer.forward_with_coords(
  111. coords, self.input_image_size
  112. )
  113. corner_embedding[:, 0, :] += self.point_embeddings[2].weight
  114. corner_embedding[:, 1, :] += self.point_embeddings[3].weight
  115. return corner_embedding
  116. def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
  117. """Embeds mask inputs."""
  118. mask_embedding = self.mask_downscaling(masks)
  119. return mask_embedding
  120. def _get_batch_size(
  121. self,
  122. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  123. boxes: Optional[torch.Tensor],
  124. masks: Optional[torch.Tensor],
  125. ) -> int:
  126. """
  127. Gets the batch size of the output given the batch size of the input prompts.
  128. """
  129. if points is not None:
  130. return points[0].shape[0]
  131. elif boxes is not None:
  132. return boxes.shape[0]
  133. elif masks is not None:
  134. return masks.shape[0]
  135. else:
  136. return 1
  137. def _get_device(self) -> torch.device:
  138. return self.point_embeddings[0].weight.device
  139. def forward(
  140. self,
  141. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  142. boxes: Optional[torch.Tensor],
  143. masks: Optional[torch.Tensor],
  144. ) -> Tuple[torch.Tensor, torch.Tensor]:
  145. """
  146. Embeds different types of prompts, returning both sparse and dense
  147. embeddings.
  148. Arguments:
  149. points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
  150. and labels to embed.
  151. boxes (torch.Tensor or none): boxes to embed
  152. masks (torch.Tensor or none): masks to embed
  153. Returns:
  154. torch.Tensor: sparse embeddings for the points and boxes, with shape
  155. BxNx(embed_dim), where N is determined by the number of input points
  156. and boxes.
  157. torch.Tensor: dense embeddings for the masks, in the shape
  158. Bx(embed_dim)x(embed_H)x(embed_W)
  159. """
  160. bs = self._get_batch_size(points, boxes, masks)
  161. sparse_embeddings = torch.empty(
  162. (bs, 0, self.embed_dim), device=self._get_device()
  163. )
  164. if points is not None:
  165. coords, labels = points
  166. point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  167. sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  168. if boxes is not None:
  169. box_embeddings = self._embed_boxes(boxes)
  170. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
  171. if masks is not None:
  172. dense_embeddings = self._embed_masks(masks)
  173. else:
  174. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  175. bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  176. )
  177. return sparse_embeddings, dense_embeddings
  178. class PositionEmbeddingRandom(nn.Module):
  179. """
  180. Positional encoding using random spatial frequencies.
  181. """
  182. def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
  183. super().__init__()
  184. if scale is None or scale <= 0.0:
  185. scale = 1.0
  186. self.register_buffer(
  187. "positional_encoding_gaussian_matrix",
  188. scale * torch.randn((2, num_pos_feats)),
  189. )
  190. def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
  191. """Positionally encode points that are normalized to [0,1]."""
  192. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  193. coords = 2 * coords - 1
  194. coords = coords @ self.positional_encoding_gaussian_matrix
  195. coords = 2 * np.pi * coords
  196. # outputs d_1 x ... x d_n x C shape
  197. return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  198. def forward(self, size: Tuple[int, int]) -> torch.Tensor:
  199. """Generate positional encoding for a grid of the specified size."""
  200. h, w = size
  201. device: Any = self.positional_encoding_gaussian_matrix.device
  202. grid = torch.ones((h, w), device=device, dtype=torch.float32)
  203. y_embed = grid.cumsum(dim=0) - 0.5
  204. x_embed = grid.cumsum(dim=1) - 0.5
  205. y_embed = y_embed / h
  206. x_embed = x_embed / w
  207. pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
  208. return pe.permute(2, 0, 1) # C x H x W
  209. def forward_with_coords(
  210. self, coords_input: torch.Tensor, image_size: Tuple[int, int]
  211. ) -> torch.Tensor:
  212. """Positionally encode points that are not normalized to [0,1]."""
  213. coords = coords_input.clone()
  214. coords[:, :, 0] = coords[:, :, 0] / image_size[1]
  215. coords[:, :, 1] = coords[:, :, 1] / image_size[0]
  216. return self._pe_encoding(coords.to(torch.float)) # B x N x C