fa3.py 843 B

1234567891011121314151617181920212223242526272829
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import torch
  4. @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
  5. def flash_attn_func_op(
  6. q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
  7. ) -> torch.Tensor:
  8. from flash_attn_interface import flash_attn_func as fa3
  9. return fa3(q, k, v)
  10. def flash_attn_func(q, k, v):
  11. dtype = torch.float8_e4m3fn
  12. return flash_attn_func_op(q.to(dtype), k.to(dtype), v.to(dtype)).to(q.dtype)
  13. @flash_attn_func_op.register_fake
  14. def _(q, k, v, **kwargs):
  15. # two outputs:
  16. # 1. output: (batch, seq_len, num_heads, head_dim)
  17. # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
  18. # output needs to be bfloat16, not float8!
  19. meta_q = torch.empty_like(q, dtype=torch.bfloat16).contiguous()
  20. return meta_q