tylertitsworth

flash-attention

"Flash Attention, FlashInfer, SDPA backends, PagedAttention, and attention kernel selection/configuration. Use when choosing or configuring attention backends for training or inference (FlashAttention-2/3, FlashInfer, SDPA, xFormers, PagedAttention, Ring Attention, FlexAttention/FlexDecoding, varlen_attn)."

tylertitsworth 3 1 Updated 2mo ago

Resources

2
GitHub

Install

npx skillscat add tylertitsworth/skills/flash-attention

Install via the SkillsCat registry.

SKILL.md

Flash Attention & Attention Backends

Attention Backend Landscape

Backend GPU Arch Dtypes Max Head Dim Use Case
FlashAttention-2 Ampere, Ada, Hopper fp16, bf16 256 General training & inference
FlashAttention-3 Hopper only (H100/H800) fp16, bf16, fp8 (E4M3 fwd) 256 Maximum throughput on H100
FlashInfer Ampere+ (SM80+) fp16, bf16, fp8 256 LLM inference kernels — decode, prefill, paged KV
SDPA Math Any CUDA fp32, fp16, bf16 Any Fallback / debugging
SDPA Efficient (xFormers/Memory-Efficient) Ampere+ fp16, bf16 128 When FA unavailable
PagedAttention Ampere+ fp16, bf16 128+ KV cache management in inference (vLLM, TGI)
Ring Attention Multi-GPU fp16, bf16 256 Sequence parallelism for very long contexts
FlexAttention Ampere+ fp16, bf16 Any Custom attention masks via torch.nn.attention.flex_attention; FlexDecoding backend auto-activates for inference decode

PyTorch SDPA (Scaled Dot Product Attention)

PyTorch routes F.scaled_dot_product_attention() to the best available backend automatically. Override with the context manager:

from torch.nn.attention import SDPBackend, sdpa_kernel

# Force FlashAttention only
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    output = F.scaled_dot_product_attention(q, k, v)

# Set priority order (try Flash first, fall back to Efficient)
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION], set_priority=True):
    output = F.scaled_dot_product_attention(q, k, v)

SDPBackend Enum

Backend Enum Value Notes
FLASH_ATTENTION Flash Attention (Dao-AILab) Requires Ampere+, fp16/bf16, head_dim ≤ 256
EFFICIENT_ATTENTION Memory-Efficient (xFormers-style) Broader head_dim support, slightly slower
MATH Naive PyTorch math No restrictions, no fusion, slow — use for debugging
CUDNN_ATTENTION cuDNN backend Available in PyTorch 2.2+, limited config

Backend Selection Logic

PyTorch checks backends in priority order. A backend is skipped if:

  • GPU arch not supported (e.g., FA on Turing T4)
  • Dtype mismatch (e.g., fp32 → no FA)
  • Head dimension exceeds backend limit
  • Causal mask + custom attention mask combination unsupported
  • Dropout > 0.0 with some backends in eval mode

Check which backend would be selected:

from torch.nn.attention import _get_flash_version
# Returns version string or None if unavailable
print(_get_flash_version())

FlashAttention-2 Direct API

The flash-attn package provides a direct API bypassing SDPA:

from flash_attn import flash_attn_func, flash_attn_qkvpacked_func

# Separate Q, K, V tensors
# q, k, v: (batch, seqlen, nheads, headdim)
output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,          # Set 0.0 during eval
    softmax_scale=None,      # Default: 1/sqrt(headdim)
    causal=False,
    window_size=(-1, -1),    # Sliding window: (left, right)
    alibi_slopes=None,       # (nheads,) or (batch, nheads) for ALiBi
    deterministic=False,     # Deterministic backward (slower, more memory)
)

# Packed QKV (faster backward — avoids gradient concat)
# qkv: (batch, seqlen, 3, nheads, headdim)
output = flash_attn_qkvpacked_func(qkv, causal=True)

Key Parameters

Parameter Type Default Effect
dropout_p float 0.0 Attention dropout probability
softmax_scale float 1/sqrt(d) QK scaling factor
causal bool False Apply causal mask
window_size (int, int) (-1, -1) Sliding window attention bounds
alibi_slopes Tensor None ALiBi positional bias slopes
deterministic bool False Deterministic backward pass

GQA/MQA Support

Flash Attention supports grouped-query and multi-query attention natively. Pass K, V with fewer heads than Q — the number of Q heads must be divisible by K/V heads:

# GQA: 32 query heads, 8 KV heads
q = torch.randn(B, S, 32, D, dtype=torch.float16, device="cuda")
k = torch.randn(B, S, 8, D, dtype=torch.float16, device="cuda")
v = torch.randn(B, S, 8, D, dtype=torch.float16, device="cuda")
output = flash_attn_func(q, k, v, causal=True)

