# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe from collections import OrderedDict from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from .model_misc import LayerScale class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: Optional[float] = None, act_layer: Callable[[], nn.Module] = nn.GELU, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, ): super().__init__() # Attention self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) # LayerNorm, LayerScale self.ln_1 = norm_layer(d_model) self.ln_2 = norm_layer(d_model) self.ls_1 = ( LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() ) self.ls_2 = ( LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() ) # MLP mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential( OrderedDict( [ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)), ] ) ) def attention( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x if attn_mask is not None: # Leave boolean masks as is if not attn_mask.dtype == torch.bool: attn_mask = attn_mask.to(q_x.dtype) return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0] def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: k_x = ( self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None ) v_x = ( self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None ) x = q_x + self.ls_1( self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) ) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: Optional[float] = None, act_layer: Callable[[], nn.Module] = nn.GELU, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, compile_mode: Optional[str] = None, use_act_checkpoint: bool = False, ): super().__init__() self.width = width self.layers = layers self.grad_checkpointing = use_act_checkpoint self.resblocks = nn.ModuleList( [ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) for _ in range(layers) ] ) if compile_mode is not None: self.forward = torch.compile( self.forward, mode=compile_mode, fullgraph=True ) if self.grad_checkpointing: torch._dynamo.config.optimize_ddp = False def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: for _, r in enumerate(self.resblocks): if ( self.grad_checkpointing and not torch.jit.is_scripting() and self.training ): x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) else: x = r( x, attn_mask=attn_mask, ) return x def text_global_pool( x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax" ) -> Tuple[torch.Tensor, torch.Tensor]: if pool_type == "first": pooled, tokens = x[:, 0], x[:, 1:] elif pool_type == "last": pooled, tokens = x[:, -1], x[:, :-1] elif pool_type == "argmax": # take features from the eot embedding (eot_token is the highest number in each sequence) assert text is not None pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x else: pooled = tokens = x return pooled, tokens class TextTransformer(nn.Module): def __init__( self, context_length: int = 77, vocab_size: int = 49408, width: int = 512, heads: int = 8, layers: int = 12, mlp_ratio: float = 4.0, ls_init_value: Optional[float] = None, output_dim: int = 512, no_causal_mask: bool = False, pool_type: str = "none", # no pooling proj_bias: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, output_tokens: bool = False, use_ln_post: bool = True, compile_mode: Optional[str] = None, use_act_checkpoint: bool = False, ): super().__init__() assert pool_type in ("first", "last", "argmax", "none") self.output_tokens = output_tokens self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim self.heads = heads self.pool_type = pool_type self.token_embedding = nn.Embedding(self.vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, compile_mode=compile_mode, use_act_checkpoint=use_act_checkpoint, ) self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() if no_causal_mask: self.attn_mask = None else: self.register_buffer( "attn_mask", self.build_causal_mask(), persistent=False ) if proj_bias: self.text_projection = nn.Linear(width, output_dim) else: self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def build_causal_mask(self) -> torch.Tensor: # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def forward( self, text: torch.Tensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: seq_len = text.shape[1] x = self.token_embedding(text) # [batch_size, n_ctx, d_model] attn_mask = self.attn_mask if attn_mask is not None: attn_mask = attn_mask[:seq_len, :seq_len] x = x + self.positional_embedding[:seq_len] x = self.transformer(x, attn_mask=attn_mask) x = self.ln_final(x) pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): pooled = self.text_projection(pooled) else: pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens return pooled class VETextEncoder(nn.Module): def __init__( self, d_model: int, tokenizer: Callable, width: int = 1024, heads: int = 16, layers: int = 24, context_length: int = 32, vocab_size: int = 49408, use_ln_post: bool = True, compile_mode: Optional[str] = None, use_act_checkpoint: bool = True, ): super().__init__() self.context_length = context_length self.use_ln_post = use_ln_post self.tokenizer = tokenizer self.encoder = TextTransformer( context_length=self.context_length, vocab_size=vocab_size, width=width, heads=heads, layers=layers, # we want the tokens, not just the pooled output output_tokens=True, use_ln_post=use_ln_post, compile_mode=compile_mode, use_act_checkpoint=use_act_checkpoint, ) self.resizer = nn.Linear(self.encoder.width, d_model) def forward( self, text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]], input_boxes: Optional[List] = None, device: torch.device = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if isinstance(text[0], str): # no use case for this assert input_boxes is None or len(input_boxes) == 0, "not supported" # Encode the text tokenized = self.tokenizer(text, context_length=self.context_length).to( device ) # [b, seq_len] text_attention_mask = (tokenized != 0).bool() # manually embed the tokens inputs_embeds = self.encoder.token_embedding( tokenized ) # [b, seq_len, d=1024] _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024] assert text_memory.shape[1] == inputs_embeds.shape[1] # Invert attention mask because its the opposite in pytorch transformer text_attention_mask = text_attention_mask.ne(1) # Transpose memory because pytorch's attention expects sequence first text_memory = text_memory.transpose(0, 1) # Resize the encoder hidden states to be of the same d_model as the decoder text_memory_resized = self.resizer(text_memory) else: # The text is already encoded, use as is. text_attention_mask, text_memory_resized, tokenized = text inputs_embeds = tokenized["inputs_embeds"] assert input_boxes is None or len(input_boxes) == 0, ( "Can't replace boxes in text if it's already encoded" ) # Note that the input_embeds are returned in pytorch's convention (sequence first) return ( text_attention_mask, text_memory_resized, inputs_embeds.transpose(0, 1), )