When I set out to implement Ring Attention for long-context models, I hit a wall that none of the papers prepared me for. My performance was capping out at half of what it should be, no matter how many accelerators I threw at the problem. It turns out, there's a hidden performance trap in the standard recipe for causal attention. Here’s the story of that bug, how to prove it exists without burning a single TPU hour, and how a simple trick called "zigzag sharding" fixed it completely.
This post is a walkthrough of ring-flash-jax, a small JAX project that explores this problem and implements the fix. We'll add the four things you actually need before Ring Attention is usable for training a real-world causal language model.
TL;DR
ring-flash-jax is a JAX implementation of ring attention with four key additions to the standard pattern:
- Causal masking — required for any decoder language model.
- Zigzag (striped) sharding — fixes the critical load-imbalance problem that kills performance in the naive causal version.
- An analytical load-balance tool — to prove the performance cliff on paper *before* you burn expensive compute time.
- A Pallas-fused inner kernel — for high-performance execution on TPUs and GPUs.
And thanks to JAX, the entire backward pass just works automatically. It's a trainable layer right out of the box.
Repo: ring-flash-jax · runs on CPU (8 simulated devices), TPU, or GPU · MIT license.
There's also a Colab TPU notebook that runs the whole thing on a free Colab v2-8 runtime.
Why We Need to Shard the Sequence
If you want to run a transformer at long context, you eventually hit a massive memory wall. Running a model like LLaMA-3 70B with a million-token context creates a huge KV cache. In bf16 precision, that cache requires around 328 GB of memory (80 layers × 8 KV heads × 128 head dim × 2 for K+V × 2 bytes/elem). That simply won't fit on a single accelerator.
The standard solution is Ring Attention. You shard the sequence across `N` devices. Then, in `N` steps, the K/V shards rotate around a logical ring (using jax.lax.ppermute). Each device computes its piece of the attention puzzle, and after `N` steps, every query has seen every key. It's a clever way to keep memory bounded while overlapping communication and computation.
The textbook version is clean. The version you can actually *train* is not. That gap is what this project is about.
The Problem Nobody Mentions in the Diagrams
The classic Ring Attention picture quietly assumes non-causal attention. The moment you add a causal mask—which you need for any decoder language model—the beautiful symmetry breaks, and a nasty performance bug emerges.
Here's why. With contiguous sharding, device `i` holds queries `[i·L, (i+1)·L)`. At step `s`, it sees keys from device `(i−s) mod N`. For causal attention, query `q` can only attend to key `k` if `k ≤ q`. This means:
- Device 0 (earliest queries) does almost nothing. It computes its local block (which is half-masked) and that's it. Every other key shard is from the "future" and gets masked out.
- Device N-1 (latest queries) does almost everything. It attends to every single key block from every device.
- The work increases linearly from device 0 to N-1. Your cluster's speed is dictated by the slowest device.
If we count work in full `L × L` blocks, the slowest device does `(N−0.5)` blocks of work, while the total work is `N²/2` blocks. The achieved speedup over a single device is:
speedup_naive(N) = (Total Work) / (Slowest Device Work)
= (N²/2) / (N − 0.5)
= N² / (2N − 1)
For large N, this function asymptotes to N/2. You bought N TPUs but are only getting the performance of N/2. The other half are sitting idle, masked out by the causal logic. This is the performance trap.
The Fix: Zigzag (Striped) Sharding
The fix is elegant and comes from Brandon et al., 2023, "Striped Attention". Instead of giving each device one contiguous chunk, you give it two interleaved chunks: one from the beginning of the sequence and one from the end.
# Device `i` holds tokens from two chunks:
chunk_A: [ i * L, (i+1)* L ) (from the first half)
chunk_B: [(2N-1-i)* L, (2N-i )* L ) (from the second half)
So, Device 0 gets the very first tokens *and* the very last tokens. The magic is that this layout makes the work under a causal mask perfectly uniform. Every device gets an "easy" early chunk (where lots of keys are masked) and a "hard" late chunk (where few keys are masked). The two balance out perfectly.
You recover the full N× speedup you were promised. The implementation is just a permutation before sharding and an inverse permutation after.
Proving the Cliff Without Burning TPU Hours
This kind of theoretical analysis is great, but "trust me" isn't a good engineering principle. Before writing a single kernel, I wanted to prove it. So I wrote a tiny analytical tool that just *counts* the unmasked `(q, k)` pairs for each device under each sharding scheme. You can settle architectural debates in milliseconds.
The results speak for themselves:
$ python -m ring_flash_jax.load_analysis
Setup: seq_len=2,048, devices=8
Naive causal achieved speedup: 4.27x (imbalance 1.87x)
Zigzag causal achieved speedup: 8.00x (imbalance 1.00x)
Zigzag/naive ratio : 1.87x
Setup: seq_len=4,096, devices=16
Naive causal achieved speedup: 8.26x (imbalance 1.94x)
Zigzag causal achieved speedup: 16.00x (imbalance 1.00x)
Zigzag/naive ratio : 1.94x
Setup: seq_len=8,192, devices=32
Naive causal achieved speedup: 16.25x (imbalance 1.97x)
Zigzag causal achieved speedup: 32.00x (imbalance 1.00x)
Zigzag/naive ratio : 1.97x
The "imbalance" column (`slowest_device / mean_device`) approaches 2.0x for the naive scheme, exactly as the math predicted. For zigzag, it's a perfect 1.00x. The theory holds. This kind of tool is underrated for validating ideas before committing to complex implementation.
The Pallas-fused Inner Block
Each step of the ring algorithm needs an online flash-attention update. While you can write this with standard `jnp` operations, it's much faster to fuse it into a single kernel on TPU/GPU. The project uses JAX Pallas for this, creating a drop-in replacement for the pure Python version. The high-level algorithm doesn't even know whether a fused kernel or a fallback is running, which makes development and testing a breeze.
The Backward Pass: Nothing to Write
This is the part I find genuinely delightful about JAX. The forward pass uses `shard_map` + `fori_loop` + `ppermute` + a Pallas kernel. You might expect the backward pass to be a nightmare to implement by hand. But with JAX, it's one line:
dQ, dK, dV = jax.grad(loss_fn, argnums=(0, 1, 2))(Q, K, V)
That's it. `jax.grad` traces through everything—collective permutations, loops, and even the Pallas kernel—and generates a correct and efficient gradient. The tests confirm it's accurate to `1e-6` against a dense reference implementation. For anyone who has manually derived a flash attention backward pass, this feels like magic.
See it For Yourself on a Free Colab TPU
The best way to be convinced is to see it run. I've packaged the whole flow as a Colab notebook. A free Colab TPU runtime is a v2-8 with 8 cores, which is the exact device mesh our simulation uses. You can run the analysis and benchmarks on real silicon with zero code changes.
The notebook will:
- Confirm you have 8 TPU cores.
- Run the analytical tool to predict the performance cliff.
- Run forward and backward correctness tests for both schemes.
- Run a wall-clock benchmark that measures the actual speedup on an 8192-token sequence.
That last step is the fun one. The analytical tool predicts the naive scheme will be ~1.87× slower than zigzag at N=8. When you run the benchmark, you'll see the real TPU wall-clock time lands right in that neighborhood. It's rare and satisfying to see theory and practice align so cleanly.
(The notebook includes a few caveats about Colab TPU availability and compile times, but they aren't showstoppers.)
What's Next?
To keep the project focused, I deliberately left out a few production features like a custom VJP for memory savings, tensor parallelism, or mixed precision. The goal of ring-flash-jax is to be a clear reference for the causal load-balance problem and a solid starting point for building these more advanced systems.
You can also clone the repo and run the same demo locally on a CPU with simulated devices, which is perfect for fast iteration.
git clone <repo> && cd ring-flash-jax
pip install -e .
# 8 simulated devices on CPU
XLA_FLAGS="--xla_force_host_platform_device_count=8" \
python examples/demo.py
References
- Liu, Zaharia, Abbeel, 2023, "Ring Attention with Blockwise Transformers" — the original ring attention paper.
- Brandon, Nrusimha, Qian, Ankner, Jin, Song, Ragan-Kelley, 2023, "Striped Attention" — the zigzag fix for causal load imbalance.
- JAX Pallas docs — for the inner block kernel.
- JAX
shard_mapdocs — for the manual SPMD primitives used throughout.
If you find an issue with the load-balance numbers or the zigzag derivation, please open an issue — happy to fix and credit.
Comments
Post a Comment