FlashAttention-3 (Hopper-Native)

FA3 is specifically designed for H100/H800, exploiting Hopper architecture features: warp specialization, TMA (Tensor Memory Accelerator), and FP8 tensor cores. ~1.5-2× faster than FA2 on H100 for fp16/bf16.

Key Architecture Differences from FA2

Feature FA2 FA3
Warp scheduling Standard cooperative groups Warp-specialized: producer warps (load data via TMA) + consumer warps (compute)
Memory access Manual GMEM→SMEM copies TMA-based async bulk copies (hardware DMA)
FP8 support None E4M3 forward pass on tensor cores
Pipeline Sequential load→compute 2-stage async pipeline: load next tile while computing current
SMEM usage Standard allocation Pingpong SMEM buffers for async pipeline
Block size Fixed per kernel Larger blocks (256×128) enabled by TMA efficiency

FA3 API

FA3 is a separate package from the main flash-attn — it requires a distinct install from the hopper/ subdirectory and uses a different import.

Install:

# FA3 is NOT included in pip install flash-attn
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention/hopper
python setup.py install

Usage:

# FA3 uses flash_attn_interface module, NOT flash_attn
import flash_attn_interface

output = flash_attn_interface.flash_attn_func(q, k, v, causal=True)

FP8 forward pass:

# FP8 E4M3 inputs — forward only (backward still in bf16/fp16)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
output = flash_attn_interface.flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)

Requirements: H100/H800 (SM90a), CUDA ≥ 12.3 (12.8 recommended).

FA3 Performance Characteristics

Sequence Length FA3 vs FA2 Speedup (H100, bf16, causal) Notes
1K ~1.3× Compute-bound, less TMA advantage
4K ~1.5× Sweet spot for TMA pipeline
16K ~1.7× Larger tiles amortize overhead
64K+ ~1.8-2.0× Maximum pipeline benefit

When to use FA3 explicitly:

  • Always on H100/H800 — no reason to use FA2
  • FP8 forward pass for inference (saves memory bandwidth)
  • Long-context training (>4K) where pipeline efficiency matters most

FlashInfer

FlashInfer is a kernel library for LLM inference serving (not training). It provides optimized attention kernels, sampling, and KV cache management — used as the default attention backend in SGLang.

Version: 0.6.4 | Install: pip install flashinfer-python

Key Features

Feature Description
Prefill attention Ragged tensor attention for variable-length prefill batches
Decode attention Optimized single-token decode with paged KV cache
POD-Attention Fused prefill+decode kernel for mixed batching
PageAttention Block-sparse paged KV cache management
Cascade attention Multi-level KV cache (shared prefix + unique suffix)
GQA/MQA Native grouped/multi-query attention support
FP8 KV cache E4M3/E5M2 KV cache quantization
Sampling GPU-fused Top-K, Top-P, Min-P without sorting
LoRA BGMV/SGMV kernels for batched LoRA
BF16 GEMM Matrix multiplication for SM100+ (Blackwell)

FlashInfer vs FlashAttention

Aspect FlashInfer FlashAttention
Focus Inference serving (variable-length, paged KV) Training + inference (fixed-length, dense)
KV cache Paged, ragged, block-sparse — built-in External (user manages contiguous tensors)
Decode Highly optimized single-token decode kernels Decent but not inference-specialized
Prefill Variable-length ragged attention Packed QKV (flash_attn_varlen_func)
Customization JIT kernel generation for custom configs Fixed compiled kernels
Used by SGLang (default), vLLM (optional) Training frameworks, HF Transformers, vLLM

FlashInfer API (Inference Engine Integration)

FlashInfer is typically used indirectly through serving engines. Direct API for custom inference:

import flashinfer

# Prefill: variable-length batch
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer)
prefill_wrapper.plan(
    qo_indptr=qo_indptr,       # [0, seq_len_0, seq_len_0+seq_len_1, ...]
    kv_indptr=kv_indptr,
    kv_indices=kv_indices,      # page table indices
    kv_last_page_len=kv_last_page_len,
    num_qo_heads=32,
    num_kv_heads=8,             # GQA
    head_dim=128,
    page_size=16,
)
output = prefill_wrapper.run(q, kv_cache)

# Decode: single-token per request
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer)
decode_wrapper.plan(
    indptr=indptr,
    indices=indices,
    last_page_len=last_page_len,
    num_qo_heads=32,
    num_kv_heads=8,
    head_dim=128,
    page_size=16,
)
output = decode_wrapper.run(q, kv_cache)

Selecting FlashInfer in Serving Engines

# SGLang (default on non-Hopper)
import sglang as sgl
llm = sgl.Engine(model_path="...", attention_backend="flashinfer")

