Skip to main content

Write Once, Scale Everywhere

End-to-End Gemma 2B LoRA Fine-Tuning and Serving on GPU & TPU

https://qwen-qwen-image-2512.hf.space/gradio_api/file=/tmp/gradio/616bece13412f26883c5510ac9dde5f8b23aee9ac4f6dcfd2baec5c2b67d9517/image.webp 

If you have ever prototyped a Large Language Model (LLM) on your local GPU and then spent days rewriting your code to scale it on a Google Cloud TPU, you know the pain of hardware lock-in. For the Google TPU Sprint, I wanted to build a solution to this exact problem.

This project provides a lightweight, end-to-end pipeline for fine-tuning Google's Gemma 2B model using LoRA (Low-Rank Adaptation) and serving it via a custom REST API. By leveraging KerasNLP and the JAX backend, we can write our training and inference code once, and execute it natively on both local NVIDIA GPUs (like the RTX 6000) and Google Cloud TPUs.

⚡ Why the Keras 3 + JAX Stack?

Keras 3 was rewritten to act as a "super-connector" that can run on top of PyTorch, TensorFlow, or JAX without changing the code. By explicitly setting our backend to JAX (os.environ["KERAS_BACKEND"] = "jax"), we unlock massive performance gains and hardware flexibility.

  • XLA Compilation: JAX translates our Python code into XLA (Accelerated Linear Algebra). XLA is a domain-specific compiler that aggressively optimizes mathematical computations. It fuses multiple operations (like MatMul + ReLU) into a single kernel, avoiding expensive read/write trips to the hardware's High-Bandwidth Memory (HBM).
  • Hardware Agnosticism: XLA abstracts the underlying hardware layer. The exact same computation graph generated by JAX on your local RTX 6000 can be compiled by XLA to run natively on a Google Cloud TPU's Matrix Multiply Units (MXUs).
  • Throughput Improvements: Switching to the JAX backend in Keras 3 (which is XLA-compiled by default) often yields a 1.5x to 2.5x throughput improvement compared to eager execution frameworks.

📊 Real-World Training Run & Logs

To prove how efficient this pipeline is, I ran the fine-tuning script locally on an NVIDIA workstation. Here is the actual terminal output:

python3 train_local.py 
Preparing data...
Loading Gemma 2B model from Hugging Face...
model-00001-of-00002.safetensors: 100%|████████| 4.95G/4.95G [00:06<00:00 -="" .="" 0.4375="" 0.5411="" 100="" 138ms="" 2026-03-23="" 20:00:25.251164:="" 250="" 45.3mb="" 45s="" 67.1m="" 818mb="" all="" around="" autotuning="" because="" by="" code="" compiling="" configs="" done="" dot_search_space.cc:200="" external="" filtered="" fine-tuned="" fine-tuning...="" full="" gemma-2b-dolly-lora="" gpu="" hints.="" hints="" instead.="" lora="" loss:="" match="" model-00002-of-00002.safetensors:="" model...="" model="" none="" of="" out="" s="" saved="" saving="" service="" set="" sparse_categorical_accuracy:="" starting="" step="" sufficiently="" the="" them="" this="" to="" using="" w="" weights...="" were="" working="" xla="">

Breaking down the logs:

  • Speed: Notice the training time: 45 seconds for 250 steps (138ms/step). Because we used LoRA to freeze the base model and only train a fraction of the parameters, we achieved blazing-fast iteration speeds on a single GPU.
  • The XLA Warning: You'll notice a warning about dot_search_space.cc. This is actually a feature, not a bug! XLA utilizes an "autotuner" that runs behind the scenes during compilation. It tests various low-level kernel configurations (like tile sizes) for dot-products to see which runs fastest on your specific hardware. The warning simply means the autotuner expanded its search space to find the optimal execution path for the RTX GPU. When you move this to a TPU, XLA will autotune specifically for the TPU's architecture.

🏗️ Architecture Overview

Below is the visual representation of how data flows through our write-once, scale-everywhere pipeline:

graph TD subgraph Training Phase [Training Environment GPU/TPU] DS[(Databricks Dolly 15k Dataset)] --> T[train_local.py] HF((Hugging Face Hub google/gemma-2b)) --> T T -->|LoRA Fine-tuning Rank 16| M[(Fine-Tuned Model ./gemma-2b-dolly-lora)] end subgraph Serving Phase [Inference Environment GPU/TPU] M --> S[serve.py] S --> API[HTTP Server Port 8080] end Client([Client Application]) <-->|POST / JSON| API

🚀 The Code: Training & Serving

Phase 1: Training (train_local.py)

import os
# Set the backend to JAX for native GPU/TPU compatibility
os.environ["KERAS_BACKEND"] = "jax"

import json
import urllib.request
import keras
import keras_nlp

def prepare_dataset():
    url = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
    filename = "databricks-dolly-15k.jsonl"
    if not os.path.exists(filename):
        urllib.request.urlretrieve(url, filename)

    data = []
    with open(filename, "r") as f:
        for i, line in enumerate(f):
            if i >= 1000: break
            item = json.loads(line)
            prompt = f"Instruction:\n{item['instruction']}\n\nResponse:\n{item['response']}"
            data.append(prompt)
    return data

def main():
    train_data = prepare_dataset()
    
    gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
    
    # Enable LoRA with rank 16
    gemma_lm.backbone.enable_lora(rank=16)
    
    # Static sequence length is REQUIRED by XLA to avoid recompilation
    gemma_lm.preprocessor.sequence_length = 512

    gemma_lm.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.AdamW(learning_rate=5e-5),
        weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )

    gemma_lm.fit(train_data, epochs=1, batch_size=4)
    gemma_lm.save_to_preset("./gemma-2b-dolly-lora")

