prompt_encoder.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Optional, Tuple, Type
  6. import torch
  7. from torch import nn
  8. from sam2.modeling.position_encoding import PositionEmbeddingRandom
  9. from sam2.modeling.sam2_utils import LayerNorm2d
  10. class PromptEncoder(nn.Module):
  11. def __init__(
  12. self,
  13. embed_dim: int,
  14. image_embedding_size: Tuple[int, int],
  15. input_image_size: Tuple[int, int],
  16. mask_in_chans: int,
  17. activation: Type[nn.Module] = nn.GELU,
  18. ) -> None:
  19. """
  20. Encodes prompts for input to SAM's mask decoder.
  21. Arguments:
  22. embed_dim (int): The prompts' embedding dimension
  23. image_embedding_size (tuple(int, int)): The spatial size of the
  24. image embedding, as (H, W).
  25. input_image_size (int): The padded size of the image as input
  26. to the image encoder, as (H, W).
  27. mask_in_chans (int): The number of hidden channels used for
  28. encoding input masks.
  29. activation (nn.Module): The activation to use when encoding
  30. input masks.
  31. """
  32. super().__init__()
  33. self.embed_dim = embed_dim
  34. self.input_image_size = input_image_size
  35. self.image_embedding_size = image_embedding_size
  36. self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
  37. self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
  38. point_embeddings = [
  39. nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
  40. ]
  41. self.point_embeddings = nn.ModuleList(point_embeddings)
  42. self.not_a_point_embed = nn.Embedding(1, embed_dim)
  43. self.mask_input_size = (
  44. 4 * image_embedding_size[0],
  45. 4 * image_embedding_size[1],
  46. )
  47. self.mask_downscaling = nn.Sequential(
  48. nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
  49. LayerNorm2d(mask_in_chans // 4),
  50. activation(),
  51. nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
  52. LayerNorm2d(mask_in_chans),
  53. activation(),
  54. nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
  55. )
  56. self.no_mask_embed = nn.Embedding(1, embed_dim)
  57. def get_dense_pe(self) -> torch.Tensor:
  58. """
  59. Returns the positional encoding used to encode point prompts,
  60. applied to a dense set of points the shape of the image encoding.
  61. Returns:
  62. torch.Tensor: Positional encoding with shape
  63. 1x(embed_dim)x(embedding_h)x(embedding_w)
  64. """
  65. return self.pe_layer(self.image_embedding_size).unsqueeze(0)
  66. def _embed_points(
  67. self,
  68. points: torch.Tensor,
  69. labels: torch.Tensor,
  70. pad: bool,
  71. ) -> torch.Tensor:
  72. """Embeds point prompts."""
  73. points = points + 0.5 # Shift to center of pixel
  74. if pad:
  75. padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
  76. padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
  77. points = torch.cat([points, padding_point], dim=1)
  78. labels = torch.cat([labels, padding_label], dim=1)
  79. point_embedding = self.pe_layer.forward_with_coords(
  80. points, self.input_image_size
  81. )
  82. point_embedding[labels == -1] = 0.0
  83. point_embedding[labels == -1] += self.not_a_point_embed.weight
  84. point_embedding[labels == 0] += self.point_embeddings[0].weight
  85. point_embedding[labels == 1] += self.point_embeddings[1].weight
  86. point_embedding[labels == 2] += self.point_embeddings[2].weight
  87. point_embedding[labels == 3] += self.point_embeddings[3].weight
  88. return point_embedding
  89. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  90. """Embeds box prompts."""
  91. boxes = boxes + 0.5 # Shift to center of pixel
  92. coords = boxes.reshape(-1, 2, 2)
  93. corner_embedding = self.pe_layer.forward_with_coords(
  94. coords, self.input_image_size
  95. )
  96. corner_embedding[:, 0, :] += self.point_embeddings[2].weight
  97. corner_embedding[:, 1, :] += self.point_embeddings[3].weight
  98. return corner_embedding
  99. def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
  100. """Embeds mask inputs."""
  101. mask_embedding = self.mask_downscaling(masks)
  102. return mask_embedding
  103. def _get_batch_size(
  104. self,
  105. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  106. boxes: Optional[torch.Tensor],
  107. masks: Optional[torch.Tensor],
  108. ) -> int:
  109. """
  110. Gets the batch size of the output given the batch size of the input prompts.
  111. """
  112. if points is not None:
  113. return points[0].shape[0]
  114. elif boxes is not None:
  115. return boxes.shape[0]
  116. elif masks is not None:
  117. return masks.shape[0]
  118. else:
  119. return 1
  120. def _get_device(self) -> torch.device:
  121. return self.point_embeddings[0].weight.device
  122. def forward(
  123. self,
  124. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  125. boxes: Optional[torch.Tensor],
  126. masks: Optional[torch.Tensor],
  127. ) -> Tuple[torch.Tensor, torch.Tensor]:
  128. """
  129. Embeds different types of prompts, returning both sparse and dense
  130. embeddings.
  131. Arguments:
  132. points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
  133. and labels to embed.
  134. boxes (torch.Tensor or none): boxes to embed
  135. masks (torch.Tensor or none): masks to embed
  136. Returns:
  137. torch.Tensor: sparse embeddings for the points and boxes, with shape
  138. BxNx(embed_dim), where N is determined by the number of input points
  139. and boxes.
  140. torch.Tensor: dense embeddings for the masks, in the shape
  141. Bx(embed_dim)x(embed_H)x(embed_W)
  142. """
  143. bs = self._get_batch_size(points, boxes, masks)
  144. sparse_embeddings = torch.empty(
  145. (bs, 0, self.embed_dim), device=self._get_device()
  146. )
  147. if points is not None:
  148. coords, labels = points
  149. point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  150. sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  151. if boxes is not None:
  152. box_embeddings = self._embed_boxes(boxes)
  153. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
  154. if masks is not None:
  155. dense_embeddings = self._embed_masks(masks)
  156. else:
  157. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  158. bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  159. )
  160. return sparse_embeddings, dense_embeddings