When I set out to implement Ring Attention for long-context models , I hit a wall that none of the papers prepared me for. My performance was capping out at half of what it should be, no matter how many accelerators I threw at the problem. It turns out, there's a hidden performance trap in the standard recipe for causal attention . Here’s the story of that bug, how to prove it exists without burning a single TPU hour, and how a simple trick called " zigzag sharding " fixed it completely. This post is a walkthrough of ring-flash-jax , a small JAX project that explores this problem and implements the fix. We'll add the four things you actually need before Ring Attention is usable for training a real-world causal language model . TL;DR ring-flash-jax is a JAX implementation of ring attention with four key additions to the standard pattern: Causal masking — required for any decoder language model. Zigzag (striped) sharding — fixes the critical load...
The Curse of Superposition: Why LLMs are Black Boxes Large Language Models like Google’s Gemma are incredibly powerful, but they suffer from a phenomenon known as Superposition . Neural networks naturally want to represent more concepts than they have mathematical dimensions. To accomplish this, they pack multiple unrelated concepts into the same neurons—a property called polysemanticity . When Gemma processes the word "Paris," it doesn't activate a neat, dedicated "City" neuron. Instead, it fires a dense, entangled vector of floating-point numbers in a 2,304-dimensional space that simultaneously represents "France," "capital," "tourism," and "linguistics." For researchers trying to build safer, steerable AI, this is a massive problem. How do we debug an AI if its internal thoughts are entangled in a dense manifold? The state-of-the-art solution is Mechanistic Interpretability via Sparse Autoencoders (SAEs...