| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import math
- from typing import Dict, List, Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- from .model_misc import MLP
- class LinearPresenceHead(nn.Sequential):
- def __init__(self, d_model):
- # a hack to make `LinearPresenceHead` compatible with old checkpoints
- super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
- def forward(self, hs, prompt, prompt_mask):
- return super().forward(hs)
- class MaskPredictor(nn.Module):
- def __init__(self, hidden_dim, mask_dim):
- super().__init__()
- self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
- def forward(self, obj_queries, pixel_embed):
- if len(obj_queries.shape) == 3:
- if pixel_embed.ndim == 3:
- # batch size was omitted
- mask_preds = torch.einsum(
- "bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed
- )
- else:
- mask_preds = torch.einsum(
- "bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed
- )
- else:
- # Assumed to have aux masks
- if pixel_embed.ndim == 3:
- # batch size was omitted
- mask_preds = torch.einsum(
- "lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed
- )
- else:
- mask_preds = torch.einsum(
- "lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed
- )
- return mask_preds
- class SegmentationHead(nn.Module):
- def __init__(
- self,
- hidden_dim,
- upsampling_stages,
- use_encoder_inputs=False,
- aux_masks=False,
- no_dec=False,
- pixel_decoder=None,
- act_ckpt=False,
- shared_conv=False,
- compile_mode_pixel_decoder=None,
- ):
- super().__init__()
- self.use_encoder_inputs = use_encoder_inputs
- self.aux_masks = aux_masks
- if pixel_decoder is not None:
- self.pixel_decoder = pixel_decoder
- else:
- self.pixel_decoder = PixelDecoder(
- hidden_dim,
- upsampling_stages,
- shared_conv=shared_conv,
- compile_mode=compile_mode_pixel_decoder,
- )
- self.no_dec = no_dec
- if no_dec:
- self.mask_predictor = nn.Conv2d(
- hidden_dim, 1, kernel_size=3, stride=1, padding=1
- )
- else:
- self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
- self.act_ckpt = act_ckpt
- # used to update the output dictionary
- self.instance_keys = ["pred_masks"]
- @property
- def device(self):
- self._device = getattr(self, "_device", None) or next(self.parameters()).device
- return self._device
- def to(self, *args, **kwargs):
- # clear cached _device in case the model is moved to a different device
- self._device = None
- return super().to(*args, **kwargs)
- def _embed_pixels(
- self,
- backbone_feats: List[torch.Tensor],
- image_ids,
- encoder_hidden_states,
- ) -> torch.Tensor:
- feature_device = backbone_feats[0].device # features could be on CPU
- model_device = self.device
- image_ids_ = image_ids.to(feature_device)
- if self.use_encoder_inputs:
- if backbone_feats[0].shape[0] > 1:
- # For bs > 1, we construct the per query backbone features
- backbone_visual_feats = []
- for feat in backbone_feats:
- # Copy the img features per query (pixel decoder won't share img feats)
- backbone_visual_feats.append(feat[image_ids_, ...].to(model_device))
- else:
- # Bs=1, we rely on broadcasting for query-based processing
- backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
- # Extract visual embeddings
- encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
- spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
- encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(
- -1, *backbone_feats[-1].shape[1:]
- )
- backbone_visual_feats[-1] = encoder_visual_embed
- if self.act_ckpt:
- pixel_embed = checkpoint.checkpoint(
- self.pixel_decoder, backbone_visual_feats, use_reentrant=False
- )
- else:
- pixel_embed = self.pixel_decoder(backbone_visual_feats)
- else:
- backbone_feats = [x.to(model_device) for x in backbone_feats]
- pixel_embed = self.pixel_decoder(backbone_feats)
- if pixel_embed.shape[0] == 1:
- # For batch_size=1 training, we can avoid the indexing to save memory
- pixel_embed = pixel_embed.squeeze(0)
- else:
- pixel_embed = pixel_embed[image_ids, ...]
- return pixel_embed
- def forward(
- self,
- backbone_feats: List[torch.Tensor],
- obj_queries: torch.Tensor,
- image_ids,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> Dict[str, torch.Tensor]:
- if self.use_encoder_inputs:
- assert encoder_hidden_states is not None
- pixel_embed = self._embed_pixels(
- backbone_feats=backbone_feats,
- image_ids=image_ids,
- encoder_hidden_states=encoder_hidden_states,
- )
- if self.no_dec:
- mask_pred = self.mask_predictor(pixel_embed)
- elif self.aux_masks:
- mask_pred = self.mask_predictor(obj_queries, pixel_embed)
- else:
- mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
- return {"pred_masks": mask_pred}
- class PixelDecoder(nn.Module):
- def __init__(
- self,
- hidden_dim,
- num_upsampling_stages,
- interpolation_mode="nearest",
- shared_conv=False,
- compile_mode=None,
- ):
- super().__init__()
- self.hidden_dim = hidden_dim
- self.num_upsampling_stages = num_upsampling_stages
- self.interpolation_mode = interpolation_mode
- conv_layers = []
- norms = []
- num_convs = 1 if shared_conv else num_upsampling_stages
- for _ in range(num_convs):
- conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
- norms.append(nn.GroupNorm(8, self.hidden_dim))
- self.conv_layers = nn.ModuleList(conv_layers)
- self.norms = nn.ModuleList(norms)
- self.shared_conv = shared_conv
- self.out_dim = self.conv_layers[-1].out_channels
- if compile_mode is not None:
- self.forward = torch.compile(
- self.forward, mode=compile_mode, dynamic=True, fullgraph=True
- )
- # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
- torch._dynamo.config.optimize_ddp = False
- def forward(self, backbone_feats: List[torch.Tensor]):
- # Assumes backbone features are already projected (C == hidden dim)
- prev_fpn = backbone_feats[-1]
- fpn_feats = backbone_feats[:-1]
- for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
- curr_fpn = bb_feat
- prev_fpn = curr_fpn + F.interpolate(
- prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode
- )
- if self.shared_conv:
- # only one conv layer
- layer_idx = 0
- prev_fpn = self.conv_layers[layer_idx](prev_fpn)
- prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
- return prev_fpn
- class UniversalSegmentationHead(SegmentationHead):
- """This module handles semantic+instance segmentation"""
- def __init__(
- self,
- hidden_dim,
- upsampling_stages,
- pixel_decoder,
- aux_masks=False,
- no_dec=False,
- act_ckpt=False,
- presence_head: bool = False,
- dot_product_scorer=None,
- cross_attend_prompt=None,
- ):
- super().__init__(
- hidden_dim=hidden_dim,
- upsampling_stages=upsampling_stages,
- use_encoder_inputs=True,
- aux_masks=aux_masks,
- no_dec=no_dec,
- pixel_decoder=pixel_decoder,
- act_ckpt=act_ckpt,
- )
- self.d_model = hidden_dim
- if dot_product_scorer is not None:
- assert presence_head, (
- "Specifying a dot product scorer without a presence head is likely a mistake"
- )
- self.presence_head = None
- if presence_head:
- self.presence_head = (
- dot_product_scorer
- if dot_product_scorer is not None
- else LinearPresenceHead(self.d_model)
- )
- self.cross_attend_prompt = cross_attend_prompt
- if self.cross_attend_prompt is not None:
- self.cross_attn_norm = nn.LayerNorm(self.d_model)
- self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
- self.instance_seg_head = nn.Conv2d(
- self.pixel_decoder.out_dim, self.d_model, kernel_size=1
- )
- def forward(
- self,
- backbone_feats: List[torch.Tensor],
- obj_queries: torch.Tensor,
- image_ids,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- prompt: Optional[torch.Tensor] = None,
- prompt_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> Dict[str, Optional[torch.Tensor]]:
- assert encoder_hidden_states is not None
- bs = encoder_hidden_states.shape[1]
- if self.cross_attend_prompt is not None:
- tgt2 = self.cross_attn_norm(encoder_hidden_states)
- tgt2 = self.cross_attend_prompt(
- query=tgt2,
- key=prompt,
- value=prompt,
- key_padding_mask=prompt_mask,
- )[0]
- encoder_hidden_states = tgt2 + encoder_hidden_states
- presence_logit = None
- if self.presence_head is not None:
- pooled_enc = encoder_hidden_states.mean(0)
- presence_logit = (
- self.presence_head(
- pooled_enc.view(1, bs, 1, self.d_model),
- prompt=prompt,
- prompt_mask=prompt_mask,
- )
- .squeeze(0)
- .squeeze(1)
- )
- pixel_embed = self._embed_pixels(
- backbone_feats=backbone_feats,
- image_ids=image_ids,
- encoder_hidden_states=encoder_hidden_states,
- )
- instance_embeds = self.instance_seg_head(pixel_embed)
- if self.no_dec:
- mask_pred = self.mask_predictor(instance_embeds)
- elif self.aux_masks:
- mask_pred = self.mask_predictor(obj_queries, instance_embeds)
- else:
- mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
- return {
- "pred_masks": mask_pred,
- "semantic_seg": self.semantic_seg_head(pixel_embed),
- "presence_logit": presence_logit,
- }
|