# vLLM (optional)
# Set VLLM_ATTENTION_BACKEND=FLASHINFER or use --attention-backend flashinfer

PagedAttention

PagedAttention manages KV cache memory in blocks (pages) rather than contiguous tensors — eliminates memory fragmentation during inference. Used by vLLM, TGI, and SGLang internally.

Not directly configured by users — it's an implementation detail of inference engines. However, understand its impact:

Setting Where Effect
block_size vLLM EngineArgs Page size for KV cache blocks (default: 16)
gpu_memory_utilization vLLM Fraction of GPU memory for KV cache (default: 0.9)
swap_space vLLM CPU swap space in GB for overflow pages
num_gpu_blocks Computed Total KV cache blocks available

How it works: Instead of pre-allocating max_seq_len × batch_size contiguous memory, PagedAttention allocates fixed-size blocks on demand. Sequences reference block tables mapping logical positions to physical blocks. This allows:

  • Near-zero memory waste from padding
  • Efficient prefix caching (shared blocks across sequences)
  • Copy-on-write for parallel sampling

Ring Attention (Sequence Parallelism)

Ring Attention distributes long sequences across GPUs in a ring topology. Each GPU holds a chunk of Q and rotates K, V blocks around the ring:

  • Enables training on sequences longer than single-GPU memory allows
  • Linear memory scaling with number of GPUs
  • Used in: Llama 3.1 (128K context), DeepSeek-V2

Integration: Typically activated via framework flags rather than direct API:

  • Megatron-LM: --context-parallel-size N enables ring attention
  • DeepSpeed Ulysses: Alternative approach — all-to-all on attention heads instead of ring on sequence

FlexAttention (PyTorch 2.5+)

flex_attention compiles custom attention patterns into fused kernels via score_mod / mask_mod functions and BlockMask for sparsity:

from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def causal_with_window(score, b, h, q_idx, kv_idx):
    return torch.where(
        (q_idx >= kv_idx) & (q_idx - kv_idx < 1024),
        score, float("-inf"),
    )

block_mask = create_block_mask(causal_with_window, B, H, Q_LEN, KV_LEN)
output = flex_attention(q, k, v, block_mask=block_mask)

Advantages: Fuses arbitrary attention patterns (document masking, sliding window + causal, prefix LM) without materializing the full attention matrix.

FlexDecoding (Inference Backend, PyTorch 2.7+)

torch.compile(flex_attention) auto-switches to a FlexDecoding kernel when query length is much shorter than KV length (decode phase). No API change — just shape-based JIT dispatch:

  • Prefill (q_len ≈ kv_len): uses standard FlexAttention kernel
  • Decode (q_len=1, kv_len=128K): auto-recompiles to FlexDecoding (FlashDecoding-based)

Key patterns for efficient decode: use captured tensor offsets to avoid recompilation per step, precompute BlockMask and slice per-step, and use convert_logical_block_mask() for PT2-native PagedAttention without custom CUDA kernels.

Performance: 1.22×–2.04× faster than SDPA on LLaMA 3.1-8B decode; parity to 1.66× on 70B. On par with FlashDecoding (FA-KV) while supporting arbitrary attention patterns.

See `references/flexattention-inference.md` for: offset patterns, BlockMask slicing, PT2-native PagedAttention architecture, and varlen_attn API.

varlen_attn — Variable-Length Attention (PyTorch 2.10+)

torch.nn.attention.varlen.varlen_attn provides packed/ragged sequence attention in PyTorch core — eliminates padding waste for variable-length batches. Uses Flash Attention kernels under the hood; replaces flash_attn_varlen_func from the flash-attn package.

from torch.nn.attention.varlen import varlen_attn

# Packed tensors: (total_tokens, num_heads, head_dim) — no batch dim, no padding
# cu_seq: cumulative positions, e.g. [0, 128, 384, 448, 960] for 4 sequences
output = varlen_attn(q, k, v, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True)

Prefer varlen_attn over flash_attn_varlen_func for new code — it's part of PyTorch core and doesn't require a separate package. See `references/flexattention-inference.md` for packed sequence construction and the AuxRequest LSE output.

Hugging Face Transformers Integration

Control attention backend via attn_implementation:

from transformers import AutoModelForCausalLM

# Explicit backend selection
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    attn_implementation="flash_attention_2",  # or "sdpa", "eager"
    torch_dtype=torch.bfloat16,
)
Value Backend Notes
"flash_attention_2" FlashAttention-2 via flash-attn package Requires pip install flash-attn in image
"sdpa" PyTorch native SDPA Default for PyTorch ≥ 2.1.1
"eager" Manual attention math Debugging only, very slow

