memory_attention.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Optional
  6. import torch
  7. from torch import nn, Tensor
  8. from sam2.modeling.sam.transformer import RoPEAttention
  9. from sam2.modeling.sam2_utils import get_activation_fn, get_clones
  10. class MemoryAttentionLayer(nn.Module):
  11. def __init__(
  12. self,
  13. activation: str,
  14. cross_attention: nn.Module,
  15. d_model: int,
  16. dim_feedforward: int,
  17. dropout: float,
  18. pos_enc_at_attn: bool,
  19. pos_enc_at_cross_attn_keys: bool,
  20. pos_enc_at_cross_attn_queries: bool,
  21. self_attention: nn.Module,
  22. ):
  23. super().__init__()
  24. self.d_model = d_model
  25. self.dim_feedforward = dim_feedforward
  26. self.dropout_value = dropout
  27. self.self_attn = self_attention
  28. self.cross_attn_image = cross_attention
  29. # Implementation of Feedforward model
  30. self.linear1 = nn.Linear(d_model, dim_feedforward)
  31. self.dropout = nn.Dropout(dropout)
  32. self.linear2 = nn.Linear(dim_feedforward, d_model)
  33. self.norm1 = nn.LayerNorm(d_model)
  34. self.norm2 = nn.LayerNorm(d_model)
  35. self.norm3 = nn.LayerNorm(d_model)
  36. self.dropout1 = nn.Dropout(dropout)
  37. self.dropout2 = nn.Dropout(dropout)
  38. self.dropout3 = nn.Dropout(dropout)
  39. self.activation_str = activation
  40. self.activation = get_activation_fn(activation)
  41. # Where to add pos enc
  42. self.pos_enc_at_attn = pos_enc_at_attn
  43. self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
  44. self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
  45. def _forward_sa(self, tgt, query_pos):
  46. # Self-Attention
  47. tgt2 = self.norm1(tgt)
  48. q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
  49. tgt2 = self.self_attn(q, k, v=tgt2)
  50. tgt = tgt + self.dropout1(tgt2)
  51. return tgt
  52. def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
  53. kwds = {}
  54. if num_k_exclude_rope > 0:
  55. assert isinstance(self.cross_attn_image, RoPEAttention)
  56. kwds = {"num_k_exclude_rope": num_k_exclude_rope}
  57. # Cross-Attention
  58. tgt2 = self.norm2(tgt)
  59. tgt2 = self.cross_attn_image(
  60. q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
  61. k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  62. v=memory,
  63. **kwds,
  64. )
  65. tgt = tgt + self.dropout2(tgt2)
  66. return tgt
  67. def forward(
  68. self,
  69. tgt,
  70. memory,
  71. pos: Optional[Tensor] = None,
  72. query_pos: Optional[Tensor] = None,
  73. num_k_exclude_rope: int = 0,
  74. ) -> torch.Tensor:
  75. # Self-Attn, Cross-Attn
  76. tgt = self._forward_sa(tgt, query_pos)
  77. tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
  78. # MLP
  79. tgt2 = self.norm3(tgt)
  80. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  81. tgt = tgt + self.dropout3(tgt2)
  82. return tgt
  83. class MemoryAttention(nn.Module):
  84. def __init__(
  85. self,
  86. d_model: int,
  87. pos_enc_at_input: bool,
  88. layer: nn.Module,
  89. num_layers: int,
  90. batch_first: bool = True, # Do layers expect batch first input?
  91. ):
  92. super().__init__()
  93. self.d_model = d_model
  94. self.layers = get_clones(layer, num_layers)
  95. self.num_layers = num_layers
  96. self.norm = nn.LayerNorm(d_model)
  97. self.pos_enc_at_input = pos_enc_at_input
  98. self.batch_first = batch_first
  99. def forward(
  100. self,
  101. curr: torch.Tensor, # self-attention inputs
  102. memory: torch.Tensor, # cross-attention inputs
  103. curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
  104. memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
  105. num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
  106. ):
  107. if isinstance(curr, list):
  108. assert isinstance(curr_pos, list)
  109. assert len(curr) == len(curr_pos) == 1
  110. curr, curr_pos = (
  111. curr[0],
  112. curr_pos[0],
  113. )
  114. assert (
  115. curr.shape[1] == memory.shape[1]
  116. ), "Batch size must be the same for curr and memory"
  117. output = curr
  118. if self.pos_enc_at_input and curr_pos is not None:
  119. output = output + 0.1 * curr_pos
  120. if self.batch_first:
  121. # Convert to batch first
  122. output = output.transpose(0, 1)
  123. curr_pos = curr_pos.transpose(0, 1)
  124. memory = memory.transpose(0, 1)
  125. memory_pos = memory_pos.transpose(0, 1)
  126. for layer in self.layers:
  127. kwds = {}
  128. if isinstance(layer.cross_attn_image, RoPEAttention):
  129. kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
  130. output = layer(
  131. tgt=output,
  132. memory=memory,
  133. pos=memory_pos,
  134. query_pos=curr_pos,
  135. **kwds,
  136. )
  137. normed_output = self.norm(output)
  138. if self.batch_first:
  139. # Convert back to seq first
  140. normed_output = normed_output.transpose(0, 1)
  141. curr_pos = curr_pos.transpose(0, 1)
  142. return normed_output