mask_decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import List, Optional, Tuple, Type
  6. import torch
  7. from torch import nn
  8. from sam2.modeling.sam2_utils import LayerNorm2d, MLP
  9. class MaskDecoder(nn.Module):
  10. def __init__(
  11. self,
  12. *,
  13. transformer_dim: int,
  14. transformer: nn.Module,
  15. num_multimask_outputs: int = 3,
  16. activation: Type[nn.Module] = nn.GELU,
  17. iou_head_depth: int = 3,
  18. iou_head_hidden_dim: int = 256,
  19. use_high_res_features: bool = False,
  20. iou_prediction_use_sigmoid=False,
  21. dynamic_multimask_via_stability=False,
  22. dynamic_multimask_stability_delta=0.05,
  23. dynamic_multimask_stability_thresh=0.98,
  24. pred_obj_scores: bool = False,
  25. pred_obj_scores_mlp: bool = False,
  26. use_multimask_token_for_obj_ptr: bool = False,
  27. ) -> None:
  28. """
  29. Predicts masks given an image and prompt embeddings, using a
  30. transformer architecture.
  31. Arguments:
  32. transformer_dim (int): the channel dimension of the transformer
  33. transformer (nn.Module): the transformer used to predict masks
  34. num_multimask_outputs (int): the number of masks to predict
  35. when disambiguating masks
  36. activation (nn.Module): the type of activation to use when
  37. upscaling masks
  38. iou_head_depth (int): the depth of the MLP used to predict
  39. mask quality
  40. iou_head_hidden_dim (int): the hidden dimension of the MLP
  41. used to predict mask quality
  42. """
  43. super().__init__()
  44. self.transformer_dim = transformer_dim
  45. self.transformer = transformer
  46. self.num_multimask_outputs = num_multimask_outputs
  47. self.iou_token = nn.Embedding(1, transformer_dim)
  48. self.num_mask_tokens = num_multimask_outputs + 1
  49. self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  50. self.pred_obj_scores = pred_obj_scores
  51. if self.pred_obj_scores:
  52. self.obj_score_token = nn.Embedding(1, transformer_dim)
  53. self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
  54. self.output_upscaling = nn.Sequential(
  55. nn.ConvTranspose2d(
  56. transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
  57. ),
  58. LayerNorm2d(transformer_dim // 4),
  59. activation(),
  60. nn.ConvTranspose2d(
  61. transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
  62. ),
  63. activation(),
  64. )
  65. self.use_high_res_features = use_high_res_features
  66. if use_high_res_features:
  67. self.conv_s0 = nn.Conv2d(
  68. transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
  69. )
  70. self.conv_s1 = nn.Conv2d(
  71. transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
  72. )
  73. self.output_hypernetworks_mlps = nn.ModuleList(
  74. [
  75. MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
  76. for i in range(self.num_mask_tokens)
  77. ]
  78. )
  79. self.iou_prediction_head = MLP(
  80. transformer_dim,
  81. iou_head_hidden_dim,
  82. self.num_mask_tokens,
  83. iou_head_depth,
  84. sigmoid_output=iou_prediction_use_sigmoid,
  85. )
  86. if self.pred_obj_scores:
  87. self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
  88. if pred_obj_scores_mlp:
  89. self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
  90. # When outputting a single mask, optionally we can dynamically fall back to the best
  91. # multimask output token if the single mask output token gives low stability scores.
  92. self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
  93. self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
  94. self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
  95. def forward(
  96. self,
  97. image_embeddings: torch.Tensor,
  98. image_pe: torch.Tensor,
  99. sparse_prompt_embeddings: torch.Tensor,
  100. dense_prompt_embeddings: torch.Tensor,
  101. multimask_output: bool,
  102. repeat_image: bool,
  103. high_res_features: Optional[List[torch.Tensor]] = None,
  104. ) -> Tuple[torch.Tensor, torch.Tensor]:
  105. """
  106. Predict masks given image and prompt embeddings.
  107. Arguments:
  108. image_embeddings (torch.Tensor): the embeddings from the image encoder
  109. image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
  110. sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
  111. dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
  112. multimask_output (bool): Whether to return multiple masks or a single
  113. mask.
  114. Returns:
  115. torch.Tensor: batched predicted masks
  116. torch.Tensor: batched predictions of mask quality
  117. torch.Tensor: batched SAM token for mask output
  118. """
  119. masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
  120. image_embeddings=image_embeddings,
  121. image_pe=image_pe,
  122. sparse_prompt_embeddings=sparse_prompt_embeddings,
  123. dense_prompt_embeddings=dense_prompt_embeddings,
  124. repeat_image=repeat_image,
  125. high_res_features=high_res_features,
  126. )
  127. # Select the correct mask or masks for output
  128. if multimask_output:
  129. masks = masks[:, 1:, :, :]
  130. iou_pred = iou_pred[:, 1:]
  131. elif self.dynamic_multimask_via_stability and not self.training:
  132. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  133. else:
  134. masks = masks[:, 0:1, :, :]
  135. iou_pred = iou_pred[:, 0:1]
  136. if multimask_output and self.use_multimask_token_for_obj_ptr:
  137. sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
  138. else:
  139. # Take the mask output token. Here we *always* use the token for single mask output.
  140. # At test time, even if we track after 1-click (and using multimask_output=True),
  141. # we still take the single mask token here. The rationale is that we always track
  142. # after multiple clicks during training, so the past tokens seen during training
  143. # are always the single mask token (and we'll let it be the object-memory token).
  144. sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
  145. # Prepare output
  146. return masks, iou_pred, sam_tokens_out, object_score_logits
  147. def predict_masks(
  148. self,
  149. image_embeddings: torch.Tensor,
  150. image_pe: torch.Tensor,
  151. sparse_prompt_embeddings: torch.Tensor,
  152. dense_prompt_embeddings: torch.Tensor,
  153. repeat_image: bool,
  154. high_res_features: Optional[List[torch.Tensor]] = None,
  155. ) -> Tuple[torch.Tensor, torch.Tensor]:
  156. """Predicts masks. See 'forward' for more details."""
  157. # Concatenate output tokens
  158. s = 0
  159. if self.pred_obj_scores:
  160. output_tokens = torch.cat(
  161. [
  162. self.obj_score_token.weight,
  163. self.iou_token.weight,
  164. self.mask_tokens.weight,
  165. ],
  166. dim=0,
  167. )
  168. s = 1
  169. else:
  170. output_tokens = torch.cat(
  171. [self.iou_token.weight, self.mask_tokens.weight], dim=0
  172. )
  173. output_tokens = output_tokens.unsqueeze(0).expand(
  174. sparse_prompt_embeddings.size(0), -1, -1
  175. )
  176. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  177. # Expand per-image data in batch direction to be per-mask
  178. if repeat_image:
  179. src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
  180. else:
  181. assert image_embeddings.shape[0] == tokens.shape[0]
  182. src = image_embeddings
  183. src = src + dense_prompt_embeddings
  184. assert (
  185. image_pe.size(0) == 1
  186. ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
  187. pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  188. b, c, h, w = src.shape
  189. # Run the transformer
  190. hs, src = self.transformer(src, pos_src, tokens)
  191. iou_token_out = hs[:, s, :]
  192. mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
  193. # Upscale mask embeddings and predict masks using the mask tokens
  194. src = src.transpose(1, 2).view(b, c, h, w)
  195. if not self.use_high_res_features:
  196. upscaled_embedding = self.output_upscaling(src)
  197. else:
  198. dc1, ln1, act1, dc2, act2 = self.output_upscaling
  199. feat_s0, feat_s1 = high_res_features
  200. upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
  201. upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
  202. hyper_in_list: List[torch.Tensor] = []
  203. for i in range(self.num_mask_tokens):
  204. hyper_in_list.append(
  205. self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
  206. )
  207. hyper_in = torch.stack(hyper_in_list, dim=1)
  208. b, c, h, w = upscaled_embedding.shape
  209. masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
  210. # Generate mask quality predictions
  211. iou_pred = self.iou_prediction_head(iou_token_out)
  212. if self.pred_obj_scores:
  213. assert s == 1
  214. object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
  215. else:
  216. # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
  217. object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
  218. return masks, iou_pred, mask_tokens_out, object_score_logits
  219. def _get_stability_scores(self, mask_logits):
  220. """
  221. Compute stability scores of the mask logits based on the IoU between upper and
  222. lower thresholds.
  223. """
  224. mask_logits = mask_logits.flatten(-2)
  225. stability_delta = self.dynamic_multimask_stability_delta
  226. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  227. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  228. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  229. return stability_scores
  230. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  231. """
  232. When outputting a single mask, if the stability score from the current single-mask
  233. output (based on output token 0) falls below a threshold, we instead select from
  234. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  235. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  236. """
  237. # The best mask from multimask output tokens (1~3)
  238. multimask_logits = all_mask_logits[:, 1:, :, :]
  239. multimask_iou_scores = all_iou_scores[:, 1:]
  240. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
  241. batch_inds = torch.arange(
  242. multimask_iou_scores.size(0), device=all_iou_scores.device
  243. )
  244. best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
  245. best_multimask_logits = best_multimask_logits.unsqueeze(1)
  246. best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
  247. best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
  248. # The mask from singlemask output token 0 and its stability score
  249. singlemask_logits = all_mask_logits[:, 0:1, :, :]
  250. singlemask_iou_scores = all_iou_scores[:, 0:1]
  251. stability_scores = self._get_stability_scores(singlemask_logits)
  252. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  253. # Dynamically fall back to best multimask output upon low stability scores.
  254. mask_logits_out = torch.where(
  255. is_stable[..., None, None].expand_as(singlemask_logits),
  256. singlemask_logits,
  257. best_multimask_logits,
  258. )
  259. iou_scores_out = torch.where(
  260. is_stable.expand_as(singlemask_iou_scores),
  261. singlemask_iou_scores,
  262. best_multimask_iou_scores,
  263. )
  264. return mask_logits_out, iou_scores_out