| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- """Various utility models"""
- import copy
- import math
- import weakref
- from collections.abc import Iterator
- from contextlib import AbstractContextManager
- from enum import auto, Enum
- from typing import Dict, List, Optional, Union
- import numpy as np
- import torch
- import torch.nn.functional as F
- from torch import nn, Tensor
- from typing_extensions import override
- def inverse_sigmoid(x, eps=1e-3):
- """
- The inverse function for sigmoid activation function.
- Note: It might face numberical issues with fp16 small eps.
- """
- x = x.clamp(min=0, max=1)
- x1 = x.clamp(min=eps)
- x2 = (1 - x).clamp(min=eps)
- return torch.log(x1 / x2)
- class MultiheadAttentionWrapper(nn.MultiheadAttention):
- def forward(self, *args, **kwargs):
- kwargs["need_weights"] = False
- return super().forward(*args, **kwargs)
- class DotProductScoring(torch.nn.Module):
- def __init__(
- self,
- d_model,
- d_proj,
- prompt_mlp=None,
- clamp_logits=True,
- clamp_max_val=12.0,
- ):
- super().__init__()
- self.d_proj = d_proj
- assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
- self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
- self.prompt_proj = torch.nn.Linear(d_model, d_proj)
- self.hs_proj = torch.nn.Linear(d_model, d_proj)
- self.scale = float(1.0 / np.sqrt(d_proj))
- self.clamp_logits = clamp_logits
- if self.clamp_logits:
- self.clamp_max_val = clamp_max_val
- def mean_pool_text(self, prompt, prompt_mask):
- # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
- is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
- # num_valid has shape (bs, 1)
- num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
- # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
- pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
- return pooled_prompt
- def forward(self, hs, prompt, prompt_mask):
- # hs has shape (num_layer, bs, num_query, d_model)
- # prompt has shape (seq, bs, d_model)
- # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
- assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
- # apply MLP on prompt if specified
- if self.prompt_mlp is not None:
- prompt = self.prompt_mlp(prompt)
- # first, get the mean-pooled version of the prompt
- pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
- # then, project pooled_prompt and hs to d_proj dimensions
- proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
- proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
- # finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
- scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
- scores *= self.scale
- # clamp scores to a max value to avoid numerical issues in loss or matcher
- if self.clamp_logits:
- scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
- return scores
- class LayerScale(nn.Module):
- def __init__(
- self,
- dim: int,
- init_values: Union[float, Tensor] = 1e-5,
- inplace: bool = False,
- ) -> None:
- super().__init__()
- self.inplace = inplace
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
- def forward(self, x: Tensor) -> Tensor:
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
- class LayerNorm2d(nn.Module):
- def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(num_channels))
- self.bias = nn.Parameter(torch.zeros(num_channels))
- self.eps = eps
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- u = x.mean(1, keepdim=True)
- s = (x - u).pow(2).mean(1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.eps)
- x = self.weight[:, None, None] * x + self.bias[:, None, None]
- return x
- class TransformerWrapper(nn.Module):
- def __init__(
- self,
- encoder,
- decoder,
- d_model: int,
- two_stage_type="none", # ["none"] only for now
- pos_enc_at_input_dec=True,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- self.num_queries = decoder.num_queries if decoder is not None else None
- self.pos_enc_at_input_dec = pos_enc_at_input_dec
- # for two stage
- assert two_stage_type in ["none"], "unknown param {} of two_stage_type".format(
- two_stage_type
- )
- self.two_stage_type = two_stage_type
- self._reset_parameters()
- self.d_model = d_model
- def _reset_parameters(self):
- for n, p in self.named_parameters():
- if p.dim() > 1:
- if (
- "box_embed" not in n
- and "query_embed" not in n
- and "reference_points" not in n
- ):
- nn.init.xavier_uniform_(p)
- class MLP(nn.Module):
- """Very simple multi-layer perceptron (also called FFN)"""
- def __init__(
- self,
- input_dim: int,
- hidden_dim: int,
- output_dim: int,
- num_layers: int,
- dropout: float = 0.0,
- residual: bool = False,
- out_norm: Optional[nn.Module] = None,
- ):
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(
- nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
- )
- self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
- # whether to add the output as a residual connection to the input
- if residual and input_dim != output_dim:
- raise ValueError("residual is only supported if input_dim == output_dim")
- self.residual = residual
- # whether to apply a normalization layer to the output
- assert isinstance(out_norm, nn.Module) or out_norm is None
- self.out_norm = out_norm or nn.Identity()
- def forward(self, x):
- orig_x = x
- for i, layer in enumerate(self.layers):
- x = self.drop(F.relu(layer(x))) if i < self.num_layers - 1 else layer(x)
- if self.residual:
- x = x + orig_x
- x = self.out_norm(x)
- return x
- def get_clones(module, N):
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
- def get_clones_seq(module, N):
- return nn.Sequential(*[copy.deepcopy(module) for i in range(N)])
- def get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
- def get_activation_module(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return nn.ReLU
- if activation == "gelu":
- return nn.GELU
- if activation == "glu":
- return nn.GLU
- raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
- def get_valid_ratio(mask):
- _, H, W = mask.shape
- valid_H = torch.sum(~mask[:, :, 0], 1)
- valid_W = torch.sum(~mask[:, 0, :], 1)
- valid_ratio_h = valid_H.float() / H
- valid_ratio_w = valid_W.float() / W
- valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
- return valid_ratio
- def gen_sineembed_for_position(pos_tensor, num_feats=256):
- assert num_feats % 2 == 0
- num_feats = num_feats // 2
- # n_query, bs, _ = pos_tensor.size()
- # sineembed_tensor = torch.zeros(n_query, bs, 256)
- scale = 2 * math.pi
- dim_t = torch.arange(num_feats, dtype=torch.float32, device=pos_tensor.device)
- dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
- x_embed = pos_tensor[:, :, 0] * scale
- y_embed = pos_tensor[:, :, 1] * scale
- pos_x = x_embed[:, :, None] / dim_t
- pos_y = y_embed[:, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
- ).flatten(2)
- pos_y = torch.stack(
- (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
- ).flatten(2)
- if pos_tensor.size(-1) == 2:
- pos = torch.cat((pos_y, pos_x), dim=2)
- elif pos_tensor.size(-1) == 4:
- w_embed = pos_tensor[:, :, 2] * scale
- pos_w = w_embed[:, :, None] / dim_t
- pos_w = torch.stack(
- (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
- ).flatten(2)
- h_embed = pos_tensor[:, :, 3] * scale
- pos_h = h_embed[:, :, None] / dim_t
- pos_h = torch.stack(
- (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
- ).flatten(2)
- pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
- else:
- raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
- return pos
- class SAM3Output(list):
- """
- A class representing the output of a SAM3 model.
- It provides an iterable interface that supports different iteration modes, including iterating over all steps per stage,
- last step per stage, and flattened output.
- Attributes:
- output: The output of the SAM3 model, represented as a list of lists.
- iter_mode: The current iteration mode.
- Example:
- >>> output = [[1, 2], [3, 4], [5, 6]]
- >>> sam3_output = SAM3Output(output)
- >>> for step in sam3_output:
- ... print(step)
- [1, 2]
- [3, 4]
- [5, 6]
- >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.LAST_STEP_PER_STAGE) as sam3_last_step_out:
- ... for step in sam3_last_step_out:
- ... print(step)
- [2]
- [4]
- [6]
- >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.FLATTENED) as sam3_flattened_out:
- ... for step in sam3_flattened_out:
- ... print(step)
- 1
- 2
- 3
- 4
- 5
- 6
- """
- class IterMode(Enum):
- # Defines the type of iterator over ouptuts.
- ALL_STEPS_PER_STAGE = auto()
- LAST_STEP_PER_STAGE = auto()
- FLATTENED = auto() # Returns each interactivity step as if it is a separate stage (this is used in SAM3Image model)
- def __init__(
- self,
- output: List[List[Dict]] = None,
- iter_mode: IterMode = IterMode.ALL_STEPS_PER_STAGE,
- loss_stages: Optional[List[int]] = None,
- ):
- if output is not None:
- assert (
- isinstance(output, list)
- and len(output) > 0
- and isinstance(output[0], list)
- ), "Expected output to be a list of lists"
- self.output = output
- else:
- self.output = []
- assert isinstance(iter_mode, SAM3Output.IterMode), (
- f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
- )
- self.iter_mode = iter_mode
- # We create a weak reference to self to be used in the lambda functions.
- # This is to avoid cyclic references and let SAM3Output be garabge collected.
- self_ref = weakref.ref(self)
- self._mode2iter = {
- SAM3Output.IterMode.ALL_STEPS_PER_STAGE: lambda: iter(self_ref().output),
- SAM3Output.IterMode.LAST_STEP_PER_STAGE: lambda: (
- inner_list[-1] for inner_list in self_ref().output
- ),
- SAM3Output.IterMode.FLATTENED: lambda: (
- element for inner_list in self_ref().output for element in inner_list
- ),
- }
- self.loss_stages = loss_stages
- @override
- def __iter__(self) -> Iterator:
- return self._mode2iter[self.iter_mode]()
- def __getitem__(self, index):
- """
- Returns the item at the specified index.
- Args:
- index (int): The index of the item to return.
- Returns:
- list or element: The item at the specified index.
- """
- assert isinstance(index, int), f"index should be an integer. Got {type(index)}"
- if self.iter_mode == SAM3Output.IterMode.ALL_STEPS_PER_STAGE:
- return self.output[index]
- elif self.iter_mode == SAM3Output.IterMode.LAST_STEP_PER_STAGE:
- return self.output[index][-1]
- elif self.iter_mode == SAM3Output.IterMode.FLATTENED:
- if index == -1:
- return self.self.output[-1][-1]
- else:
- flattened_output = sum(self.output, [])
- return flattened_output[index]
- class _IterationMode(AbstractContextManager):
- """
- A context manager that temporarily changes the iteration mode of a SAM3Output object.
- This class is used internally by the SAM3Output.iteration_mode method.
- """
- def __init__(
- self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode"
- ):
- self._model_output = model_output
- self._orig_iter_mode = model_output.iter_mode
- self._new_iter_mode = iter_mode
- @override
- def __enter__(self) -> "SAM3Output":
- self._model_output.iter_mode = self._new_iter_mode
- return self._model_output
- @override
- def __exit__(self, exc_type, exc_value, traceback):
- self._model_output.iter_mode = self._orig_iter_mode
- return super().__exit__(exc_type, exc_value, traceback)
- @staticmethod
- def iteration_mode(
- model_output: "SAM3Output", iter_mode: IterMode
- ) -> _IterationMode:
- """
- Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object.
- Args:
- model_output: The SAM3Output object.
- iter_mode: The new iteration mode.
- Returns:
- SAM3Output._IterationMode: A context manager that changes the iteration mode of the SAM3Output object.
- """
- return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
- def append(self, item: list):
- assert isinstance(item, list), (
- f"Only list items are supported. Got {type(item)}"
- )
- self.output.append(item)
- def __repr__(self):
- return self.output.__repr__()
- def __len__(self):
- if self.iter_mode in [
- SAM3Output.IterMode.ALL_STEPS_PER_STAGE,
- SAM3Output.IterMode.LAST_STEP_PER_STAGE,
- ]:
- return len(self.output)
- elif self.iter_mode == SAM3Output.IterMode.FLATTENED:
- flattened_output = sum(self.output, [])
- return len(flattened_output)
|