The Curse of Superposition: Why LLMs are Black Boxes
Large Language Models like Google’s Gemma are incredibly powerful, but they suffer from a phenomenon known as Superposition. Neural networks naturally want to represent more concepts than they have mathematical dimensions. To accomplish this, they pack multiple unrelated concepts into the same neurons—a property called polysemanticity.
When Gemma processes the word "Paris," it doesn't activate a neat, dedicated "City" neuron. Instead, it fires a dense, entangled vector of floating-point numbers in a 2,304-dimensional space that simultaneously represents "France," "capital," "tourism," and "linguistics." For researchers trying to build safer, steerable AI, this is a massive problem. How do we debug an AI if its internal thoughts are entangled in a dense manifold?
The state-of-the-art solution is Mechanistic Interpretability via Sparse Autoencoders (SAEs). By training a separate, unsupervised neural network (the SAE) on the LLM's internal activations, we can decompile these dense vectors into an "overcomplete" (much larger), highly sparse space. In this expanded space, features stop being entangled and become monosemantic (human-readable single concepts).
Why Cloud TPUs and JAX? The Hardware-Software Synergy
Training an SAE is not a standard fine-tuning job; it presents a unique, brutal computational bottleneck.
To untangle Gemma's 2,304-dimensional hidden states, we must project them into a massively expanded dictionary—typically 8x to 16x larger. An 8x expansion requires calculating forward and backward passes for 18,432-dimensional tensors across batch sizes of 4,096. Doing this for billions of tokens on standard consumer GPUs quickly results in memory bandwidth starvation and agonizingly slow matrix math.
This is precisely where the Google Cloud TPU v5e and JAX ecosystem provide an asymmetric advantage:
- Matrix Multiply Units (MXUs): The TPU v5e's architecture is dominated by systolic arrays (MXUs) designed specifically to chew through massive dense matrix multiplications. Expanding $2304 \to 18432$ is an ideal workload that keeps the MXUs perfectly saturated.
- XLA Kernel Fusion: Calculating the $L_1$ sparsity penalty and applying ReLU activations normally requires the accelerator to read and write to memory multiple times per step. JAX’s XLA (Accelerated Linear Algebra) compiler analyzes our Python code and fuses these operations into a single, highly optimized TPU kernel, virtually eliminating memory-bound latency.
- Zero-Copy Memory via Keras 3: By setting the Keras 3 backend to JAX, the massive Gemma 2B model and our custom Flax SAE share the same underlying TPU High Bandwidth Memory (HBM). We can extract internal LLM tensors and pass them directly to our training loop without any CPU-to-Accelerator transfer overhead.
Code Walkthrough: Tapping into Gemma's Brain
First, we configure our environment and load Gemma 2 2B. We use the Keras Functional API to create a "tap" into the middle of the network (Layer 9), which is typically where high-level semantic concepts are fully formed before the model decides what word to output next.
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import keras_hub
# Load Gemma 2 2B Backbone natively in JAX
gemma_backbone = keras_hub.models.GemmaBackbone.from_preset("gemma2_2b_en")
# Create a sub-model that outputs the hidden state of Layer 9
layer_name = "decoder_block_9"
extractor = keras.Model(
inputs=gemma_backbone.inputs,
outputs=gemma_backbone.get_layer(layer_name).output
)
Building the SAE in JAX/Flax
Our Sparse Autoencoder consists of an encoder (to project into the 18,432-dimensional sparse space) and a decoder (to project back to the original 2,304-dimensional LLM space). We optimize a combined loss function that balances the $L_2$ reconstruction accuracy (Mean Squared Error) with an $L_1$ sparsity penalty:
Here is our TPU-optimized architecture in Flax:
import jax
import jax.numpy as jnp
from flax import linen as nn
class SparseAutoencoder(nn.Module):
d_model: int # Gemma's hidden size (2304)
expansion_factor: int # Expansion factor (8x)
def setup(self):
self.d_sparse = self.d_model * self.expansion_factor
self.encoder = nn.Dense(self.d_sparse, use_bias=True)
self.decoder = nn.Dense(self.d_model, use_bias=True)
def __call__(self, x):
# ReLU enforces non-negative sparse activations (f(x))
pre_acts = self.encoder(x)
sparse_acts = nn.relu(pre_acts)
# Reconstruct original LLM state (\hat{x})
reconstructed = self.decoder(sparse_acts)
return reconstructed, sparse_acts
We then define our training step using the @jax.jit decorator. This is where the magic happens: JAX traces the sae_loss_fn, computes the gradients via jax.value_and_grad, and compiles the entire update step into a single XLA graph executed on the TPU.
@jax.jit
def train_step(state, batch, l1_coeff):
# has_aux=True allows us to return the separated MSE and L1 metrics
grad_fn = jax.value_and_grad(sae_loss_fn, has_aux=True)
(total_loss, (mse_loss, l1_loss)), grads = grad_fn(
state.params, state.apply_fn, batch, l1_coeff
)
# Apply Adam optimizer gradients
state = state.apply_gradients(grads=grads)
return state, total_loss, mse_loss, l1_loss
The Reality of Data Distributions: A Research Insight
During our initial TPU benchmarking, we simulated Gemma's activations using standard Gaussian noise (np.random.randn) to test TPU utilization. The model converged beautifully, achieving perfect reconstruction with minimal active features.
However, the real test of Mechanistic Interpretability happens with real text. We passed the sentence "Mechanistic interpretability helps us understand how AI models think." through our Keras extractor, yielding a real (1, 32, 2304) tensor. When we fed the real activation for the word "understand" into our noise-trained SAE, we observed a fascinating result:
Number of active features for this token: 2479
Out of 18432 possible concepts, these IDs fired:
[ 5 7 33 ... 18423 18424 18429]
2,479 active features fired. In a converged, production-grade SAE, we want an $L_0$ sparsity (active features) of around 20 to 100 per token. Why the explosion? Distribution Shift.
Our SAE was temporarily trained to compress isotropic Gaussian noise. But Gemma's real internal thoughts are highly structured, anisotropic, heavy-tailed geometric representations. When the SAE suddenly saw real data, it fired thousands of features attempting to reconstruct an alien mathematical topology! This perfectly highlights a core research truth: SAEs must be trained on massive, streaming datasets of real token activations to map the true semantic geometry of the LLM. With the zero-copy JAX/Keras pipeline we established, streaming gigabytes of real text through the TPU v5e to properly converge the dictionary is now a trivial scaling task.
Concept Steering: Writing Directly to the Brain
The ultimate goal of extracting these features isn't just observation; it's control. Once we train our SAE on a massive text corpus and identify that a specific feature (e.g., Feature 5) mathematically represents the concept of "Software Code," we can isolate its decoder weights.
By scaling this vector and adding it directly to Gemma's hidden states during inference, we can force the model to output code—without doing any LoRA fine-tuning or prompt engineering. We are literally doing neurosurgery on the AI's thoughts.
# Extract the decoder weights for our target concept
TARGET_FEATURE_IDX = 5
STEERING_STRENGTH = 15.0
# decoder_weights shape: (18432, 2304)
decoder_weights = state.params['decoder']['kernel']
steering_vector = decoder_weights[TARGET_FEATURE_IDX]
# Normalize the vector to prevent exploding activations, then scale it
steering_vector = steering_vector / jnp.linalg.norm(steering_vector)
steering_vector = steering_vector * STEERING_STRENGTH
def steer_gemma_activations(original_hidden_states, steering_vec):
"""
Intercepts the dense activation in the LLM forward pass,
adds the steering vector, and allows generation to continue.
"""
return original_hidden_states + steering_vec
Conclusion
Through the combined power of Keras 3, JAX, and Google Cloud TPUs, we built a high-throughput framework for peering into the black box of one of the world's leading open-weights models. While standard SFT and LoRA fine-tuning teach a model new behaviors via gradient descent, Mechanistic Interpretability allows us to understand exactly how those behaviors are encoded, and steer them mathematically.
By utilizing hardware like the TPU v5e—which is purpose-built for the massive matrix multiplications required by overcomplete dictionaries—AI safety researchers can iterate on these high-dimensional mechanistic workflows faster and more cost-effectively than ever before.
Check out the full reproducible Colab codebase on GitHub: Google Colab

Comments
Post a Comment