rope.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. """
  4. Adapted from:
  5. 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
  6. 2. https://github.com/naver-ai/rope-vit
  7. 3. https://github.com/lucidrains/rotary-embedding-torch
  8. """
  9. from typing import Optional
  10. import torch
  11. from einops import rearrange, repeat
  12. from torch import broadcast_tensors, nn
  13. def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0, device=None):
  14. t = torch.arange(end_x * end_y, dtype=torch.float32, device=device)
  15. t_x = (t % end_x).float()
  16. t_y = torch.div(t, end_x, rounding_mode="floor").float()
  17. return t_x * scale + offset, t_y * scale + offset
  18. def compute_axial_cis(
  19. dim: int,
  20. end_x: int,
  21. end_y: int,
  22. theta: float = 10000.0,
  23. scale_pos: float = 1.0,
  24. offset: int = 0,
  25. device=None,
  26. ):
  27. freqs_x = 1.0 / (
  28. theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)
  29. )
  30. freqs_y = 1.0 / (
  31. theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)
  32. )
  33. t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset, device=device)
  34. freqs_x = torch.outer(t_x, freqs_x)
  35. freqs_y = torch.outer(t_y, freqs_y)
  36. freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
  37. freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
  38. return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
  39. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  40. ndim = x.ndim
  41. assert 0 <= 1 < ndim
  42. assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
  43. shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
  44. return freqs_cis.view(*shape)
  45. def apply_rotary_enc(
  46. xq: torch.Tensor,
  47. xk: torch.Tensor,
  48. freqs_cis: torch.Tensor,
  49. repeat_freqs_k: bool = False,
  50. ):
  51. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  52. xk_ = (
  53. torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  54. if xk.shape[-2] != 0
  55. else None
  56. )
  57. freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
  58. xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
  59. if xk_ is None:
  60. # no keys to rotate, due to dropout
  61. return xq_out.type_as(xq).to(xq.device), xk
  62. # repeat freqs along seq_len dim to match k seq_len
  63. if repeat_freqs_k:
  64. r = xk_.shape[-2] // xq_.shape[-2]
  65. freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
  66. xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
  67. return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
  68. def complex_mult(xq_real, xq_imag, freqs_cis_real, freqs_cis_imag):
  69. # Compute the real part of the product
  70. real_part = xq_real * freqs_cis_real - xq_imag * freqs_cis_imag
  71. # Compute the imaginary part of the product
  72. imag_part = xq_real * freqs_cis_imag + xq_imag * freqs_cis_real
  73. # Stack the real and imaginary parts along the last dimension
  74. return torch.stack([real_part, imag_part], dim=-1)
  75. def apply_rotary_enc_real(
  76. xq: torch.Tensor,
  77. xk: torch.Tensor,
  78. freqs_cis_real: torch.Tensor,
  79. freqs_cis_imag: torch.Tensor,
  80. repeat_freqs_k: bool = False,
  81. ):
  82. assert xk is not None
  83. assert xk.shape[-2] != 0
  84. xq_real = xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 0]
  85. xq_imag = xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 1]
  86. xk_real = xk.float().reshape(*xk.shape[:-1], -1, 2)[..., 0]
  87. xk_imag = xk.float().reshape(*xk.shape[:-1], -1, 2)[..., 1]
  88. freqs_cis_real = reshape_for_broadcast(freqs_cis_real, xq_real)
  89. freqs_cis_imag = reshape_for_broadcast(freqs_cis_imag, xq_imag)
  90. xq_out = complex_mult(xq_real, xq_imag, freqs_cis_real, freqs_cis_imag).flatten(3)
  91. if repeat_freqs_k:
  92. r = xk_real.shape[-2] // xq_real.shape[-2]
  93. freqs_cis_real = freqs_cis_real.repeat(*([1] * (freqs_cis_real.ndim - 2)), r, 1)
  94. freqs_cis_imag = freqs_cis_imag.repeat(*([1] * (freqs_cis_imag.ndim - 2)), r, 1)
  95. xk_out = complex_mult(xk_real, xk_imag, freqs_cis_real, freqs_cis_imag).flatten(3)
  96. # xq_out = torch.view_as_real(torch.complex(xq_real, xq_imag) * torch.complex(freqs_cis_real, freqs_cis_imag)).flatten(3)
  97. # xk_out = torch.view_as_real(torch.compelx(xk_real, xk_imag) * torch.complex(freqs_cis_real, freqs_cis_imag)).flatten(3)
  98. return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
  99. # rotary embedding helper functions
  100. def broadcat(tensors, dim=-1):
  101. broadcasted_tensors = broadcast_tensors(*tensors)
  102. return torch.cat(broadcasted_tensors, dim=dim)
  103. def rotate_half(x: torch.Tensor):
  104. x = rearrange(x, "... (d r) -> ... d r", r=2)
  105. x1, x2 = x.unbind(dim=-1)
  106. x = torch.stack((-x2, x1), dim=-1)
  107. return rearrange(x, "... d r -> ... (d r)")
  108. class VisionRotaryEmbeddingVE(nn.Module):
  109. def __init__(
  110. self,
  111. dim: int,
  112. seq_len: int,
  113. pt_seq_len: Optional[int] = None,
  114. theta: float = 10000.0,
  115. offset: int = 1, # specific to VE
  116. ):
  117. super().__init__()
  118. freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  119. scale = 1.0
  120. if pt_seq_len is not None:
  121. scale = pt_seq_len / seq_len
  122. # offset of +1 following VE - even though for the
  123. # attention op only differences matter
  124. t = torch.arange(seq_len) * scale + offset
  125. freqs = torch.einsum("..., f -> ... f", t, freqs)
  126. freqs = repeat(freqs, "... n -> ... (n r)", r=2)
  127. freqs = broadcat((freqs[None, :, :], freqs[:, None, :]), dim=-1)
  128. freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
  129. freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
  130. self.register_buffer("freqs_cos", freqs_cos)
  131. self.register_buffer("freqs_sin", freqs_sin)
  132. def forward(self, t: torch.Tensor):
  133. return t * self.freqs_cos + rotate_half(t) * self.freqs_sin