sam2_utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import copy
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
  10. """
  11. Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
  12. that are temporally closest to the current frame at `frame_idx`. Here, we take
  13. - a) the closest conditioning frame before `frame_idx` (if any);
  14. - b) the closest conditioning frame after `frame_idx` (if any);
  15. - c) any other temporally closest conditioning frames until reaching a total
  16. of `max_cond_frame_num` conditioning frames.
  17. Outputs:
  18. - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
  19. - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
  20. """
  21. if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
  22. selected_outputs = cond_frame_outputs
  23. unselected_outputs = {}
  24. else:
  25. assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
  26. selected_outputs = {}
  27. # the closest conditioning frame before `frame_idx` (if any)
  28. idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
  29. if idx_before is not None:
  30. selected_outputs[idx_before] = cond_frame_outputs[idx_before]
  31. # the closest conditioning frame after `frame_idx` (if any)
  32. idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
  33. if idx_after is not None:
  34. selected_outputs[idx_after] = cond_frame_outputs[idx_after]
  35. # add other temporally closest conditioning frames until reaching a total
  36. # of `max_cond_frame_num` conditioning frames.
  37. num_remain = max_cond_frame_num - len(selected_outputs)
  38. inds_remain = sorted(
  39. (t for t in cond_frame_outputs if t not in selected_outputs),
  40. key=lambda x: abs(x - frame_idx),
  41. )[:num_remain]
  42. selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
  43. unselected_outputs = {
  44. t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
  45. }
  46. return selected_outputs, unselected_outputs
  47. def get_1d_sine_pe(pos_inds, dim, temperature=10000):
  48. """
  49. Get 1D sine positional embedding as in the original Transformer paper.
  50. """
  51. pe_dim = dim // 2
  52. dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
  53. dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
  54. pos_embed = pos_inds.unsqueeze(-1) / dim_t
  55. pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
  56. return pos_embed
  57. def get_activation_fn(activation):
  58. """Return an activation function given a string"""
  59. if activation == "relu":
  60. return F.relu
  61. if activation == "gelu":
  62. return F.gelu
  63. if activation == "glu":
  64. return F.glu
  65. raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
  66. def get_clones(module, N):
  67. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  68. class DropPath(nn.Module):
  69. # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
  70. def __init__(self, drop_prob=0.0, scale_by_keep=True):
  71. super(DropPath, self).__init__()
  72. self.drop_prob = drop_prob
  73. self.scale_by_keep = scale_by_keep
  74. def forward(self, x):
  75. if self.drop_prob == 0.0 or not self.training:
  76. return x
  77. keep_prob = 1 - self.drop_prob
  78. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  79. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  80. if keep_prob > 0.0 and self.scale_by_keep:
  81. random_tensor.div_(keep_prob)
  82. return x * random_tensor
  83. # Lightly adapted from
  84. # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
  85. class MLP(nn.Module):
  86. def __init__(
  87. self,
  88. input_dim: int,
  89. hidden_dim: int,
  90. output_dim: int,
  91. num_layers: int,
  92. activation: nn.Module = nn.ReLU,
  93. sigmoid_output: bool = False,
  94. ) -> None:
  95. super().__init__()
  96. self.num_layers = num_layers
  97. h = [hidden_dim] * (num_layers - 1)
  98. self.layers = nn.ModuleList(
  99. nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
  100. )
  101. self.sigmoid_output = sigmoid_output
  102. self.act = activation()
  103. def forward(self, x):
  104. for i, layer in enumerate(self.layers):
  105. x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
  106. if self.sigmoid_output:
  107. x = F.sigmoid(x)
  108. return x
  109. # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
  110. # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
  111. class LayerNorm2d(nn.Module):
  112. def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
  113. super().__init__()
  114. self.weight = nn.Parameter(torch.ones(num_channels))
  115. self.bias = nn.Parameter(torch.zeros(num_channels))
  116. self.eps = eps
  117. def forward(self, x: torch.Tensor) -> torch.Tensor:
  118. u = x.mean(1, keepdim=True)
  119. s = (x - u).pow(2).mean(1, keepdim=True)
  120. x = (x - u) / torch.sqrt(s + self.eps)
  121. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  122. return x