"Flash Attention, FlashInfer, SDPA backends, PagedAttention, and attention kernel selection/configuration. Use when choosing or configuring attention backends for training or inference (FlashAttention-2/3, FlashInfer, SDPA, xFormers, PagedAttention, Ring Attention, FlexAttention/FlexDecoding, varlen_attn)."
Resources
2Install
npx skillscat add tylertitsworth/skills/flash-attention Install via the SkillsCat registry.
Flash Attention & Attention Backends
Attention Backend Landscape
| Backend | GPU Arch | Dtypes | Max Head Dim | Use Case |
|---|---|---|---|---|
| FlashAttention-2 | Ampere, Ada, Hopper | fp16, bf16 | 256 | General training & inference |
| FlashAttention-3 | Hopper only (H100/H800) | fp16, bf16, fp8 (E4M3 fwd) | 256 | Maximum throughput on H100 |
| FlashInfer | Ampere+ (SM80+) | fp16, bf16, fp8 | 256 | LLM inference kernels — decode, prefill, paged KV |
| SDPA Math | Any CUDA | fp32, fp16, bf16 | Any | Fallback / debugging |
| SDPA Efficient (xFormers/Memory-Efficient) | Ampere+ | fp16, bf16 | 128 | When FA unavailable |
| PagedAttention | Ampere+ | fp16, bf16 | 128+ | KV cache management in inference (vLLM, TGI) |
| Ring Attention | Multi-GPU | fp16, bf16 | 256 | Sequence parallelism for very long contexts |
| FlexAttention | Ampere+ | fp16, bf16 | Any | Custom attention masks via torch.nn.attention.flex_attention; FlexDecoding backend auto-activates for inference decode |
PyTorch SDPA (Scaled Dot Product Attention)
PyTorch routes F.scaled_dot_product_attention() to the best available backend automatically. Override with the context manager:
from torch.nn.attention import SDPBackend, sdpa_kernel
# Force FlashAttention only
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
output = F.scaled_dot_product_attention(q, k, v)
# Set priority order (try Flash first, fall back to Efficient)
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION], set_priority=True):
output = F.scaled_dot_product_attention(q, k, v)SDPBackend Enum
| Backend | Enum Value | Notes |
|---|---|---|
FLASH_ATTENTION |
Flash Attention (Dao-AILab) | Requires Ampere+, fp16/bf16, head_dim ≤ 256 |
EFFICIENT_ATTENTION |
Memory-Efficient (xFormers-style) | Broader head_dim support, slightly slower |
MATH |
Naive PyTorch math | No restrictions, no fusion, slow — use for debugging |
CUDNN_ATTENTION |
cuDNN backend | Available in PyTorch 2.2+, limited config |
Backend Selection Logic
PyTorch checks backends in priority order. A backend is skipped if:
- GPU arch not supported (e.g., FA on Turing T4)
- Dtype mismatch (e.g., fp32 → no FA)
- Head dimension exceeds backend limit
- Causal mask + custom attention mask combination unsupported
- Dropout > 0.0 with some backends in eval mode
Check which backend would be selected:
from torch.nn.attention import _get_flash_version
# Returns version string or None if unavailable
print(_get_flash_version())FlashAttention-2 Direct API
The flash-attn package provides a direct API bypassing SDPA:
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
# Separate Q, K, V tensors
# q, k, v: (batch, seqlen, nheads, headdim)
output = flash_attn_func(
q, k, v,
dropout_p=0.0, # Set 0.0 during eval
softmax_scale=None, # Default: 1/sqrt(headdim)
causal=False,
window_size=(-1, -1), # Sliding window: (left, right)
alibi_slopes=None, # (nheads,) or (batch, nheads) for ALiBi
deterministic=False, # Deterministic backward (slower, more memory)
)
# Packed QKV (faster backward — avoids gradient concat)
# qkv: (batch, seqlen, 3, nheads, headdim)
output = flash_attn_qkvpacked_func(qkv, causal=True)Key Parameters
| Parameter | Type | Default | Effect |
|---|---|---|---|
dropout_p |
float | 0.0 | Attention dropout probability |
softmax_scale |
float | 1/sqrt(d) |
QK scaling factor |
causal |
bool | False | Apply causal mask |
window_size |
(int, int) | (-1, -1) | Sliding window attention bounds |
alibi_slopes |
Tensor | None | ALiBi positional bias slopes |
deterministic |
bool | False | Deterministic backward pass |
GQA/MQA Support
Flash Attention supports grouped-query and multi-query attention natively. Pass K, V with fewer heads than Q — the number of Q heads must be divisible by K/V heads:
# GQA: 32 query heads, 8 KV heads
q = torch.randn(B, S, 32, D, dtype=torch.float16, device="cuda")
k = torch.randn(B, S, 8, D, dtype=torch.float16, device="cuda")
v = torch.randn(B, S, 8, D, dtype=torch.float16, device="cuda")
output = flash_attn_func(q, k, v, causal=True)FlashAttention-3 (Hopper-Native)
FA3 is specifically designed for H100/H800, exploiting Hopper architecture features: warp specialization, TMA (Tensor Memory Accelerator), and FP8 tensor cores. ~1.5-2× faster than FA2 on H100 for fp16/bf16.
Key Architecture Differences from FA2
| Feature | FA2 | FA3 |
|---|---|---|
| Warp scheduling | Standard cooperative groups | Warp-specialized: producer warps (load data via TMA) + consumer warps (compute) |
| Memory access | Manual GMEM→SMEM copies | TMA-based async bulk copies (hardware DMA) |
| FP8 support | None | E4M3 forward pass on tensor cores |
| Pipeline | Sequential load→compute | 2-stage async pipeline: load next tile while computing current |
| SMEM usage | Standard allocation | Pingpong SMEM buffers for async pipeline |
| Block size | Fixed per kernel | Larger blocks (256×128) enabled by TMA efficiency |
FA3 API
FA3 is a separate package from the main flash-attn — it requires a distinct install from the hopper/ subdirectory and uses a different import.
Install:
# FA3 is NOT included in pip install flash-attn
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention/hopper
python setup.py installUsage:
# FA3 uses flash_attn_interface module, NOT flash_attn
import flash_attn_interface
output = flash_attn_interface.flash_attn_func(q, k, v, causal=True)FP8 forward pass:
# FP8 E4M3 inputs — forward only (backward still in bf16/fp16)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
output = flash_attn_interface.flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)Requirements: H100/H800 (SM90a), CUDA ≥ 12.3 (12.8 recommended).
FA3 Performance Characteristics
| Sequence Length | FA3 vs FA2 Speedup (H100, bf16, causal) | Notes |
|---|---|---|
| 1K | ~1.3× | Compute-bound, less TMA advantage |
| 4K | ~1.5× | Sweet spot for TMA pipeline |
| 16K | ~1.7× | Larger tiles amortize overhead |
| 64K+ | ~1.8-2.0× | Maximum pipeline benefit |
When to use FA3 explicitly:
- Always on H100/H800 — no reason to use FA2
- FP8 forward pass for inference (saves memory bandwidth)
- Long-context training (>4K) where pipeline efficiency matters most
FlashInfer
FlashInfer is a kernel library for LLM inference serving (not training). It provides optimized attention kernels, sampling, and KV cache management — used as the default attention backend in SGLang.
Version: 0.6.4 | Install: pip install flashinfer-python
Key Features
| Feature | Description |
|---|---|
| Prefill attention | Ragged tensor attention for variable-length prefill batches |
| Decode attention | Optimized single-token decode with paged KV cache |
| POD-Attention | Fused prefill+decode kernel for mixed batching |
| PageAttention | Block-sparse paged KV cache management |
| Cascade attention | Multi-level KV cache (shared prefix + unique suffix) |
| GQA/MQA | Native grouped/multi-query attention support |
| FP8 KV cache | E4M3/E5M2 KV cache quantization |
| Sampling | GPU-fused Top-K, Top-P, Min-P without sorting |
| LoRA | BGMV/SGMV kernels for batched LoRA |
| BF16 GEMM | Matrix multiplication for SM100+ (Blackwell) |
FlashInfer vs FlashAttention
| Aspect | FlashInfer | FlashAttention |
|---|---|---|
| Focus | Inference serving (variable-length, paged KV) | Training + inference (fixed-length, dense) |
| KV cache | Paged, ragged, block-sparse — built-in | External (user manages contiguous tensors) |
| Decode | Highly optimized single-token decode kernels | Decent but not inference-specialized |
| Prefill | Variable-length ragged attention | Packed QKV (flash_attn_varlen_func) |
| Customization | JIT kernel generation for custom configs | Fixed compiled kernels |
| Used by | SGLang (default), vLLM (optional) | Training frameworks, HF Transformers, vLLM |
FlashInfer API (Inference Engine Integration)
FlashInfer is typically used indirectly through serving engines. Direct API for custom inference:
import flashinfer
# Prefill: variable-length batch
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer)
prefill_wrapper.plan(
qo_indptr=qo_indptr, # [0, seq_len_0, seq_len_0+seq_len_1, ...]
kv_indptr=kv_indptr,
kv_indices=kv_indices, # page table indices
kv_last_page_len=kv_last_page_len,
num_qo_heads=32,
num_kv_heads=8, # GQA
head_dim=128,
page_size=16,
)
output = prefill_wrapper.run(q, kv_cache)
# Decode: single-token per request
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer)
decode_wrapper.plan(
indptr=indptr,
indices=indices,
last_page_len=last_page_len,
num_qo_heads=32,
num_kv_heads=8,
head_dim=128,
page_size=16,
)
output = decode_wrapper.run(q, kv_cache)Selecting FlashInfer in Serving Engines
# SGLang (default on non-Hopper)
import sglang as sgl
llm = sgl.Engine(model_path="...", attention_backend="flashinfer")
# vLLM (optional)
# Set VLLM_ATTENTION_BACKEND=FLASHINFER or use --attention-backend flashinferPagedAttention
PagedAttention manages KV cache memory in blocks (pages) rather than contiguous tensors — eliminates memory fragmentation during inference. Used by vLLM, TGI, and SGLang internally.
Not directly configured by users — it's an implementation detail of inference engines. However, understand its impact:
| Setting | Where | Effect |
|---|---|---|
block_size |
vLLM EngineArgs |
Page size for KV cache blocks (default: 16) |
gpu_memory_utilization |
vLLM | Fraction of GPU memory for KV cache (default: 0.9) |
swap_space |
vLLM | CPU swap space in GB for overflow pages |
num_gpu_blocks |
Computed | Total KV cache blocks available |
How it works: Instead of pre-allocating max_seq_len × batch_size contiguous memory, PagedAttention allocates fixed-size blocks on demand. Sequences reference block tables mapping logical positions to physical blocks. This allows:
- Near-zero memory waste from padding
- Efficient prefix caching (shared blocks across sequences)
- Copy-on-write for parallel sampling
Ring Attention (Sequence Parallelism)
Ring Attention distributes long sequences across GPUs in a ring topology. Each GPU holds a chunk of Q and rotates K, V blocks around the ring:
- Enables training on sequences longer than single-GPU memory allows
- Linear memory scaling with number of GPUs
- Used in: Llama 3.1 (128K context), DeepSeek-V2
Integration: Typically activated via framework flags rather than direct API:
- Megatron-LM:
--context-parallel-size Nenables ring attention - DeepSpeed Ulysses: Alternative approach — all-to-all on attention heads instead of ring on sequence
FlexAttention (PyTorch 2.5+)
flex_attention compiles custom attention patterns into fused kernels via score_mod / mask_mod functions and BlockMask for sparsity:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
def causal_with_window(score, b, h, q_idx, kv_idx):
return torch.where(
(q_idx >= kv_idx) & (q_idx - kv_idx < 1024),
score, float("-inf"),
)
block_mask = create_block_mask(causal_with_window, B, H, Q_LEN, KV_LEN)
output = flex_attention(q, k, v, block_mask=block_mask)Advantages: Fuses arbitrary attention patterns (document masking, sliding window + causal, prefix LM) without materializing the full attention matrix.
FlexDecoding (Inference Backend, PyTorch 2.7+)
torch.compile(flex_attention) auto-switches to a FlexDecoding kernel when query length is much shorter than KV length (decode phase). No API change — just shape-based JIT dispatch:
- Prefill (q_len ≈ kv_len): uses standard FlexAttention kernel
- Decode (q_len=1, kv_len=128K): auto-recompiles to FlexDecoding (FlashDecoding-based)
Key patterns for efficient decode: use captured tensor offsets to avoid recompilation per step, precompute BlockMask and slice per-step, and use convert_logical_block_mask() for PT2-native PagedAttention without custom CUDA kernels.
Performance: 1.22×–2.04× faster than SDPA on LLaMA 3.1-8B decode; parity to 1.66× on 70B. On par with FlashDecoding (FA-KV) while supporting arbitrary attention patterns.
See `references/flexattention-inference.md` for: offset patterns, BlockMask slicing, PT2-native PagedAttention architecture, and varlen_attn API.
varlen_attn — Variable-Length Attention (PyTorch 2.10+)
torch.nn.attention.varlen.varlen_attn provides packed/ragged sequence attention in PyTorch core — eliminates padding waste for variable-length batches. Uses Flash Attention kernels under the hood; replaces flash_attn_varlen_func from the flash-attn package.
from torch.nn.attention.varlen import varlen_attn
# Packed tensors: (total_tokens, num_heads, head_dim) — no batch dim, no padding
# cu_seq: cumulative positions, e.g. [0, 128, 384, 448, 960] for 4 sequences
output = varlen_attn(q, k, v, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True)Prefer varlen_attn over flash_attn_varlen_func for new code — it's part of PyTorch core and doesn't require a separate package. See `references/flexattention-inference.md` for packed sequence construction and the AuxRequest LSE output.
Hugging Face Transformers Integration
Control attention backend via attn_implementation:
from transformers import AutoModelForCausalLM
# Explicit backend selection
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_2", # or "sdpa", "eager"
torch_dtype=torch.bfloat16,
)| Value | Backend | Notes |
|---|---|---|
"flash_attention_2" |
FlashAttention-2 via flash-attn package |
Requires pip install flash-attn in image |
"sdpa" |
PyTorch native SDPA | Default for PyTorch ≥ 2.1.1 |
"eager" |
Manual attention math | Debugging only, very slow |
Attention Kernel Benchmarking Methodology
When selecting attention backends for a deployment, benchmark systematically:
What to Measure
| Metric | How | Why |
|---|---|---|
| Prefill latency | Time first token, varying input lengths | Prefill is memory-bandwidth bound at short seq, compute-bound at long seq |
| Decode throughput | Tokens/sec at target batch size | Decode is memory-bandwidth bound — kernel efficiency matters |
| Peak memory | torch.cuda.max_memory_allocated() |
Determines max batch size / sequence length |
| TFLOPS utilization | Measured FLOPS ÷ peak theoretical | Indicates how close to hardware limit |
Benchmarking Script Pattern
import torch
import time
def benchmark_attention(fn, q, k, v, warmup=10, iters=100):
"""Benchmark an attention kernel. Returns median ms."""
# Warmup
for _ in range(warmup):
fn(q, k, v, causal=True)
torch.cuda.synchronize()
# Timed iterations
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for _ in range(iters):
start.record()
fn(q, k, v, causal=True)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
times.sort()
median_ms = times[len(times) // 2]
# Compute TFLOPS
B, S, H, D = q.shape
flops = 4 * B * H * S * S * D # 2 for QK^T, 2 for attn@V
tflops = flops / (median_ms / 1000) / 1e12
return median_ms, tflops
# Compare backends across sequence lengths
for seq_len in [1024, 4096, 16384, 65536]:
q = torch.randn(1, seq_len, 32, 128, dtype=torch.bfloat16, device="cuda")
k = torch.randn(1, seq_len, 8, 128, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, seq_len, 8, 128, dtype=torch.bfloat16, device="cuda")
# Test each backend
from flash_attn import flash_attn_func
ms, tflops = benchmark_attention(flash_attn_func, q, k, v)
print(f"seq={seq_len}: FA {ms:.2f}ms, {tflops:.1f} TFLOPS")Benchmarking Best Practices
- Isolate GPU:
CUDA_VISIBLE_DEVICES=0, no other processes - Warmup: ≥10 iterations to stabilize CUDA state and clocks
- Use CUDA events: Not
time.time()— events measure GPU time accurately - Report median: Not mean — avoids skew from occasional GC pauses
- Vary dimensions: Test your actual model's head_dim, num_heads, GQA ratio
- Test both prefill and decode: Different batch shapes trigger different kernels
- Check utilization: Compare measured TFLOPS vs GPU peak (e.g., H100 = ~989 TFLOPS bf16)
See scripts/check_attention_backend.py for a runnable benchmark script.
Container Image Considerations
FlashAttention compilation is slow (~30 min). Use pre-built wheels or NGC containers:
# Option 1: NGC PyTorch container (includes FA)
FROM nvcr.io/nvidia/pytorch:24.12-py3
# Option 2: Pre-built wheel
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime
RUN pip install flash-attn --no-build-isolation
# Option 3: Set MAX_JOBS to limit compilation parallelism
ENV MAX_JOBS=4
RUN pip install flash-attn --no-build-isolationTroubleshooting
"Torch was not compiled with flash attention"
This warning means the PyTorch build lacks FA support. Causes:
- GPU arch < Ampere (T4, V100) — FA2 requires SM80+
- CUDA toolkit version mismatch
- PyTorch built without FA kernel
Fix: Install flash-attn package separately, or use attn_implementation="sdpa" which uses the Efficient Attention fallback.
Head Dimension Errors
FA2 supports head_dim ≤ 256, but backward pass for head_dim > 192 requires A100/H100. If training with head_dim 256 on consumer GPUs, disable dropout (dropout_p=0.0).
Memory Errors with Long Sequences
Flash Attention is O(N) memory in sequence length (vs O(N²) for vanilla attention), but KV cache still grows linearly. For very long sequences:
- Use sliding window attention (
window_sizeparameter) - Enable Ring Attention for multi-GPU sequence parallelism
- Consider chunked prefill in inference engines
Cross-References
- pytorch — PyTorch training and
F.scaled_dot_product_attention - fsdp — FSDP distributed training (uses SDPA internally)
- megatron-lm — Megatron-LM attention with context parallelism / Ring Attention
- vllm — vLLM inference engine (uses PagedAttention internally)
- sglang — SGLang inference engine (uses FlashInfer by default; uses RadixAttention built on Flash Attention)
- huggingface-transformers —
attn_implementationparameter for model loading - torch-compile — Compiling attention kernels with
torch.compile - deepspeed — DeepSpeed training with Flash Attention kernels
- gpu-operator — GPU driver prerequisites (SM80+ required for FA2)
Reference
- FlashAttention-2 GitHub
- FlashAttention-3 GitHub
- FlashInfer GitHub
- FlashInfer docs
- FlashInfer paper (arXiv:2501.01005)
- PyTorch SDPA docs
- FlexAttention blog
- FlexAttention for Inference blog
- varlen_attn docs
scripts/check_attention_backend.py— check available SDPA backends and benchmark them across dtypes- `flashattention3-and-flashinfer.md` — FlashAttention-3 Hopper features and FlashInfer PagedAttention engine
- `flexattention-inference.md` — FlexDecoding backend, KV cache offset patterns, PT2-native PagedAttention,
varlen_attnpacked sequence API - `troubleshooting.md` — Installation failures, runtime errors, correctness issues, and performance debugging