Attention Kernel Benchmarking Methodology

When selecting attention backends for a deployment, benchmark systematically:

What to Measure

Metric How Why
Prefill latency Time first token, varying input lengths Prefill is memory-bandwidth bound at short seq, compute-bound at long seq
Decode throughput Tokens/sec at target batch size Decode is memory-bandwidth bound — kernel efficiency matters
Peak memory torch.cuda.max_memory_allocated() Determines max batch size / sequence length
TFLOPS utilization Measured FLOPS ÷ peak theoretical Indicates how close to hardware limit

Benchmarking Script Pattern

import torch
import time

def benchmark_attention(fn, q, k, v, warmup=10, iters=100):
    """Benchmark an attention kernel. Returns median ms."""
    # Warmup
    for _ in range(warmup):
        fn(q, k, v, causal=True)
    torch.cuda.synchronize()

    # Timed iterations
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    times = []
    for _ in range(iters):
        start.record()
        fn(q, k, v, causal=True)
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    times.sort()
    median_ms = times[len(times) // 2]

    # Compute TFLOPS
    B, S, H, D = q.shape
    flops = 4 * B * H * S * S * D  # 2 for QK^T, 2 for attn@V
    tflops = flops / (median_ms / 1000) / 1e12
    return median_ms, tflops

# Compare backends across sequence lengths
for seq_len in [1024, 4096, 16384, 65536]:
    q = torch.randn(1, seq_len, 32, 128, dtype=torch.bfloat16, device="cuda")
    k = torch.randn(1, seq_len, 8, 128, dtype=torch.bfloat16, device="cuda")
    v = torch.randn(1, seq_len, 8, 128, dtype=torch.bfloat16, device="cuda")

    # Test each backend
    from flash_attn import flash_attn_func
    ms, tflops = benchmark_attention(flash_attn_func, q, k, v)
    print(f"seq={seq_len}: FA {ms:.2f}ms, {tflops:.1f} TFLOPS")

Benchmarking Best Practices

  • Isolate GPU: CUDA_VISIBLE_DEVICES=0, no other processes
  • Warmup: ≥10 iterations to stabilize CUDA state and clocks
  • Use CUDA events: Not time.time() — events measure GPU time accurately
  • Report median: Not mean — avoids skew from occasional GC pauses
  • Vary dimensions: Test your actual model's head_dim, num_heads, GQA ratio
  • Test both prefill and decode: Different batch shapes trigger different kernels
  • Check utilization: Compare measured TFLOPS vs GPU peak (e.g., H100 = ~989 TFLOPS bf16)

See scripts/check_attention_backend.py for a runnable benchmark script.

Container Image Considerations

FlashAttention compilation is slow (~30 min). Use pre-built wheels or NGC containers:

# Option 1: NGC PyTorch container (includes FA)
FROM nvcr.io/nvidia/pytorch:24.12-py3

# Option 2: Pre-built wheel
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime
RUN pip install flash-attn --no-build-isolation

# Option 3: Set MAX_JOBS to limit compilation parallelism
ENV MAX_JOBS=4
RUN pip install flash-attn --no-build-isolation

Troubleshooting

"Torch was not compiled with flash attention"

This warning means the PyTorch build lacks FA support. Causes:

  1. GPU arch < Ampere (T4, V100) — FA2 requires SM80+
  2. CUDA toolkit version mismatch
  3. PyTorch built without FA kernel

Fix: Install flash-attn package separately, or use attn_implementation="sdpa" which uses the Efficient Attention fallback.

Head Dimension Errors

FA2 supports head_dim ≤ 256, but backward pass for head_dim > 192 requires A100/H100. If training with head_dim 256 on consumer GPUs, disable dropout (dropout_p=0.0).

Memory Errors with Long Sequences

Flash Attention is O(N) memory in sequence length (vs O(N²) for vanilla attention), but KV cache still grows linearly. For very long sequences:

  • Use sliding window attention (window_size parameter)
  • Enable Ring Attention for multi-GPU sequence parallelism
  • Consider chunked prefill in inference engines

Cross-References

  • pytorch — PyTorch training and F.scaled_dot_product_attention
  • fsdp — FSDP distributed training (uses SDPA internally)
  • megatron-lm — Megatron-LM attention with context parallelism / Ring Attention
  • vllm — vLLM inference engine (uses PagedAttention internally)
  • sglang — SGLang inference engine (uses FlashInfer by default; uses RadixAttention built on Flash Attention)
  • huggingface-transformersattn_implementation parameter for model loading
  • torch-compile — Compiling attention kernels with torch.compile
  • deepspeed — DeepSpeed training with Flash Attention kernels
  • gpu-operator — GPU driver prerequisites (SM80+ required for FA2)

Reference