| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- """
- Transformer decoder.
- Inspired from Pytorch's version, adds the pre-norm variant
- """
- from typing import Any, Dict, List, Optional
- import numpy as np
- import torch
- from sam3.sam.transformer import RoPEAttention
- from torch import nn, Tensor
- from torchvision.ops.roi_align import RoIAlign
- from .act_ckpt_utils import activation_ckpt_wrapper
- from .box_ops import box_cxcywh_to_xyxy
- from .model_misc import (
- gen_sineembed_for_position,
- get_activation_fn,
- get_clones,
- inverse_sigmoid,
- MLP,
- )
- class TransformerDecoderLayer(nn.Module):
- def __init__(
- self,
- activation: str,
- d_model: int,
- dim_feedforward: int,
- dropout: float,
- cross_attention: nn.Module,
- n_heads: int,
- use_text_cross_attention: bool = False,
- ):
- super().__init__()
- # cross attention
- self.cross_attn = cross_attention
- self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
- self.norm1 = nn.LayerNorm(d_model)
- # cross attention text
- self.use_text_cross_attention = use_text_cross_attention
- if use_text_cross_attention:
- self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
- self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
- self.catext_norm = nn.LayerNorm(d_model)
- # self attention
- self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
- self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
- self.norm2 = nn.LayerNorm(d_model)
- # ffn
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.activation = get_activation_fn(activation)
- self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
- self.norm3 = nn.LayerNorm(d_model)
- @staticmethod
- def with_pos_embed(tensor, pos):
- return tensor if pos is None else tensor + pos
- def forward_ffn(self, tgt):
- with torch.amp.autocast(device_type="cuda", enabled=False):
- tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
- tgt = tgt + self.dropout4(tgt2)
- tgt = self.norm3(tgt)
- return tgt
- def forward(
- self,
- # for tgt
- tgt: Optional[Tensor], # nq, bs, d_model
- tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
- tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
- tgt_key_padding_mask: Optional[Tensor] = None,
- tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
- memory_text: Optional[Tensor] = None, # num_token, bs, d_model
- text_attention_mask: Optional[Tensor] = None, # bs, num_token
- # for memory
- memory: Optional[Tensor] = None, # hw, bs, d_model
- memory_key_padding_mask: Optional[Tensor] = None,
- memory_level_start_index: Optional[Tensor] = None, # num_levels
- memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
- memory_pos: Optional[Tensor] = None, # pos for memory
- # sa
- self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
- cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
- # dac
- dac=False,
- dac_use_selfatt_ln=True,
- presence_token=None,
- # skip inside deformable attn
- identity=0.0,
- **kwargs, # additional kwargs for compatibility
- ):
- """
- Input:
- - tgt/tgt_query_pos: nq, bs, d_model
- -
- """
- # self attention
- if self.self_attn is not None:
- if dac:
- # we only apply self attention to the first half of the queries
- assert tgt.shape[0] % 2 == 0
- num_o2o_queries = tgt.shape[0] // 2
- tgt_o2o = tgt[:num_o2o_queries]
- tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
- tgt_o2m = tgt[num_o2o_queries:]
- else:
- tgt_o2o = tgt
- tgt_query_pos_o2o = tgt_query_pos
- if presence_token is not None:
- tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
- tgt_query_pos_o2o = torch.cat(
- [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
- )
- tgt_query_pos = torch.cat(
- [torch.zeros_like(presence_token), tgt_query_pos], dim=0
- )
- q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
- tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0]
- tgt_o2o = tgt_o2o + self.dropout2(tgt2)
- if dac:
- if not dac_use_selfatt_ln:
- tgt_o2o = self.norm2(tgt_o2o)
- tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine
- if dac_use_selfatt_ln:
- tgt = self.norm2(tgt)
- else:
- tgt = tgt_o2o
- tgt = self.norm2(tgt)
- if self.use_text_cross_attention:
- tgt2 = self.ca_text(
- self.with_pos_embed(tgt, tgt_query_pos),
- memory_text,
- memory_text,
- key_padding_mask=text_attention_mask,
- )[0]
- tgt = tgt + self.catext_dropout(tgt2)
- tgt = self.catext_norm(tgt)
- if presence_token is not None:
- presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
- cross_attn_mask = torch.cat(
- [presence_token_mask, cross_attn_mask], dim=1
- ) # (bs*nheads, 1+nq, hw)
- # Cross attention to image
- tgt2 = self.cross_attn(
- query=self.with_pos_embed(tgt, tgt_query_pos),
- key=self.with_pos_embed(memory, memory_pos),
- value=memory,
- attn_mask=cross_attn_mask,
- key_padding_mask=(
- memory_key_padding_mask.transpose(0, 1)
- if memory_key_padding_mask is not None
- else None
- ),
- )[0]
- tgt = tgt + self.dropout1(tgt2)
- tgt = self.norm1(tgt)
- # ffn
- tgt = self.forward_ffn(tgt)
- presence_token_out = None
- if presence_token is not None:
- presence_token_out = tgt[:1]
- tgt = tgt[1:]
- return tgt, presence_token_out
- class TransformerDecoder(nn.Module):
- def __init__(
- self,
- d_model: int,
- frozen: bool,
- interaction_layer,
- layer,
- num_layers: int,
- num_queries: int,
- return_intermediate: bool,
- box_refine: bool = False,
- num_o2m_queries: int = 0,
- dac: bool = False,
- boxRPB: str = "none",
- # Experimental: An object query for SAM 2 tasks
- instance_query: bool = False,
- # Defines the number of additional instance queries,
- # 1 or 4 are the most likely for single vs multi mask support
- num_instances: int = 1, # Irrelevant if instance_query is False
- dac_use_selfatt_ln: bool = True,
- use_act_checkpoint: bool = False,
- compile_mode=None,
- presence_token: bool = False,
- clamp_presence_logits: bool = True,
- clamp_presence_logit_max_val: float = 10.0,
- use_normed_output_consistently: bool = True,
- separate_box_head_instance: bool = False,
- separate_norm_instance: bool = False,
- resolution: Optional[int] = None,
- stride: Optional[int] = None,
- ):
- super().__init__()
- self.d_model = d_model
- self.layers = get_clones(layer, num_layers)
- self.fine_layers = (
- get_clones(interaction_layer, num_layers)
- if interaction_layer is not None
- else [None] * num_layers
- )
- self.num_layers = num_layers
- self.num_queries = num_queries
- self.dac = dac
- if dac:
- self.num_o2m_queries = num_queries
- tot_num_queries = num_queries
- else:
- self.num_o2m_queries = num_o2m_queries
- tot_num_queries = num_queries + num_o2m_queries
- self.norm = nn.LayerNorm(d_model)
- self.return_intermediate = return_intermediate
- self.bbox_embed = MLP(d_model, d_model, 4, 3)
- self.query_embed = nn.Embedding(tot_num_queries, d_model)
- self.instance_query_embed = None
- self.instance_query_reference_points = None
- self.use_instance_query = instance_query
- self.num_instances = num_instances
- self.use_normed_output_consistently = use_normed_output_consistently
- self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
- self.instance_bbox_embed = None
- if separate_box_head_instance:
- self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
- if instance_query:
- self.instance_query_embed = nn.Embedding(num_instances, d_model)
- self.box_refine = box_refine
- if box_refine:
- nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
- nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
- self.reference_points = nn.Embedding(num_queries, 4)
- if instance_query:
- self.instance_reference_points = nn.Embedding(num_instances, 4)
- assert boxRPB in ["none", "log", "linear", "both"]
- self.boxRPB = boxRPB
- if boxRPB != "none":
- try:
- nheads = self.layers[0].cross_attn_image.num_heads
- except AttributeError:
- nheads = self.layers[0].cross_attn.num_heads
- n_input = 4 if boxRPB == "both" else 2
- self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
- self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
- self.compilable_cord_cache = None
- self.compilable_stored_size = None
- self.coord_cache = {}
- if resolution is not None and stride is not None:
- feat_size = resolution // stride
- coords_h, coords_w = self._get_coords(
- feat_size, feat_size, device="cuda"
- )
- self.compilable_cord_cache = (coords_h, coords_w)
- self.compilable_stored_size = (feat_size, feat_size)
- self.roi_pooler = (
- RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
- if interaction_layer is not None
- else None
- )
- if frozen:
- for p in self.parameters():
- p.requires_grad_(False)
- self.presence_token = None
- self.clamp_presence_logits = clamp_presence_logits
- self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
- if presence_token:
- self.presence_token = nn.Embedding(1, d_model)
- self.presence_token_head = MLP(d_model, d_model, 1, 3)
- self.presence_token_out_norm = nn.LayerNorm(d_model)
- self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
- self.dac_use_selfatt_ln = dac_use_selfatt_ln
- self.use_act_checkpoint = use_act_checkpoint
- nn.init.normal_(self.query_embed.weight.data)
- if self.instance_query_embed is not None:
- nn.init.normal_(self.instance_query_embed.weight.data)
- assert self.roi_pooler is None
- assert self.return_intermediate, "support return_intermediate only"
- assert self.box_refine, "support box refine only"
- self.compile_mode = compile_mode
- self.compiled = False
- # We defer compilation till after the first forward, to first warm-up the boxRPB cache
- # assign layer index to each layer so that some layers can decide what to do
- # based on which layer index they are (e.g. cross attention to memory bank only
- # in selected layers)
- for layer_idx, layer in enumerate(self.layers):
- layer.layer_idx = layer_idx
- @staticmethod
- def _get_coords(H, W, device):
- coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H
- coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W
- return coords_h, coords_w
- def _get_rpb_matrix(self, reference_boxes, feat_size):
- H, W = feat_size
- boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
- bs, num_queries, _ = boxes_xyxy.shape
- if self.compilable_cord_cache is None:
- self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
- self.compilable_stored_size = (H, W)
- if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
- H,
- W,
- ):
- # good, hitting the cache, will be compilable
- coords_h, coords_w = self.compilable_cord_cache
- else:
- # cache miss, will create compilation issue
- # In case we're not compiling, we'll still rely on the dict-based cache
- if feat_size not in self.coord_cache:
- self.coord_cache[feat_size] = self._get_coords(
- H, W, reference_boxes.device
- )
- coords_h, coords_w = self.coord_cache[feat_size]
- assert coords_h.shape == (H,)
- assert coords_w.shape == (W,)
- deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
- deltas_y = deltas_y.view(bs, num_queries, -1, 2)
- deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
- deltas_x = deltas_x.view(bs, num_queries, -1, 2)
- if self.boxRPB in ["log", "both"]:
- deltas_x_log = deltas_x * 8 # normalize to -8, 8
- deltas_x_log = (
- torch.sign(deltas_x_log)
- * torch.log2(torch.abs(deltas_x_log) + 1.0)
- / np.log2(8)
- )
- deltas_y_log = deltas_y * 8 # normalize to -8, 8
- deltas_y_log = (
- torch.sign(deltas_y_log)
- * torch.log2(torch.abs(deltas_y_log) + 1.0)
- / np.log2(8)
- )
- if self.boxRPB == "log":
- deltas_x = deltas_x_log
- deltas_y = deltas_y_log
- else:
- deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
- deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
- if self.training:
- assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
- deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)(
- x=deltas_x,
- act_ckpt_enable=self.training and self.use_act_checkpoint,
- ) # bs, num_queries, W, n_heads
- deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)(
- x=deltas_y,
- act_ckpt_enable=self.training and self.use_act_checkpoint,
- ) # bs, num_queries, H, n_heads
- if not torch.compiler.is_dynamo_compiling():
- assert deltas_x.shape[:3] == (bs, num_queries, W)
- assert deltas_y.shape[:3] == (bs, num_queries, H)
- B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
- 2
- ) # bs, num_queries, H, W, n_heads
- if not torch.compiler.is_dynamo_compiling():
- assert B.shape[:4] == (bs, num_queries, H, W)
- B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
- B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
- B = B.contiguous() # memeff attn likes ordered strides
- if not torch.compiler.is_dynamo_compiling():
- assert B.shape[2:] == (num_queries, H * W)
- return B
- def forward(
- self,
- tgt,
- memory,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4
- # for memory
- level_start_index: Optional[Tensor] = None, # num_levels
- spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
- valid_ratios: Optional[Tensor] = None,
- # for text
- memory_text: Optional[Tensor] = None,
- text_attention_mask: Optional[Tensor] = None,
- # if `apply_dac` is None, it will default to `self.dac`
- apply_dac: Optional[bool] = None,
- is_instance_prompt=False,
- decoder_extra_kwargs: Optional[Dict] = None,
- # ROI memory bank
- obj_roi_memory_feat=None,
- obj_roi_memory_mask=None,
- box_head_trk=None,
- ):
- """
- Input:
- - tgt: nq, bs, d_model
- - memory: \\sum{hw}, bs, d_model
- - pos: \\sum{hw}, bs, d_model
- - reference_boxes: nq, bs, 4 (after sigmoid)
- - valid_ratios/spatial_shapes: bs, nlevel, 2
- """
- if memory_mask is not None:
- assert self.boxRPB == "none", (
- "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
- )
- apply_dac = apply_dac if apply_dac is not None else self.dac
- if apply_dac:
- assert (tgt.shape[0] == self.num_queries) or (
- self.use_instance_query
- and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
- )
- tgt = tgt.repeat(2, 1, 1)
- # note that we don't tile tgt_mask, since DAC doesn't
- # use self-attention in o2m queries
- if reference_boxes is not None:
- assert (reference_boxes.shape[0] == self.num_queries) or (
- self.use_instance_query
- and (
- reference_boxes.shape[0]
- == self.instance_query_embed.num_embeddings
- )
- )
- reference_boxes = reference_boxes.repeat(2, 1, 1)
- bs = tgt.shape[1]
- intermediate = []
- intermediate_presence_logits = []
- presence_feats = None
- if self.box_refine:
- if reference_boxes is None:
- # In this case, we're in a one-stage model, so we generate the reference boxes
- reference_boxes = self.reference_points.weight.unsqueeze(1)
- reference_boxes = (
- reference_boxes.repeat(2, bs, 1)
- if apply_dac
- else reference_boxes.repeat(1, bs, 1)
- )
- reference_boxes = reference_boxes.sigmoid()
- intermediate_ref_boxes = [reference_boxes]
- else:
- reference_boxes = None
- intermediate_ref_boxes = None
- output = tgt
- presence_out = None
- if self.presence_token is not None and is_instance_prompt is False:
- # expand to batch dim
- presence_out = self.presence_token.weight[None].expand(1, bs, -1)
- box_head = self.bbox_embed
- if is_instance_prompt and self.instance_bbox_embed is not None:
- box_head = self.instance_bbox_embed
- out_norm = self.norm
- if is_instance_prompt and self.instance_norm is not None:
- out_norm = self.instance_norm
- for layer_idx, layer in enumerate(self.layers):
- reference_points_input = (
- reference_boxes[:, :, None]
- * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
- ) # nq, bs, nlevel, 4
- query_sine_embed = gen_sineembed_for_position(
- reference_points_input[:, :, 0, :], self.d_model
- ) # nq, bs, d_model*2
- # conditional query
- query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
- if self.boxRPB != "none" and reference_boxes is not None:
- assert spatial_shapes.shape[0] == 1, (
- "only single scale support implemented"
- )
- memory_mask = self._get_rpb_matrix(
- reference_boxes,
- (spatial_shapes[0, 0], spatial_shapes[0, 1]),
- )
- memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
- if self.training:
- assert self.use_act_checkpoint, (
- "Activation checkpointing not enabled in the decoder"
- )
- output, presence_out = activation_ckpt_wrapper(layer)(
- tgt=output,
- tgt_query_pos=query_pos,
- tgt_query_sine_embed=query_sine_embed,
- tgt_key_padding_mask=tgt_key_padding_mask,
- tgt_reference_points=reference_points_input,
- memory_text=memory_text,
- text_attention_mask=text_attention_mask,
- memory=memory,
- memory_key_padding_mask=memory_key_padding_mask,
- memory_level_start_index=level_start_index,
- memory_spatial_shapes=spatial_shapes,
- memory_pos=pos,
- self_attn_mask=tgt_mask,
- cross_attn_mask=memory_mask,
- dac=apply_dac,
- dac_use_selfatt_ln=self.dac_use_selfatt_ln,
- presence_token=presence_out,
- **(decoder_extra_kwargs or {}),
- act_ckpt_enable=self.training and self.use_act_checkpoint,
- # ROI memory bank
- obj_roi_memory_feat=obj_roi_memory_feat,
- obj_roi_memory_mask=obj_roi_memory_mask,
- )
- # iter update
- if self.box_refine:
- reference_before_sigmoid = inverse_sigmoid(reference_boxes)
- if box_head_trk is None:
- # delta_unsig = self.bbox_embed(output)
- if not self.use_normed_output_consistently:
- delta_unsig = box_head(output)
- else:
- delta_unsig = box_head(out_norm(output))
- else:
- # box_head_trk use a separate box head for tracking queries
- Q_det = decoder_extra_kwargs["Q_det"]
- assert output.size(0) >= Q_det
- delta_unsig_det = self.bbox_embed(output[:Q_det])
- delta_unsig_trk = box_head_trk(output[Q_det:])
- delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
- outputs_unsig = delta_unsig + reference_before_sigmoid
- new_reference_points = outputs_unsig.sigmoid()
- reference_boxes = new_reference_points.detach()
- if layer_idx != self.num_layers - 1:
- intermediate_ref_boxes.append(new_reference_points)
- else:
- raise NotImplementedError("not implemented yet")
- intermediate.append(out_norm(output))
- if self.presence_token is not None and is_instance_prompt is False:
- # norm, mlp head
- intermediate_layer_presence_logits = self.presence_token_head(
- self.presence_token_out_norm(presence_out)
- ).squeeze(-1)
- # clamp to mitigate numerical issues
- if self.clamp_presence_logits:
- intermediate_layer_presence_logits.clamp(
- min=-self.clamp_presence_logit_max_val,
- max=self.clamp_presence_logit_max_val,
- )
- intermediate_presence_logits.append(intermediate_layer_presence_logits)
- presence_feats = presence_out.clone()
- if not self.compiled and self.compile_mode is not None:
- self.forward = torch.compile(
- self.forward, mode=self.compile_mode, fullgraph=True
- )
- self.compiled = True
- return (
- torch.stack(intermediate),
- torch.stack(intermediate_ref_boxes),
- (
- torch.stack(intermediate_presence_logits)
- if self.presence_token is not None and is_instance_prompt is False
- else None
- ),
- presence_feats,
- )
- class TransformerEncoderCrossAttention(nn.Module):
- def __init__(
- self,
- d_model: int,
- frozen: bool,
- pos_enc_at_input: bool,
- layer,
- num_layers: int,
- use_act_checkpoint: bool = False,
- batch_first: bool = False, # Do layers expect batch first input?
- # which layers to exclude cross attention? default: None, means all
- # layers use cross attention
- remove_cross_attention_layers: Optional[list] = None,
- ):
- super().__init__()
- self.d_model = d_model
- self.layers = get_clones(layer, num_layers)
- self.num_layers = num_layers
- self.norm = nn.LayerNorm(d_model)
- self.pos_enc_at_input = pos_enc_at_input
- self.use_act_checkpoint = use_act_checkpoint
- if frozen:
- for p in self.parameters():
- p.requires_grad_(False)
- self.batch_first = batch_first
- # remove cross attention layers if specified
- self.remove_cross_attention_layers = [False] * self.num_layers
- if remove_cross_attention_layers is not None:
- for i in remove_cross_attention_layers:
- self.remove_cross_attention_layers[i] = True
- assert len(self.remove_cross_attention_layers) == len(self.layers)
- for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers):
- if remove_cross_attention:
- self.layers[i].cross_attn_image = None
- self.layers[i].norm2 = None
- self.layers[i].dropout2 = None
- def forward(
- self,
- src, # self-attention inputs
- prompt, # cross-attention inputs
- src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs
- prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs
- src_key_padding_mask: Optional[Tensor] = None,
- prompt_key_padding_mask: Optional[Tensor] = None,
- src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
- prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
- feat_sizes: Optional[list] = None,
- num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
- ):
- if isinstance(src, list):
- assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list)
- assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1
- src, src_key_padding_mask, src_pos = (
- src[0],
- src_key_padding_mask[0],
- src_pos[0],
- )
- assert src.shape[1] == prompt.shape[1], (
- "Batch size must be the same for src and prompt"
- )
- output = src
- if self.pos_enc_at_input and src_pos is not None:
- output = output + 0.1 * src_pos
- if self.batch_first:
- # Convert to batch first
- output = output.transpose(0, 1)
- src_pos = src_pos.transpose(0, 1)
- prompt = prompt.transpose(0, 1)
- prompt_pos = prompt_pos.transpose(0, 1)
- for layer in self.layers:
- kwds = {}
- if isinstance(layer.cross_attn_image, RoPEAttention):
- kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
- output = activation_ckpt_wrapper(layer)(
- tgt=output,
- memory=prompt,
- tgt_mask=src_mask,
- memory_mask=prompt_mask,
- tgt_key_padding_mask=src_key_padding_mask,
- memory_key_padding_mask=prompt_key_padding_mask,
- pos=prompt_pos,
- query_pos=src_pos,
- dac=False,
- attn_bias=None,
- act_ckpt_enable=self.training and self.use_act_checkpoint,
- **kwds,
- )
- normed_output = self.norm(output)
- if self.batch_first:
- # Convert back to seq first
- normed_output = normed_output.transpose(0, 1)
- src_pos = src_pos.transpose(0, 1)
- return {
- "memory": normed_output,
- "pos_embed": src_pos,
- "padding_mask": src_key_padding_mask,
- }
- class TransformerDecoderLayerv1(nn.Module):
- def __init__(
- self,
- activation: str,
- cross_attention: nn.Module,
- d_model: int,
- dim_feedforward: int,
- dropout: float,
- pos_enc_at_attn: bool,
- pos_enc_at_cross_attn_keys: bool,
- pos_enc_at_cross_attn_queries: bool,
- pre_norm: bool,
- self_attention: nn.Module,
- ):
- super().__init__()
- self.d_model = d_model
- self.dim_feedforward = dim_feedforward
- self.dropout_value = dropout
- self.self_attn = self_attention
- self.cross_attn_image = cross_attention
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.norm3 = nn.LayerNorm(d_model)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
- self.activation_str = activation
- self.activation = get_activation_fn(activation)
- self.pre_norm = pre_norm
- self.pos_enc_at_attn = pos_enc_at_attn
- self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
- self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
- def forward_post(
- self,
- tgt,
- memory,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- **kwargs,
- ):
- q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
- # Self attention
- tgt2 = self.self_attn(
- q,
- k,
- value=tgt,
- attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask,
- )[0]
- tgt = tgt + self.dropout1(tgt2)
- tgt = self.norm1(tgt)
- # Cross attention to image
- tgt2 = self.cross_attn_image(
- query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
- key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
- value=memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask,
- )[0]
- tgt = tgt + self.dropout2(tgt2)
- tgt = self.norm2(tgt)
- # FFN
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
- tgt = tgt + self.dropout3(tgt2)
- tgt = self.norm3(tgt)
- return tgt
- def forward_pre(
- self,
- tgt,
- memory,
- dac: bool = False,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- attn_bias: Optional[Tensor] = None,
- **kwargs,
- ):
- if dac:
- # we only apply self attention to the first half of the queries
- assert tgt.shape[0] % 2 == 0
- other_tgt = tgt[tgt.shape[0] // 2 :]
- tgt = tgt[: tgt.shape[0] // 2]
- tgt2 = self.norm1(tgt)
- q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
- tgt2 = self.self_attn(
- q,
- k,
- value=tgt2,
- attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask,
- )[0]
- tgt = tgt + self.dropout1(tgt2)
- if dac:
- # Recombine
- tgt = torch.cat((tgt, other_tgt), dim=0)
- tgt2 = self.norm2(tgt)
- tgt2 = self.cross_attn_image(
- query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
- key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
- value=memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask,
- attn_bias=attn_bias,
- )[0]
- tgt = tgt + self.dropout2(tgt2)
- tgt2 = self.norm3(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout3(tgt2)
- return tgt
- def forward(
- self,
- tgt,
- memory,
- dac: bool = False,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- attn_bias: Optional[Tensor] = None,
- **kwds: Any,
- ) -> torch.Tensor:
- fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
- return fwd_fn(
- tgt,
- memory,
- dac=dac,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask,
- pos=pos,
- query_pos=query_pos,
- attn_bias=attn_bias,
- **kwds,
- )
- class TransformerDecoderLayerv2(TransformerDecoderLayerv1):
- def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any):
- super().__init__(*args, **kwds)
- self.cross_attention_first = cross_attention_first
- def _forward_sa(self, tgt, query_pos):
- # Self-Attention
- tgt2 = self.norm1(tgt)
- q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
- tgt2 = self.self_attn(q, k, v=tgt2)
- tgt = tgt + self.dropout1(tgt2)
- return tgt
- def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
- if self.cross_attn_image is None:
- return tgt
- kwds = {}
- if num_k_exclude_rope > 0:
- assert isinstance(self.cross_attn_image, RoPEAttention)
- kwds = {"num_k_exclude_rope": num_k_exclude_rope}
- # Cross-Attention
- tgt2 = self.norm2(tgt)
- tgt2 = self.cross_attn_image(
- q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
- k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
- v=memory,
- **kwds,
- )
- tgt = tgt + self.dropout2(tgt2)
- return tgt
- def forward_pre(
- self,
- tgt,
- memory,
- dac: bool,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- attn_bias: Optional[Tensor] = None,
- num_k_exclude_rope: int = 0,
- ):
- assert dac is False
- assert tgt_mask is None
- assert memory_mask is None
- assert tgt_key_padding_mask is None
- assert memory_key_padding_mask is None
- assert attn_bias is None
- if self.cross_attention_first:
- tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
- tgt = self._forward_sa(tgt, query_pos)
- else:
- tgt = self._forward_sa(tgt, query_pos)
- tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
- # MLP
- tgt2 = self.norm3(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout3(tgt2)
- return tgt
- def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
- if self.pre_norm:
- return self.forward_pre(*args, **kwds)
- raise NotImplementedError
|