Skip to main content

Posts

Showing posts from March, 2026

Write Once, Scale Everywhere

End-to-End Gemma 2B LoRA Fine-Tuning and Serving on GPU & TPU If you have ever prototyped a Large Language Model (LLM) on your local GPU and then spent days rewriting your code to scale it on a Google Cloud TPU , you know the pain of hardware lock-in. For the Google TPU Sprint, I wanted to build a solution to this exact problem. This project provides a lightweight, end-to-end pipeline for fine-tuning Google's Gemma 2B model using LoRA (Low-Rank Adaptation) and serving it via a custom REST API. By leveraging KerasNLP and the JAX backend, we can write our training and inference code once, and execute it natively on both local NVIDIA GPUs (like the RTX 6000) and Google Cloud TPUs. ⚡ Why the Keras 3 + JAX Stack? Keras 3 was rewritten to act as a "super-connector" that can run on top of PyTorch, TensorFlow, or JAX without changing the code. By explicitly setting our backend to JAX ( os.environ["KERAS_BACKEND"] = "jax" )...