sub_quadratic_attention.py 7.1 KB


  1. # original source:
  2. # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
  3. # license:
  4. # MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
  5. # credit:
  6. # Amin Rezaei (original author)
  7. # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
  8. # brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
  9. # implementation of:
  10. # Self-attention Does Not Need O(n2) Memory":
  11. # https://arxiv.org/abs/2112.05682v2
  12. from functools import partial
  13. import torch
  14. from torch import Tensor
  15. from torch.utils.checkpoint import checkpoint
  16. import math
  17. from typing import Optional, NamedTuple, List
  18. def narrow_trunc(
  19. input: Tensor,
  20. dim: int,
  21. start: int,
  22. length: int
  23. ) -> Tensor:
  24. return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
  25. class AttnChunk(NamedTuple):
  26. exp_values: Tensor
  27. exp_weights_sum: Tensor
  28. max_score: Tensor
  29. class SummarizeChunk:
  30. @staticmethod
  31. def __call__(
  32. query: Tensor,
  33. key: Tensor,
  34. value: Tensor,
  35. ) -> AttnChunk: ...
  36. class ComputeQueryChunkAttn:
  37. @staticmethod
  38. def __call__(
  39. query: Tensor,
  40. key: Tensor,
  41. value: Tensor,
  42. ) -> Tensor: ...
  43. def _summarize_chunk(
  44. query: Tensor,
  45. key: Tensor,
  46. value: Tensor,
  47. scale: float,
  48. ) -> AttnChunk:
  49. attn_weights = torch.baddbmm(
  50. torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
  51. query,
  52. key.transpose(1,2),
  53. alpha=scale,
  54. beta=0,
  55. )
  56. max_score, _ = torch.max(attn_weights, -1, keepdim=True)
  57. max_score = max_score.detach()
  58. exp_weights = torch.exp(attn_weights - max_score)
  59. exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
  60. max_score = max_score.squeeze(-1)
  61. return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
  62. def _query_chunk_attention(
  63. query: Tensor,
  64. key: Tensor,
  65. value: Tensor,
  66. summarize_chunk: SummarizeChunk,
  67. kv_chunk_size: int,
  68. ) -> Tensor:
  69. batch_x_heads, k_tokens, k_channels_per_head = key.shape
  70. _, _, v_channels_per_head = value.shape
  71. def chunk_scanner(chunk_idx: int) -> AttnChunk:
  72. key_chunk = narrow_trunc(
  73. key,
  74. 1,
  75. chunk_idx,
  76. kv_chunk_size
  77. )
  78. value_chunk = narrow_trunc(
  79. value,
  80. 1,
  81. chunk_idx,
  82. kv_chunk_size
  83. )
  84. return summarize_chunk(query, key_chunk, value_chunk)
  85. chunks: List[AttnChunk] = [
  86. chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
  87. ]
  88. acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
  89. chunk_values, chunk_weights, chunk_max = acc_chunk
  90. global_max, _ = torch.max(chunk_max, 0, keepdim=True)
  91. max_diffs = torch.exp(chunk_max - global_max)
  92. chunk_values *= torch.unsqueeze(max_diffs, -1)
  93. chunk_weights *= max_diffs
  94. all_values = chunk_values.sum(dim=0)
  95. all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
  96. return all_values / all_weights
  97. # TODO: refactor CrossAttention#get_attention_scores to share code with this
  98. def _get_attention_scores_no_kv_chunking(
  99. query: Tensor,
  100. key: Tensor,
  101. value: Tensor,
  102. scale: float,
  103. ) -> Tensor:
  104. attn_scores = torch.baddbmm(
  105. torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
  106. query,
  107. key.transpose(1,2),
  108. alpha=scale,
  109. beta=0,
  110. )
  111. attn_probs = attn_scores.softmax(dim=-1)
  112. del attn_scores
  113. hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
  114. return hidden_states_slice
  115. class ScannedChunk(NamedTuple):
  116. chunk_idx: int
  117. attn_chunk: AttnChunk
  118. def efficient_dot_product_attention(
  119. query: Tensor,
  120. key: Tensor,
  121. value: Tensor,
  122. query_chunk_size=1024,
  123. kv_chunk_size: Optional[int] = None,
  124. kv_chunk_size_min: Optional[int] = None,
  125. use_checkpoint=True,
  126. ):
  127. """Computes efficient dot-product attention given query, key, and value.
  128. This is efficient version of attention presented in
  129. https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
  130. Args:
  131. query: queries for calculating attention with shape of
  132. `[batch * num_heads, tokens, channels_per_head]`.
  133. key: keys for calculating attention with shape of
  134. `[batch * num_heads, tokens, channels_per_head]`.
  135. value: values to be used in attention with shape of
  136. `[batch * num_heads, tokens, channels_per_head]`.
  137. query_chunk_size: int: query chunks size
  138. kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
  139. kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
  140. use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
  141. Returns:
  142. Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
  143. """
  144. batch_x_heads, q_tokens, q_channels_per_head = query.shape
  145. _, k_tokens, _ = key.shape
  146. scale = q_channels_per_head ** -0.5
  147. kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
  148. if kv_chunk_size_min is not None:
  149. kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
  150. def get_query_chunk(chunk_idx: int) -> Tensor:
  151. return narrow_trunc(
  152. query,
  153. 1,
  154. chunk_idx,
  155. min(query_chunk_size, q_tokens)
  156. )
  157. summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
  158. summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
  159. compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
  160. _get_attention_scores_no_kv_chunking,
  161. scale=scale
  162. ) if k_tokens <= kv_chunk_size else (
  163. # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
  164. partial(
  165. _query_chunk_attention,
  166. kv_chunk_size=kv_chunk_size,
  167. summarize_chunk=summarize_chunk,
  168. )
  169. )
  170. if q_tokens <= query_chunk_size:
  171. # fast-path for when there's just 1 query chunk
  172. return compute_query_chunk_attn(
  173. query=query,
  174. key=key,
  175. value=value,
  176. )
  177. res = torch.zeros_like(query)
  178. for i in range(math.ceil(q_tokens / query_chunk_size)):
  179. attn_scores = compute_query_chunk_attn(
  180. query=get_query_chunk(i * query_chunk_size),
  181. key=key,
  182. value=value,
  183. )
  184. res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
  185. return res