Flash Attention: IO-Aware Exact Attention

Interactive visualization of Flash Attention - the breakthrough algorithm that makes attention memory-efficient through tiling, recomputation, and kernel fusion.

Best viewed on desktop for optimal interactive experience

Flash Attention: Revolutionizing Transformer Efficiency

Flash Attention is an IO-aware attention algorithm that achieves 2-9× speedup over standard attention while using orders of magnitude less memory - all while computing exact attention, not an approximation.

Interactive Flash Attention Explorer

Visualize how tiling and memory hierarchy optimization transform attention computation:

Attention Method

Tiling Visualization

32×32 blocks of size 64×64
0.0
KB Memory
0.0
M FLOPs
0.0
K IO Ops
0.0
Seconds

GPU Memory Hierarchy

Registers
1 KB100 TB/s
1 cycle
SRAM/Shared
100 KB19 TB/s
~20 cycles
HBM/DRAM
40-80 GB1.5-3 TB/s
~200 cycles
System RAM
100s GB100 GB/s
~1000 cycles

Flash Attention insight: Keep data in SRAM (19 TB/s) instead of HBM (3 TB/s) for 6× bandwidth improvement.

Performance Comparison

Memory UsageFlash: 45.3× reduction
16.0 MB
0.4 MB
IO OperationsFlash: 135.8× reduction
12.6M
0.1M

Algorithm Complexity Comparison

MethodMemoryFLOPsIO ComplexitySpeed
StandardO(N²)O(N²)O(N²)
Flash AttentionO(N)O(N²)O(N²/M)2-4×
Flash Attention 2O(N)O(N²)O(N²/M)5-9×
Block-SparseO(N√N)O(N√N)O(N√N)10-50×
N = sequence length, M = SRAM size

Flash Attention Key Innovations

1. Tiling

Split attention matrix into blocks that fit in SRAM (100KB), avoiding HBM access.

2. Recomputation

Recompute attention on-the-fly during backward pass instead of storing.

3. Kernel Fusion

Fuse softmax, dropout, and masking into single kernel to minimize memory access.

4. IO-Aware

Algorithm designed around memory bandwidth, not just FLOPs.

Real-World Speedups

512 tokens
2.4×
1024 tokens
3×
2048 tokens
3.9×
4096 tokens
5.1×
8192 tokens
7.6×
16384 tokens
9.5×
* Speedup vs standard attention on A100 GPU

The Memory Bandwidth Bottleneck

The Hidden Problem

Modern GPUs have massive compute power but limited memory bandwidth:

ResourceA100 GPURatio
Compute312 TFLOPS-
Memory Bandwidth1.5 TB/s208:1
SRAM Bandwidth19 TB/s16:1

Most attention implementations are memory-bound, not compute-bound!

Standard Attention's Inefficiency

Attention(Q,K,V) = softmax(QKT√(d))V

Standard implementation:

  1. Compute S = QKT → Store N×N matrix in HBM
  2. Compute P = softmax(S) → Read/write N×N matrix
  3. Compute O = PV → Read N×N matrix

Total HBM accesses: O(N2)

Flash Attention's Innovation

Core Idea: Stay in SRAM

Instead of materializing the full attention matrix:

  1. Tile the computation into blocks that fit in SRAM
  2. Fuse operations to minimize memory transfers
  3. Recompute instead of storing intermediate results

The Tiling Strategy

def flash_attention(Q, K, V, block_size): N = Q.shape[0] num_blocks = ceil(N / block_size) O = zeros(N, d) for i in range(num_blocks): # Load Q block to SRAM Q_block = Q[i*block_size:(i+1)*block_size] # Initialize block outputs O_block = zeros(block_size, d) l_block = zeros(block_size) m_block = full(block_size, -inf) for j in range(num_blocks): # Load K, V blocks to SRAM K_block = K[j*block_size:(j+1)*block_size] V_block = V[j*block_size:(j+1)*block_size] # Compute attention in SRAM S_block = Q_block @ K_block.T / sqrt(d) # Online softmax update m_new = maximum(m_block, max(S_block)) P_block = exp(S_block - m_new) l_new = exp(m_block - m_new) * l_block + sum(P_block) # Update output O_block = exp(m_block - m_new) * O_block + P_block @ V_block m_block = m_new l_block = l_new # Write back to HBM O[i*block_size:(i+1)*block_size] = O_block / l_block return O

Mathematical Foundation

Online Softmax

Key insight: Compute softmax without materializing the full matrix using the online softmax algorithm:

softmax(x)i = exi - mΣj exj - m

Where m = max(x). This can be computed in a single pass!

Safe Softmax Update

When processing blocks incrementally:

lnew = em^{old - mnew} · lold + lcurrent
Onew = em^{old - mnew} · lold · Oold + lcurrent · Ocurrentlnew

Memory Complexity

Standard attention:

Mstandard = O(N2 + Nd)

