End-to-End Gemma 2B LoRA Fine-Tuning and Serving on GPU & TPU
If you have ever prototyped a Large Language Model (LLM) on your local GPU and then spent days rewriting your code to scale it on a Google Cloud TPU, you know the pain of hardware lock-in. For the Google TPU Sprint, I wanted to build a solution to this exact problem.
This project provides a lightweight, end-to-end pipeline for fine-tuning Google's Gemma 2B model using LoRA (Low-Rank Adaptation) and serving it via a custom REST API. By leveraging KerasNLP and the JAX backend, we can write our training and inference code once, and execute it natively on both local NVIDIA GPUs (like the RTX 6000) and Google Cloud TPUs.
⚡ Why the Keras 3 + JAX Stack?
Keras 3 was rewritten to act as a "super-connector" that can run on top of PyTorch, TensorFlow, or JAX without changing the code. By explicitly setting our backend to JAX (os.environ["KERAS_BACKEND"] = "jax"), we unlock massive performance gains and hardware flexibility.
- XLA Compilation: JAX translates our Python code into XLA (Accelerated Linear Algebra). XLA is a domain-specific compiler that aggressively optimizes mathematical computations. It fuses multiple operations (like MatMul + ReLU) into a single kernel, avoiding expensive read/write trips to the hardware's High-Bandwidth Memory (HBM).
- Hardware Agnosticism: XLA abstracts the underlying hardware layer. The exact same computation graph generated by JAX on your local RTX 6000 can be compiled by XLA to run natively on a Google Cloud TPU's Matrix Multiply Units (MXUs).
- Throughput Improvements: Switching to the JAX backend in Keras 3 (which is XLA-compiled by default) often yields a 1.5x to 2.5x throughput improvement compared to eager execution frameworks.
📊 Real-World Training Run & Logs
To prove how efficient this pipeline is, I ran the fine-tuning script locally on an NVIDIA workstation. Here is the actual terminal output:
python3 train_local.py
Preparing data...
Loading Gemma 2B model from Hugging Face...
model-00001-of-00002.safetensors: 100%|████████| 4.95G/4.95G [00:06<00:00 -="" .="" 0.4375="" 0.5411="" 100="" 138ms="" 2026-03-23="" 20:00:25.251164:="" 250="" 45.3mb="" 45s="" 67.1m="" 818mb="" all="" around="" autotuning="" because="" by="" code="" compiling="" configs="" done="" dot_search_space.cc:200="" external="" filtered="" fine-tuned="" fine-tuning...="" full="" gemma-2b-dolly-lora="" gpu="" hints.="" hints="" instead.="" lora="" loss:="" match="" model-00002-of-00002.safetensors:="" model...="" model="" none="" of="" out="" s="" saved="" saving="" service="" set="" sparse_categorical_accuracy:="" starting="" step="" sufficiently="" the="" them="" this="" to="" using="" w="" weights...="" were="" working="" xla="">00:00>
Breaking down the logs:
- Speed: Notice the training time: 45 seconds for 250 steps (138ms/step). Because we used LoRA to freeze the base model and only train a fraction of the parameters, we achieved blazing-fast iteration speeds on a single GPU.
- The XLA Warning: You'll notice a warning about
dot_search_space.cc. This is actually a feature, not a bug! XLA utilizes an "autotuner" that runs behind the scenes during compilation. It tests various low-level kernel configurations (like tile sizes) for dot-products to see which runs fastest on your specific hardware. The warning simply means the autotuner expanded its search space to find the optimal execution path for the RTX GPU. When you move this to a TPU, XLA will autotune specifically for the TPU's architecture.
🏗️ Architecture Overview
Below is the visual representation of how data flows through our write-once, scale-everywhere pipeline:
🚀 The Code: Training & Serving
Phase 1: Training (train_local.py)
import os
# Set the backend to JAX for native GPU/TPU compatibility
os.environ["KERAS_BACKEND"] = "jax"
import json
import urllib.request
import keras
import keras_nlp
def prepare_dataset():
url = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
filename = "databricks-dolly-15k.jsonl"
if not os.path.exists(filename):
urllib.request.urlretrieve(url, filename)
data = []
with open(filename, "r") as f:
for i, line in enumerate(f):
if i >= 1000: break
item = json.loads(line)
prompt = f"Instruction:\n{item['instruction']}\n\nResponse:\n{item['response']}"
data.append(prompt)
return data
def main():
train_data = prepare_dataset()
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
# Enable LoRA with rank 16
gemma_lm.backbone.enable_lora(rank=16)
# Static sequence length is REQUIRED by XLA to avoid recompilation
gemma_lm.preprocessor.sequence_length = 512
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.AdamW(learning_rate=5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train_data, epochs=1, batch_size=4)
gemma_lm.save_to_preset("./gemma-2b-dolly-lora")
if __name__ == "__main__":
main()
Phase 2: Serving (serve.py)
Once fine-tuning is complete, the serving script loads the localized weights. If you run this on a Cloud TPU VM, XLA automatically compiles the inference graph natively for the TPU's architecture. Notice the fixed max_tokens variable, which continues our design pattern of enforcing static shapes—a strict requirement for XLA to prevent graph recompilation.
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import keras_nlp
from http.server import HTTPServer, BaseHTTPRequestHandler
import json
MODEL_DIR = "./gemma-2b-dolly-lora"
HOST = "0.0.0.0"
PORT = 8080
gemma_lm = None
def load_model():
global gemma_lm
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(MODEL_DIR)
gemma_lm.compile(sampler="top_k")
class InferenceHandler(BaseHTTPRequestHandler):
def do_POST(self):
content_length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(content_length))
prompt = body.get("prompt", "")
# Enforcing static shapes during inference for XLA efficiency
max_length = body.get("max_tokens", 256)
output = gemma_lm.generate(prompt, max_length=max_length)
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps({"generated_text": output}).encode())
def main():
load_model()
server = HTTPServer((HOST, PORT), InferenceHandler)
server.serve_forever()
if __name__ == "__main__":
main()
Start the server by running python3 serve.py. You can then query your fine-tuned model using a simple cURL request:
curl -X POST http://localhost:8080/ \
-H "Content-Type: application/json" \
-d '{"prompt": "Instruction:\nExplain the difference between a GPU and a TPU.\n\nResponse:\n", "max_tokens": 256}'
💡 Why This Matters for the TPU Sprint
The developer experience (DevEx) of AI engineering often suffers from hardware fragmentation. By unifying the tech stack with Keras 3 and JAX, developers can iterate cheaply and quickly on a local workstation, and deploy to Google's massive TPU infrastructure for production inference without changing a single line of model logic. Understanding XLA's autotuning and static shape requirements is the key to unlocking this massive scale.
Comments
Post a Comment