|
|
@@ -4,6 +4,7 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
+import contextlib
|
|
|
import math
|
|
|
import warnings
|
|
|
from functools import partial
|
|
|
@@ -14,12 +15,30 @@ import torch.nn.functional as F
|
|
|
from torch import nn, Tensor
|
|
|
|
|
|
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
|
|
-
|
|
|
from sam2.modeling.sam2_utils import MLP
|
|
|
from sam2.utils.misc import get_sdpa_settings
|
|
|
|
|
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
|
+# Check whether Flash Attention is available (and use it by default)
|
|
|
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
|
|
+# A fallback setting to allow all available kernels if Flash Attention fails
|
|
|
+ALLOW_ALL_KERNELS = False
|
|
|
+
|
|
|
+
|
|
|
+def sdp_kernel_context(dropout_p):
|
|
|
+ """
|
|
|
+ Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
|
|
+ by default, but fall back to all available kernels if Flash Attention fails.
|
|
|
+ """
|
|
|
+ if ALLOW_ALL_KERNELS:
|
|
|
+ return contextlib.nullcontext()
|
|
|
+
|
|
|
+ return torch.backends.cuda.sdp_kernel(
|
|
|
+ enable_flash=USE_FLASH_ATTN,
|
|
|
+ # if Flash attention kernel is off, then math kernel needs to be enabled
|
|
|
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
|
|
+ enable_mem_efficient=OLD_GPU,
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class TwoWayTransformer(nn.Module):
|
|
|
@@ -246,12 +265,19 @@ class Attention(nn.Module):
|
|
|
|
|
|
dropout_p = self.dropout_p if self.training else 0.0
|
|
|
# Attention
|
|
|
- with torch.backends.cuda.sdp_kernel(
|
|
|
- enable_flash=USE_FLASH_ATTN,
|
|
|
- # if Flash attention kernel is off, then math kernel needs to be enabled
|
|
|
- enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
|
|
- enable_mem_efficient=OLD_GPU,
|
|
|
- ):
|
|
|
+ try:
|
|
|
+ with sdp_kernel_context(dropout_p):
|
|
|
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
|
+ except Exception as e:
|
|
|
+ # Fall back to all kernels if the Flash attention kernel fails
|
|
|
+ warnings.warn(
|
|
|
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
|
|
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
|
|
+ category=UserWarning,
|
|
|
+ stacklevel=2,
|
|
|
+ )
|
|
|
+ global ALLOW_ALL_KERNELS
|
|
|
+ ALLOW_ALL_KERNELS = True
|
|
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
|
|
|
|
out = self._recombine_heads(out)
|
|
|
@@ -313,12 +339,19 @@ class RoPEAttention(Attention):
|
|
|
|
|
|
dropout_p = self.dropout_p if self.training else 0.0
|
|
|
# Attention
|
|
|
- with torch.backends.cuda.sdp_kernel(
|
|
|
- enable_flash=USE_FLASH_ATTN,
|
|
|
- # if Flash attention kernel is off, then math kernel needs to be enabled
|
|
|
- enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
|
|
- enable_mem_efficient=OLD_GPU,
|
|
|
- ):
|
|
|
+ try:
|
|
|
+ with sdp_kernel_context(dropout_p):
|
|
|
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
|
+ except Exception as e:
|
|
|
+ # Fall back to all kernels if the Flash attention kernel fails
|
|
|
+ warnings.warn(
|
|
|
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
|
|
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
|
|
+ category=UserWarning,
|
|
|
+ stacklevel=2,
|
|
|
+ )
|
|
|
+ global ALLOW_ALL_KERNELS
|
|
|
+ ALLOW_ALL_KERNELS = True
|
|
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
|
|
|
|
out = self._recombine_heads(out)
|