| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import Optional
- import torch
- from torch import nn, Tensor
- from sam2.modeling.sam.transformer import RoPEAttention
- from sam2.modeling.sam2_utils import get_activation_fn, get_clones
- class MemoryAttentionLayer(nn.Module):
- def __init__(
- self,
- activation: str,
- cross_attention: nn.Module,
- d_model: int,
- dim_feedforward: int,
- dropout: float,
- pos_enc_at_attn: bool,
- pos_enc_at_cross_attn_keys: bool,
- pos_enc_at_cross_attn_queries: bool,
- self_attention: nn.Module,
- ):
- super().__init__()
- self.d_model = d_model
- self.dim_feedforward = dim_feedforward
- self.dropout_value = dropout
- self.self_attn = self_attention
- self.cross_attn_image = cross_attention
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.norm3 = nn.LayerNorm(d_model)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
- self.activation_str = activation
- self.activation = get_activation_fn(activation)
- # Where to add pos enc
- self.pos_enc_at_attn = pos_enc_at_attn
- self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
- self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
- def _forward_sa(self, tgt, query_pos):
- # Self-Attention
- tgt2 = self.norm1(tgt)
- q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
- tgt2 = self.self_attn(q, k, v=tgt2)
- tgt = tgt + self.dropout1(tgt2)
- return tgt
- def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
- kwds = {}
- if num_k_exclude_rope > 0:
- assert isinstance(self.cross_attn_image, RoPEAttention)
- kwds = {"num_k_exclude_rope": num_k_exclude_rope}
- # Cross-Attention
- tgt2 = self.norm2(tgt)
- tgt2 = self.cross_attn_image(
- q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
- k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
- v=memory,
- **kwds,
- )
- tgt = tgt + self.dropout2(tgt2)
- return tgt
- def forward(
- self,
- tgt,
- memory,
- pos: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None,
- num_k_exclude_rope: int = 0,
- ) -> torch.Tensor:
- # Self-Attn, Cross-Attn
- tgt = self._forward_sa(tgt, query_pos)
- tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
- # MLP
- tgt2 = self.norm3(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout3(tgt2)
- return tgt
- class MemoryAttention(nn.Module):
- def __init__(
- self,
- d_model: int,
- pos_enc_at_input: bool,
- layer: nn.Module,
- num_layers: int,
- batch_first: bool = True, # Do layers expect batch first input?
- ):
- super().__init__()
- self.d_model = d_model
- self.layers = get_clones(layer, num_layers)
- self.num_layers = num_layers
- self.norm = nn.LayerNorm(d_model)
- self.pos_enc_at_input = pos_enc_at_input
- self.batch_first = batch_first
- def forward(
- self,
- curr: torch.Tensor, # self-attention inputs
- memory: torch.Tensor, # cross-attention inputs
- curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
- memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
- num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
- ):
- if isinstance(curr, list):
- assert isinstance(curr_pos, list)
- assert len(curr) == len(curr_pos) == 1
- curr, curr_pos = (
- curr[0],
- curr_pos[0],
- )
- assert (
- curr.shape[1] == memory.shape[1]
- ), "Batch size must be the same for curr and memory"
- output = curr
- if self.pos_enc_at_input and curr_pos is not None:
- output = output + 0.1 * curr_pos
- if self.batch_first:
- # Convert to batch first
- output = output.transpose(0, 1)
- curr_pos = curr_pos.transpose(0, 1)
- memory = memory.transpose(0, 1)
- memory_pos = memory_pos.transpose(0, 1)
- for layer in self.layers:
- kwds = {}
- if isinstance(layer.cross_attn_image, RoPEAttention):
- kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
- output = layer(
- tgt=output,
- memory=memory,
- pos=memory_pos,
- query_pos=curr_pos,
- **kwds,
- )
- normed_output = self.norm(output)
- if self.batch_first:
- # Convert back to seq first
- normed_output = normed_output.transpose(0, 1)
- curr_pos = curr_pos.transpose(0, 1)
- return normed_output
|