if __name__ == "__main__":
    main()

Phase 2: Serving (serve.py)

Once fine-tuning is complete, the serving script loads the localized weights. If you run this on a Cloud TPU VM, XLA automatically compiles the inference graph natively for the TPU's architecture. Notice the fixed max_tokens variable, which continues our design pattern of enforcing static shapes—a strict requirement for XLA to prevent graph recompilation.

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
import keras_nlp
from http.server import HTTPServer, BaseHTTPRequestHandler
import json

MODEL_DIR = "./gemma-2b-dolly-lora"
HOST = "0.0.0.0"
PORT = 8080
gemma_lm = None

def load_model():
    global gemma_lm
    gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(MODEL_DIR)
    gemma_lm.compile(sampler="top_k")

class InferenceHandler(BaseHTTPRequestHandler):
    def do_POST(self):
        content_length = int(self.headers.get("Content-Length", 0))
        body = json.loads(self.rfile.read(content_length))
        prompt = body.get("prompt", "")
        
        # Enforcing static shapes during inference for XLA efficiency
        max_length = body.get("max_tokens", 256)
        
        output = gemma_lm.generate(prompt, max_length=max_length)
        
        self.send_response(200)
        self.send_header("Content-Type", "application/json")
        self.end_headers()
        self.wfile.write(json.dumps({"generated_text": output}).encode())

def main():
    load_model()
    server = HTTPServer((HOST, PORT), InferenceHandler)
    server.serve_forever()

if __name__ == "__main__":
    main()

Start the server by running python3 serve.py. You can then query your fine-tuned model using a simple cURL request:

curl -X POST http://localhost:8080/ \
     -H "Content-Type: application/json" \
     -d '{"prompt": "Instruction:\nExplain the difference between a GPU and a TPU.\n\nResponse:\n", "max_tokens": 256}'

💡 Why This Matters for the TPU Sprint

The developer experience (DevEx) of AI engineering often suffers from hardware fragmentation. By unifying the tech stack with Keras 3 and JAX, developers can iterate cheaply and quickly on a local workstation, and deploy to Google's massive TPU infrastructure for production inference without changing a single line of model logic. Understanding XLA's autotuning and static shape requirements is the key to unlocking this massive scale.

Comments

Popular posts from this blog

Deep Dive into the Google Agent Development Kit (ADK): Features and Code Examples

In our previous overview, we introduced the Google Agent Development Kit (ADK) as a powerful Python framework for building sophisticated AI agents. Now, let's dive deeper into some of the specific features that make ADK a compelling choice for developers looking to create agents that can reason, plan, use tools, and interact effectively with the world. 1. The Core: Configuring the `LlmAgent` The heart of most ADK applications is the LlmAgent (aliased as Agent for convenience). This agent uses a Large Language Model (LLM) for its core reasoning and decision-making. Configuring it effectively is key: name (str): A unique identifier for your agent within the application. model (str | BaseLlm): Specify the LLM to use. You can provide a model name string (like 'gemini-1.5-flash') or an instance of a model class (e.g., Gemini() ). ADK resolves string names using its registry. instruction (str | Callable): This is crucial for guiding the agent's be...

Build Smarter AI Agents Faster: Introducing the Google Agent Development Kit (ADK)

The world is buzzing about AI agents – intelligent entities that can understand goals, make plans, use tools, and interact with the world to get things done. But building truly capable agents that go beyond simple chatbots can be complex. You need to handle Large Language Model (LLM) interactions, manage conversation state, give the agent access to tools (like APIs or code execution), orchestrate complex workflows, and much more. Introducing the Google Agent Development Kit (ADK) , a comprehensive Python framework from Google designed to significantly simplify the process of building, testing, deploying, and managing sophisticated AI agents. Whether you're building a customer service assistant that interacts with your internal APIs, a research agent that can browse the web and summarize findings, or a home automation hub, ADK provides the building blocks you need. Core Concepts: What Makes ADK Tick? ADK is built around several key concepts that make agent development more s...

Curious case of Cisco AnyConnect and WSL2

One thing Covid has taught me is the importance of VPN. Also one other thing COVID has taught me while I work from home  is that your Windows Machine can be brilliant  as long as you have WSL2 configured in it. So imagine my dismay when I realized I cannot access my University resources while being inside the University provided VPN client. Both of the institutions I have affiliation with, requires me to use VPN software which messes up WSL2 configuration (which of course I realized at 1:30 AM). Don't get me wrong, I have faced this multiple times last two years (when I was stuck in India), and mostly I have been lazy and bypassed the actual problem by side-stepping with my not-so-noble  alternatives, which mostly include one of the following: Connect to a physical machine exposed to the internet and do an ssh tunnel from there (not so reliable since this is my actual box sitting at lab desk, also not secure enough) Create a poor man's socks proxy in that same box to have...