Naive scaled dot-product attention has two nested matrix multiplications separated by a softmax — nothing algorithmically wrong with that. The problem is operational: at sequence length N, the intermediate attention score matrix is N×N floats. At N=32K tokens and BF16, that is 2 GB per attention layer per forward pass — written to HBM, then read back, every layer, every step. FlashAttention (Dao et al., May 27, 2022, arXiv:2205.14135) eliminated that write entirely. FlashAttention-2 (July 17, 2023, arXiv:2307.08691) fixed the work partitioning to use warp-level parallelism correctly. FlashAttention-3 (Shah et al., July 11, 2024, arXiv:2407.08608) added Hopper-specific async warp specialization and FP8 support. The progression is a masterclass in GPU microarchitecture applied to one algorithm. Triton — Phil Tillet's Python-embedded DSL for GPU kernels — is why you can implement FlashAttention-2 in ~200 lines of Python instead of 2,000 lines of CUDA. The @triton.jit decorator JIT-compiles a tile-parallel kernel that maps directly to warp-level MMA instructions on the target GPU. Raw CUDA is now reserved for the bottom 5% of performance-critical paths where Triton's abstractions leak. For everything else — custom attention variants, fused layernorm + MLP, speculative decode tree expansion — Triton is the right tool. CS336 Assignment 2 is explicitly a Triton implementation of FlashAttention-2 and is the most educational 200 lines you will write in this entire study path.
Online softmax: the mathematical core
Standard softmax over a row of logits requires two passes: first find the maximum (for numerical stability), then compute exp and normalize. You cannot do this in one pass unless you track a running maximum and running sum simultaneously. The online softmax recurrence (Milakov and Gimelshein, 2018) maintains two scalars per row: m (running max) and l (running sum of exp(x - m)). When a new tile of K columns arrives, the old l is rescaled by exp(m_old - m_new) before adding the new partial sum. This rescaling is exact — no approximation — and it allows the full softmax to be computed tile-by-tile without ever materializing all N keys at once.
FlashAttention makes this recurrence the outer loop of the attention computation. For each query tile Q_i, iterate over K tiles: load K_j and V_j from HBM into SRAM, compute the dot products Q_i K_j^T, update m and l with the online recurrence, accumulate the partial output O_i += softmax_tile × V_j (rescaling the old O_i as m updates). When all K tiles are consumed, write O_i back to HBM once. Total HBM reads: O(N) for Q, K, V. Total HBM writes: O(N) for O. The N×N intermediate matrix is never written. At seq=32K this saves roughly 4 GB of HBM traffic per attention layer.
Causal masking is free in this framework: when processing query tile i against key tile j where j > i (all keys are future tokens), set all scores to -inf before the softmax update. The online recurrence handles -inf entries correctly because exp(-inf) = 0. No separate masking kernel, no memory for the mask matrix.
Triton: GPU kernels in Python
Triton abstracts GPU programming at the tile level. A @triton.jit function receives program IDs (analogous to block IDs in CUDA) and operates on tiles via tl.load, tl.store, tl.dot, tl.exp, tl.sum. The compiler lowers these to the optimal sequence of warp-level MMA instructions, shared memory allocations, and memory access patterns for the target GPU. The key constraint: tl.dot requires compile-time-known tile sizes, which is why FlashAttention Triton kernels take BLOCK_M and BLOCK_N as compile-time constexprs. The tile sizes must be chosen to fit Q_tile + K_tile + V_tile + O_tile within 228 KB of L1.
A practical Triton FlashAttention-2 forward kernel is ~200 lines: load Q tile, loop over K/V tiles accumulating the online softmax and output, write the output and logsumexp. The backward pass is harder — recomputation of attention scores from the saved logsumexp is the key trick that avoids storing the N×N matrix for the backward pass as well. CS336 Assignment 2 requires the forward pass; the backward pass is optional but recommended for anyone who will write custom attention variants for a VLA fine-tuning loop.
Triton limitations worth knowing: no dynamic shapes within a kernel (must recompile for each new shape), no support for warp-to-warp asynchronous pipelines (that is raw CUDA / PTX territory, which FlashAttention-3 exploits). For the ~95% of custom kernel use cases — fused activation functions, custom normalization, attention variants — Triton is sufficient and orders of magnitude faster to develop than CUDA.
FlashAttention-2 vs FlashAttention-3 on Hopper
FlashAttention-2 (Dao, arXiv:2307.08691, July 2023) improved over FA1 primarily by fixing work partitioning: FA1 assigned all queries in a row to one warp, serializing queries over keys. FA2 parallelizes queries across warps and introduces the correct tiling strategy for causal attention (fewer tiles to compute, no wasted softmax work). FA2 reaches roughly 50–73% of H100 BF16 tensor core throughput at long sequences, compared to FA1's 25–35%.
FlashAttention-3 (Shah et al., arXiv:2407.08608, July 2024) exploits three Hopper-specific features: (1) the Tensor Memory Accelerator (TMA) for asynchronous HBM-to-SRAM transfers, decoupled from computation; (2) warp-group specialization where producer warps handle TMA asynchronously while consumer warps run tensor core MMA — this is the pipelining that squeezes out the last 20–30% of throughput; (3) FP8 support for the GEMMs, which doubles tensor core throughput when precision allows. FA3 reaches ~75% of H100 BF16 FLOP throughput, and higher for FP8. The catch: FA3 requires Hopper (H100) or later — it will not run on Ampere (A100).
What custom kernels unlock for your projects
For GR00T N1.5 running cross-attention between vision tokens (512) and language tokens (256) at 50 Hz on Jetson AGX Orin, the attention kernel is the most time-sensitive operation in the DiT. Orin does not support FA3 (Ampere-class GPU), but FA2 via Triton is available. The difference between PyTorch naive attention and FA2 at seq=768 is roughly 2.5× wall-clock — from ~6 ms to ~2.5 ms per attention layer — which can be the margin between meeting and missing the 20 ms control budget.
For DealLens processing VC memos with 32K-token context, FA2 is the difference between fitting the full memo in one call and chunking it. At seq=32K, naive attention requires ~4 GB of HBM for the attention matrix across a 40-layer Llama 3.1 70B — that is the entire memory budget of the model. FA2 reduces that to negligible. Prefix caching (covered in Chapter 19) sits on top of FA2 and further amortizes the prefill cost across requests sharing a common system prompt.
torch.nn.functional.scaled_dot_product_attention dispatches to FlashAttention-2 automatically on Ampere+ GPUs with CUDA >= 11.6. You get FA2 for free in most training loops. The Triton build is still worth doing — it teaches you what SDPA does, and custom attention variants (e.g., sliding window, cross-attention with custom masking) require writing your own kernel.