Flash Attention:

Mflash = O(N√(MSRAM) + Nd)

For typical values: 100× memory reduction!

IO Complexity Analysis

Standard Attention IO

IOstandard = O(Nd + N2)

Breakdown:

  • Load Q, K, V: 3Nd
  • Store/load S: 2N2
  • Store/load P: 2N2
  • Store O: Nd

Flash Attention IO

IOflash = O(N2dMSRAM)

Breakdown:

  • Load each block of K, V: O(Nd) total
  • Load Q once: O(Nd)
  • Write O once: O(Nd)

Reduction factor: MSRAMd (typically 16-64×)

Implementation Optimizations

1. Kernel Fusion

Fuse multiple operations into single CUDA kernel:

__global__ void fused_attention_kernel( float* Q, float* K, float* V, float* O, int N, int d, int block_size ) { // Shared memory for tiles __shared__ float Q_smem[BLOCK_SIZE][HEAD_DIM]; __shared__ float K_smem[BLOCK_SIZE][HEAD_DIM]; __shared__ float V_smem[BLOCK_SIZE][HEAD_DIM]; // Compute attention with all ops fused // No intermediate writes to global memory }

2. Warp-Level Primitives

Utilize GPU warp-level operations:

  • __shfl_sync() for reductions
  • Tensor cores for matrix multiplies
  • Warp-wide reductions for softmax

3. Work Partitioning (Flash-2)

Better parallelization:

  • Split across sequence length dimension
  • Each warp handles different queries
  • Reduced synchronization overhead

Variants and Extensions

Flash Attention 2

Improvements over Flash Attention:

  1. Better parallelization: 2× speedup
  2. Reduced non-matmul FLOPs:
  3. Support for different head dimensions
  4. Multi-query/Grouped-query attention

Flash Decoding

Optimized for inference:

  • Parallel decoding for batch size 1
  • Split K,V across threadblocks
  • Efficient for long context generation

Block-Sparse Flash Attention

Combine with sparsity patterns:

def sparse_flash_attention(Q, K, V, sparsity_mask): # Only compute blocks where mask is non-zero for i, j in sparse_blocks(sparsity_mask): compute_block_attention(Q[i], K[j], V[j])

Performance Characteristics

Speedup by Sequence Length

Sequence LengthSpeedupMemory Savings
5122.4×10×
10243.0×15×
20483.9×20×
40965.1×40×
81927.6×60×
163849.5×100×

Hardware Scaling

Performance on different GPUs:

  • V100: 2-3× speedup (limited SRAM)
  • A100: 3-5× speedup (more SRAM)
  • H100: 5-9× speedup (faster HBM)

Backward Pass

Flash Attention also optimizes the backward pass:

  1. Recomputation: Don't store attention matrix
  2. Gradient accumulation: In SRAM
  3. Fused operations: Single kernel for backward
def flash_attention_backward(dO, Q, K, V, O): # Recompute attention blocks on-the-fly # Accumulate gradients in SRAM # Single pass through sequence dQ, dK, dV = compute_gradients_tiled(dO, Q, K, V, O) return dQ, dK, dV

Practical Considerations

When to Use Flash Attention

Use when:

  • Long sequences (>512 tokens)
  • Memory constrained
  • Training large models
  • Need exact attention

Don't use when:

  • Very short sequences (less than 128 tokens)
  • Custom attention patterns needed
  • Hardware doesn't support (old GPUs)

Integration

Most frameworks now include Flash Attention:

# PyTorch from torch.nn.functional import scaled_dot_product_attention out = scaled_dot_product_attention( Q, K, V, attn_mask=mask, dropout_p=0.1, is_causal=True ) # Uses Flash Attention automatically # Transformers library from transformers import AutoModel model = AutoModel.from_pretrained( "model-name", attn_implementation="flash_attention_2" )

Future Directions

Flash Attention 3

Potential improvements:

  • Persistent kernels: Keep data in SRAM across layers
  • Cross-attention optimization: For encoder-decoder
  • Dynamic sparsity: Adaptive attention patterns

Hardware Co-design

Future hardware optimizations:

  • Larger SRAM (>256KB)
  • Higher SRAM bandwidth
  • Hardware attention units
  • Near-memory computing

Common Misconceptions

"Flash Attention is approximate"

False: Flash Attention computes exact attention, numerically identical to standard attention (within floating-point precision).

"Flash Attention only helps with memory"

False: Primary benefit is speed (2-9×), memory savings are secondary.

"Flash Attention requires special hardware"

False: Works on any GPU with CUDA capability ≥ 7.5 (Turing and newer).

Conclusion

Flash Attention demonstrates that algorithmic innovation can overcome hardware limitations. By reformulating attention as an IO-aware problem rather than a pure compute problem, it enables training and inference of models with much longer contexts than previously possible.

If you found this explanation helpful, consider sharing it with others.

Mastodon