maskformer_segmentation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import math
  4. from typing import Dict, List, Optional
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.utils.checkpoint as checkpoint
  9. from .model_misc import MLP
  10. class LinearPresenceHead(nn.Sequential):
  11. def __init__(self, d_model):
  12. # a hack to make `LinearPresenceHead` compatible with old checkpoints
  13. super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
  14. def forward(self, hs, prompt, prompt_mask):
  15. return super().forward(hs)
  16. class MaskPredictor(nn.Module):
  17. def __init__(self, hidden_dim, mask_dim):
  18. super().__init__()
  19. self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
  20. def forward(self, obj_queries, pixel_embed):
  21. if len(obj_queries.shape) == 3:
  22. if pixel_embed.ndim == 3:
  23. # batch size was omitted
  24. mask_preds = torch.einsum(
  25. "bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed
  26. )
  27. else:
  28. mask_preds = torch.einsum(
  29. "bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed
  30. )
  31. else:
  32. # Assumed to have aux masks
  33. if pixel_embed.ndim == 3:
  34. # batch size was omitted
  35. mask_preds = torch.einsum(
  36. "lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed
  37. )
  38. else:
  39. mask_preds = torch.einsum(
  40. "lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed
  41. )
  42. return mask_preds
  43. class SegmentationHead(nn.Module):
  44. def __init__(
  45. self,
  46. hidden_dim,
  47. upsampling_stages,
  48. use_encoder_inputs=False,
  49. aux_masks=False,
  50. no_dec=False,
  51. pixel_decoder=None,
  52. act_ckpt=False,
  53. shared_conv=False,
  54. compile_mode_pixel_decoder=None,
  55. ):
  56. super().__init__()
  57. self.use_encoder_inputs = use_encoder_inputs
  58. self.aux_masks = aux_masks
  59. if pixel_decoder is not None:
  60. self.pixel_decoder = pixel_decoder
  61. else:
  62. self.pixel_decoder = PixelDecoder(
  63. hidden_dim,
  64. upsampling_stages,
  65. shared_conv=shared_conv,
  66. compile_mode=compile_mode_pixel_decoder,
  67. )
  68. self.no_dec = no_dec
  69. if no_dec:
  70. self.mask_predictor = nn.Conv2d(
  71. hidden_dim, 1, kernel_size=3, stride=1, padding=1
  72. )
  73. else:
  74. self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
  75. self.act_ckpt = act_ckpt
  76. # used to update the output dictionary
  77. self.instance_keys = ["pred_masks"]
  78. @property
  79. def device(self):
  80. self._device = getattr(self, "_device", None) or next(self.parameters()).device
  81. return self._device
  82. def to(self, *args, **kwargs):
  83. # clear cached _device in case the model is moved to a different device
  84. self._device = None
  85. return super().to(*args, **kwargs)
  86. def _embed_pixels(
  87. self,
  88. backbone_feats: List[torch.Tensor],
  89. image_ids,
  90. encoder_hidden_states,
  91. ) -> torch.Tensor:
  92. feature_device = backbone_feats[0].device # features could be on CPU
  93. model_device = self.device
  94. image_ids_ = image_ids.to(feature_device)
  95. if self.use_encoder_inputs:
  96. if backbone_feats[0].shape[0] > 1:
  97. # For bs > 1, we construct the per query backbone features
  98. backbone_visual_feats = []
  99. for feat in backbone_feats:
  100. # Copy the img features per query (pixel decoder won't share img feats)
  101. backbone_visual_feats.append(feat[image_ids_, ...].to(model_device))
  102. else:
  103. # Bs=1, we rely on broadcasting for query-based processing
  104. backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
  105. # Extract visual embeddings
  106. encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
  107. spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
  108. encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(
  109. -1, *backbone_feats[-1].shape[1:]
  110. )
  111. backbone_visual_feats[-1] = encoder_visual_embed
  112. if self.act_ckpt:
  113. pixel_embed = checkpoint.checkpoint(
  114. self.pixel_decoder, backbone_visual_feats, use_reentrant=False
  115. )
  116. else:
  117. pixel_embed = self.pixel_decoder(backbone_visual_feats)
  118. else:
  119. backbone_feats = [x.to(model_device) for x in backbone_feats]
  120. pixel_embed = self.pixel_decoder(backbone_feats)
  121. if pixel_embed.shape[0] == 1:
  122. # For batch_size=1 training, we can avoid the indexing to save memory
  123. pixel_embed = pixel_embed.squeeze(0)
  124. else:
  125. pixel_embed = pixel_embed[image_ids, ...]
  126. return pixel_embed
  127. def forward(
  128. self,
  129. backbone_feats: List[torch.Tensor],
  130. obj_queries: torch.Tensor,
  131. image_ids,
  132. encoder_hidden_states: Optional[torch.Tensor] = None,
  133. **kwargs,
  134. ) -> Dict[str, torch.Tensor]:
  135. if self.use_encoder_inputs:
  136. assert encoder_hidden_states is not None
  137. pixel_embed = self._embed_pixels(
  138. backbone_feats=backbone_feats,
  139. image_ids=image_ids,
  140. encoder_hidden_states=encoder_hidden_states,
  141. )
  142. if self.no_dec:
  143. mask_pred = self.mask_predictor(pixel_embed)
  144. elif self.aux_masks:
  145. mask_pred = self.mask_predictor(obj_queries, pixel_embed)
  146. else:
  147. mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
  148. return {"pred_masks": mask_pred}
  149. class PixelDecoder(nn.Module):
  150. def __init__(
  151. self,
  152. hidden_dim,
  153. num_upsampling_stages,
  154. interpolation_mode="nearest",
  155. shared_conv=False,
  156. compile_mode=None,
  157. ):
  158. super().__init__()
  159. self.hidden_dim = hidden_dim
  160. self.num_upsampling_stages = num_upsampling_stages
  161. self.interpolation_mode = interpolation_mode
  162. conv_layers = []
  163. norms = []
  164. num_convs = 1 if shared_conv else num_upsampling_stages
  165. for _ in range(num_convs):
  166. conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
  167. norms.append(nn.GroupNorm(8, self.hidden_dim))
  168. self.conv_layers = nn.ModuleList(conv_layers)
  169. self.norms = nn.ModuleList(norms)
  170. self.shared_conv = shared_conv
  171. self.out_dim = self.conv_layers[-1].out_channels
  172. if compile_mode is not None:
  173. self.forward = torch.compile(
  174. self.forward, mode=compile_mode, dynamic=True, fullgraph=True
  175. )
  176. # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
  177. torch._dynamo.config.optimize_ddp = False
  178. def forward(self, backbone_feats: List[torch.Tensor]):
  179. # Assumes backbone features are already projected (C == hidden dim)
  180. prev_fpn = backbone_feats[-1]
  181. fpn_feats = backbone_feats[:-1]
  182. for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
  183. curr_fpn = bb_feat
  184. prev_fpn = curr_fpn + F.interpolate(
  185. prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode
  186. )
  187. if self.shared_conv:
  188. # only one conv layer
  189. layer_idx = 0
  190. prev_fpn = self.conv_layers[layer_idx](prev_fpn)
  191. prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
  192. return prev_fpn
  193. class UniversalSegmentationHead(SegmentationHead):
  194. """This module handles semantic+instance segmentation"""
  195. def __init__(
  196. self,
  197. hidden_dim,
  198. upsampling_stages,
  199. pixel_decoder,
  200. aux_masks=False,
  201. no_dec=False,
  202. act_ckpt=False,
  203. presence_head: bool = False,
  204. dot_product_scorer=None,
  205. cross_attend_prompt=None,
  206. ):
  207. super().__init__(
  208. hidden_dim=hidden_dim,
  209. upsampling_stages=upsampling_stages,
  210. use_encoder_inputs=True,
  211. aux_masks=aux_masks,
  212. no_dec=no_dec,
  213. pixel_decoder=pixel_decoder,
  214. act_ckpt=act_ckpt,
  215. )
  216. self.d_model = hidden_dim
  217. if dot_product_scorer is not None:
  218. assert presence_head, (
  219. "Specifying a dot product scorer without a presence head is likely a mistake"
  220. )
  221. self.presence_head = None
  222. if presence_head:
  223. self.presence_head = (
  224. dot_product_scorer
  225. if dot_product_scorer is not None
  226. else LinearPresenceHead(self.d_model)
  227. )
  228. self.cross_attend_prompt = cross_attend_prompt
  229. if self.cross_attend_prompt is not None:
  230. self.cross_attn_norm = nn.LayerNorm(self.d_model)
  231. self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
  232. self.instance_seg_head = nn.Conv2d(
  233. self.pixel_decoder.out_dim, self.d_model, kernel_size=1
  234. )
  235. def forward(
  236. self,
  237. backbone_feats: List[torch.Tensor],
  238. obj_queries: torch.Tensor,
  239. image_ids,
  240. encoder_hidden_states: Optional[torch.Tensor] = None,
  241. prompt: Optional[torch.Tensor] = None,
  242. prompt_mask: Optional[torch.Tensor] = None,
  243. **kwargs,
  244. ) -> Dict[str, Optional[torch.Tensor]]:
  245. assert encoder_hidden_states is not None
  246. bs = encoder_hidden_states.shape[1]
  247. if self.cross_attend_prompt is not None:
  248. tgt2 = self.cross_attn_norm(encoder_hidden_states)
  249. tgt2 = self.cross_attend_prompt(
  250. query=tgt2,
  251. key=prompt,
  252. value=prompt,
  253. key_padding_mask=prompt_mask,
  254. )[0]
  255. encoder_hidden_states = tgt2 + encoder_hidden_states
  256. presence_logit = None
  257. if self.presence_head is not None:
  258. pooled_enc = encoder_hidden_states.mean(0)
  259. presence_logit = (
  260. self.presence_head(
  261. pooled_enc.view(1, bs, 1, self.d_model),
  262. prompt=prompt,
  263. prompt_mask=prompt_mask,
  264. )
  265. .squeeze(0)
  266. .squeeze(1)
  267. )
  268. pixel_embed = self._embed_pixels(
  269. backbone_feats=backbone_feats,
  270. image_ids=image_ids,
  271. encoder_hidden_states=encoder_hidden_states,
  272. )
  273. instance_embeds = self.instance_seg_head(pixel_embed)
  274. if self.no_dec:
  275. mask_pred = self.mask_predictor(instance_embeds)
  276. elif self.aux_masks:
  277. mask_pred = self.mask_predictor(obj_queries, instance_embeds)
  278. else:
  279. mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
  280. return {
  281. "pred_masks": mask_pred,
  282. "semantic_seg": self.semantic_seg_head(pixel_embed),
  283. "presence_logit": presence_logit,
  284. }