| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- from collections import OrderedDict
- from typing import Callable, List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from torch.utils.checkpoint import checkpoint
- from .model_misc import LayerScale
- class ResidualAttentionBlock(nn.Module):
- def __init__(
- self,
- d_model: int,
- n_head: int,
- mlp_ratio: float = 4.0,
- ls_init_value: Optional[float] = None,
- act_layer: Callable[[], nn.Module] = nn.GELU,
- norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
- ):
- super().__init__()
- # Attention
- self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
- # LayerNorm, LayerScale
- self.ln_1 = norm_layer(d_model)
- self.ln_2 = norm_layer(d_model)
- self.ls_1 = (
- LayerScale(d_model, ls_init_value)
- if ls_init_value is not None
- else nn.Identity()
- )
- self.ls_2 = (
- LayerScale(d_model, ls_init_value)
- if ls_init_value is not None
- else nn.Identity()
- )
- # MLP
- mlp_width = int(d_model * mlp_ratio)
- self.mlp = nn.Sequential(
- OrderedDict(
- [
- ("c_fc", nn.Linear(d_model, mlp_width)),
- ("gelu", act_layer()),
- ("c_proj", nn.Linear(mlp_width, d_model)),
- ]
- )
- )
- def attention(
- self,
- q_x: torch.Tensor,
- k_x: Optional[torch.Tensor] = None,
- v_x: Optional[torch.Tensor] = None,
- attn_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- k_x = k_x if k_x is not None else q_x
- v_x = v_x if v_x is not None else q_x
- if attn_mask is not None:
- # Leave boolean masks as is
- if not attn_mask.dtype == torch.bool:
- attn_mask = attn_mask.to(q_x.dtype)
- return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
- def forward(
- self,
- q_x: torch.Tensor,
- k_x: Optional[torch.Tensor] = None,
- v_x: Optional[torch.Tensor] = None,
- attn_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- k_x = (
- self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
- )
- v_x = (
- self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
- )
- x = q_x + self.ls_1(
- self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
- )
- x = x + self.ls_2(self.mlp(self.ln_2(x)))
- return x
- class Transformer(nn.Module):
- def __init__(
- self,
- width: int,
- layers: int,
- heads: int,
- mlp_ratio: float = 4.0,
- ls_init_value: Optional[float] = None,
- act_layer: Callable[[], nn.Module] = nn.GELU,
- norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
- compile_mode: Optional[str] = None,
- use_act_checkpoint: bool = False,
- ):
- super().__init__()
- self.width = width
- self.layers = layers
- self.grad_checkpointing = use_act_checkpoint
- self.resblocks = nn.ModuleList(
- [
- ResidualAttentionBlock(
- width,
- heads,
- mlp_ratio,
- ls_init_value=ls_init_value,
- act_layer=act_layer,
- norm_layer=norm_layer,
- )
- for _ in range(layers)
- ]
- )
- if compile_mode is not None:
- self.forward = torch.compile(
- self.forward, mode=compile_mode, fullgraph=True
- )
- if self.grad_checkpointing:
- torch._dynamo.config.optimize_ddp = False
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- for _, r in enumerate(self.resblocks):
- if (
- self.grad_checkpointing
- and not torch.jit.is_scripting()
- and self.training
- ):
- x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
- else:
- x = r(
- x,
- attn_mask=attn_mask,
- )
- return x
- def text_global_pool(
- x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if pool_type == "first":
- pooled, tokens = x[:, 0], x[:, 1:]
- elif pool_type == "last":
- pooled, tokens = x[:, -1], x[:, :-1]
- elif pool_type == "argmax":
- # take features from the eot embedding (eot_token is the highest number in each sequence)
- assert text is not None
- pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
- else:
- pooled = tokens = x
- return pooled, tokens
- class TextTransformer(nn.Module):
- def __init__(
- self,
- context_length: int = 77,
- vocab_size: int = 49408,
- width: int = 512,
- heads: int = 8,
- layers: int = 12,
- mlp_ratio: float = 4.0,
- ls_init_value: Optional[float] = None,
- output_dim: int = 512,
- no_causal_mask: bool = False,
- pool_type: str = "none", # no pooling
- proj_bias: bool = False,
- act_layer: Callable = nn.GELU,
- norm_layer: Callable = nn.LayerNorm,
- output_tokens: bool = False,
- use_ln_post: bool = True,
- compile_mode: Optional[str] = None,
- use_act_checkpoint: bool = False,
- ):
- super().__init__()
- assert pool_type in ("first", "last", "argmax", "none")
- self.output_tokens = output_tokens
- self.num_pos = self.context_length = context_length
- self.vocab_size = vocab_size
- self.width = width
- self.output_dim = output_dim
- self.heads = heads
- self.pool_type = pool_type
- self.token_embedding = nn.Embedding(self.vocab_size, width)
- self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
- self.transformer = Transformer(
- width=width,
- layers=layers,
- heads=heads,
- mlp_ratio=mlp_ratio,
- ls_init_value=ls_init_value,
- act_layer=act_layer,
- norm_layer=norm_layer,
- compile_mode=compile_mode,
- use_act_checkpoint=use_act_checkpoint,
- )
- self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
- if no_causal_mask:
- self.attn_mask = None
- else:
- self.register_buffer(
- "attn_mask", self.build_causal_mask(), persistent=False
- )
- if proj_bias:
- self.text_projection = nn.Linear(width, output_dim)
- else:
- self.text_projection = nn.Parameter(torch.empty(width, output_dim))
- def build_causal_mask(self) -> torch.Tensor:
- # lazily create causal attention mask, with full attention between the tokens
- # pytorch uses additive attention mask; fill with -inf
- mask = torch.empty(self.num_pos, self.num_pos)
- mask.fill_(float("-inf"))
- mask.triu_(1) # zero out the lower diagonal
- return mask
- def forward(
- self, text: torch.Tensor
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- seq_len = text.shape[1]
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
- attn_mask = self.attn_mask
- if attn_mask is not None:
- attn_mask = attn_mask[:seq_len, :seq_len]
- x = x + self.positional_embedding[:seq_len]
- x = self.transformer(x, attn_mask=attn_mask)
- x = self.ln_final(x)
- pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
- if self.text_projection is not None:
- if isinstance(self.text_projection, nn.Linear):
- pooled = self.text_projection(pooled)
- else:
- pooled = pooled @ self.text_projection
- if self.output_tokens:
- return pooled, tokens
- return pooled
- class VETextEncoder(nn.Module):
- def __init__(
- self,
- d_model: int,
- tokenizer: Callable,
- width: int = 1024,
- heads: int = 16,
- layers: int = 24,
- context_length: int = 32,
- vocab_size: int = 49408,
- use_ln_post: bool = True,
- compile_mode: Optional[str] = None,
- use_act_checkpoint: bool = True,
- ):
- super().__init__()
- self.context_length = context_length
- self.use_ln_post = use_ln_post
- self.tokenizer = tokenizer
- self.encoder = TextTransformer(
- context_length=self.context_length,
- vocab_size=vocab_size,
- width=width,
- heads=heads,
- layers=layers,
- # we want the tokens, not just the pooled output
- output_tokens=True,
- use_ln_post=use_ln_post,
- compile_mode=compile_mode,
- use_act_checkpoint=use_act_checkpoint,
- )
- self.resizer = nn.Linear(self.encoder.width, d_model)
- def forward(
- self,
- text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
- input_boxes: Optional[List] = None,
- device: torch.device = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- if isinstance(text[0], str):
- # no use case for this
- assert input_boxes is None or len(input_boxes) == 0, "not supported"
- # Encode the text
- tokenized = self.tokenizer(text, context_length=self.context_length).to(
- device
- ) # [b, seq_len]
- text_attention_mask = (tokenized != 0).bool()
- # manually embed the tokens
- inputs_embeds = self.encoder.token_embedding(
- tokenized
- ) # [b, seq_len, d=1024]
- _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
- assert text_memory.shape[1] == inputs_embeds.shape[1]
- # Invert attention mask because its the opposite in pytorch transformer
- text_attention_mask = text_attention_mask.ne(1)
- # Transpose memory because pytorch's attention expects sequence first
- text_memory = text_memory.transpose(0, 1)
- # Resize the encoder hidden states to be of the same d_model as the decoder
- text_memory_resized = self.resizer(text_memory)
- else:
- # The text is already encoded, use as is.
- text_attention_mask, text_memory_resized, tokenized = text
- inputs_embeds = tokenized["inputs_embeds"]
- assert input_boxes is None or len(input_boxes) == 0, (
- "Can't replace boxes in text if it's already encoded"
- )
- # Note that the input_embeds are returned in pytorch's convention (sequence first)
- return (
- text_attention_mask,
- text_memory_resized,
- inputs_embeds.transpose(0, 1),
- )
|