Skip to main content

The Gemma 4 E2B Fine-Tuning Cookbook



A complete, opinionated recipe for adapting Gemma 4 E2B to your domain — from multimodal dataset construction and QLoRA configuration through training loop debugging, evaluation, and production deployment.

April 2026  ·  ~28 min read  ·  HuggingFace / TRL / PEFT
Recipe at a Glance Serves: 1 fine-tuned model
Ingredients (hardware)
  • NVIDIA GPU ≥ 24 GB VRAM
  • System RAM ≥ 32 GB
  • Storage (SSD) ≥ 50 GB free
  • Training dataset 1K–100K samples
  • Python ≥ 3.11
  • CUDA ≥ 12.1
Ingredients (libraries)
  • transformers ≥ 4.51
  • peft ≥ 0.12
  • trl ≥ 0.12
  • bitsandbytes ≥ 0.44
  • accelerate ≥ 0.34
  • wandb / mlflow any
01

Why Fine-Tune E2B Specifically?

The instruct model is already excellent — why spend compute on fine-tuning? Three reasons: domain adaptation, format compliance, and modality specialization.

The base model covers 140+ languages and general reasoning, but it doesn't know your ontology, your entity schema, or your document templates. Fine-tuning bridges that gap with dramatically fewer tokens than few-shot prompting. Because the model is small, a properly fine-tuned E2B can match or exceed a raw 27B model on narrow domain tasks.

ScenarioStrategyDataset SizeExpected Gain
Domain vocabulary / jargonSFT with curated Q&A pairs1K–10K+15–30%
Strict output format (JSON/XML)SFT on format-exemplar pairs500–2K+40–60%
Custom visual document layoutMultimodal SFT with screenshots2K–20K pages+20–35%
Specific audio accent / domainASR fine-tune on audio+transcript10h–100h audio−30–60% WER
Agentic reasoning preferenceDPO/ORPO on tool-call trajectories5K–50K pairs+25–50% task completion
02

Dataset Construction — The Most Important Step

Poor data kills fine-tunes. 500 high-quality, diverse, correctly-formatted samples beat 50,000 scraped, noisy ones every time. Here's how to build a validated multimodal SFT dataset:

dataset_builder.py
Python
from dataclasses import dataclass, field
from PIL import Image
import hashlib
from datasets import Dataset, DatasetDict


@dataclass
class TrainingSample:
    """A single multimodal training conversation."""
    id: str
    messages: list[dict]
    images: list[Image.Image] = field(default_factory=list)
    metadata: dict = field(default_factory=dict)

    def validate(self) -> list[str]:
        errors = []
        roles = [m["role"] for m in self.messages]
        if roles[-1] != "assistant":
            errors.append("Last turn must be assistant")
        if "user" not in roles:
            errors.append("Must have at least one user turn")

        # CRITICAL: no thinking tokens in training targets
        for i, msg in enumerate(self.messages):
            if msg["role"] == "assistant":
                content = str(msg.get("content", ""))
                if "channel" in content:
                    errors.append(
                        f"Turn {i}: thinking tokens in training target — strip them!"
                    )
        return errors


class DatasetBuilder:
    def __init__(self):
        self.samples: list[TrainingSample] = []

    def add_text_qa(self, question: str, answer: str, system: str = ""):
        messages = []
        if system:
            messages.append({"role": "system", "content": system})
        messages.extend([
            {"role": "user",      "content": question},
            {"role": "assistant", "content": answer},
        ])
        self.samples.append(TrainingSample(id=self._id(question), messages=messages))

    def add_multimodal_qa(self, images: list, question: str, answer: str):
        # Images MUST come before text in each turn
        user_content = [{"type": "image"} for _ in images]
        user_content.append({"type": "text", "text": question})
        self.samples.append(TrainingSample(
            id=self._id(question),
            messages=[
                {"role": "user",      "content": user_content},
                {"role": "assistant", "content": answer},
            ],
            images=images,
        ))

    def validate_all(self) -> dict:
        errors = {s.id: s.validate() for s in self.samples if s.validate()}
        return {"total": len(self.samples), "invalid": len(errors), "errors": errors}

    def build_hf_dataset(self, val_split: float = 0.1) -> DatasetDict:
        records = [{"id": s.id, "messages": s.messages, "images": s.images}
                   for s in self.samples]
        full  = Dataset.from_list(records)
        split = full.train_test_split(test_size=val_split, seed=42)
        return DatasetDict({"train": split["train"], "validation": split["test"]})

    def _id(self, text: str) -> str:
        return hashlib.md5((str(len(self.samples)) + text).encode()).hexdigest()[:12]

The 4 Non-Negotiable Rules

  • Strip all thinking tokens from training targets

    Gemma 4 thinking content must never appear in assistant training targets. The validator above catches this. Thinking is inference-time-only — including it in training data corrupts the model's behavior.

  • Use the official chat template for tokenization

    Always call processor.apply_chat_template(messages, tokenize=False). Manually constructing the prompt string will miss special tokens and produce poisoned training signal.

  • Mask prompt tokens in the loss

    Only compute cross-entropy on assistant turns. Use DataCollatorForCompletionOnlyLM from TRL. Training on input tokens wastes gradient budget and can cause prompt memorization.

  • Images must precede their referring text

    Gemma 4's vision encoder processes image tokens before text tokens. Violating this at train time creates a distribution mismatch that degrades inference quality — silently.

03

QLoRA Configuration — The Parameters That Matter

QLoRA (Quantized Low-Rank Adaptation) stays comfortably under 24 GB VRAM at batch size 4. The key insight is that 4-bit base weights + bf16 LoRA adapters gives you the representational capacity you need at a fraction of the compute cost.

qlora_config.py
Python
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig, AutoProcessor, AutoModelForCausalLM
import torch


def build_qlora_model(base_model_id: str = "google/gemma-4-E2B-it"):
    # 4-bit quantization config
    # double_quant: quantize the quantization constants too (~0.4 GB savings)
    # nf4: NormalFloat4, theoretically optimal for normally-distributed weights
    # compute_dtype: actual math in bf16 despite 4-bit storage
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    processor = AutoProcessor.from_pretrained(base_model_id)
    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        quantization_config=bnb_config,
        device_map="auto",
        attn_implementation="flash_attention_2",  # 20-40% speed boost
    )

    # Cast LayerNorm & embeddings to bf16 to avoid dtype mismatch errors
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
    )

    # Target ALL projection matrices — omitting v_proj or o_proj is a common
    # mistake that prevents the model from rewriting attention outputs
    lora_config = LoraConfig(
        r=64,
        lora_alpha=128,        # convention: alpha = 2 × r
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        use_rslora=True,       # rank-stabilized LoRA: normalize by √r
        modules_to_save=["embed_tokens", "lm_head"],
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    # Expected: trainable params: 168,034,304 || trainable%: 3.18
    return processor, model
⚠️ rsLoRA vs Standard LoRA

Always set use_rslora=True. Standard LoRA scales adapter outputs by alpha/r, making larger-rank adapters numerically unstable. rsLoRA normalizes by alpha/√r, allowing higher ranks (r=64, r=128) without needing to re-tune the learning rate.

Choosing LoRA Rank

Task ComplexityRecommended rParam OverheadTraining Time (1K steps)
Simple format compliancer=8~21M params~8 min
Domain knowledge injectionr=32~84M params~20 min
Complex reasoning adaptationr=64~168M params~35 min
Full behavioral overhaulr=128~336M params~65 min
04

Multimodal Fine-Tuning — Vision & Audio Adapters

For truly custom visual domains (microscopy, satellite imagery, proprietary chart styles) adapt the vision encoder alongside the LM backbone. The critical addition is multi_modal_projector in modules_to_save:

multimodal_finetune.py
Python
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch


def build_multimodal_qlora(
    model_id: str,
    adapt_vision_encoder: bool = False,  # True for custom visual domains
    lm_backbone_r: int = 64,
):
    bnb = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb, device_map="auto",
        attn_implementation="flash_attention_2"
    )
    model = prepare_model_for_kbit_training(model)

    target_mods = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ]
    if adapt_vision_encoder:
        # Use lower rank (r=16) for encoders — they're more brittle
        target_mods.extend([
            "vision_tower.encoder.layers.*.self_attn.q_proj",
            "vision_tower.encoder.layers.*.self_attn.v_proj",
        ])

    lora_config = LoraConfig(
        r=lm_backbone_r,
        lora_alpha=lm_backbone_r * 2,
        target_modules=target_mods,
        use_rslora=True,
        # multi_modal_projector maps vision embeddings into the LM token space.
        # ALWAYS include it — without it, visual representations stay misaligned
        # and the model fails to improve on visual tasks regardless of training steps.
        modules_to_save=["embed_tokens", "lm_head", "multi_modal_projector"],
    )
    return get_peft_model(model, lora_config)


def multimodal_data_collator(processor, samples: list[dict]) -> dict:
    """Dynamic token budget + label masking for mixed-modality batches."""
    texts, images = [], []
    for s in samples:
        texts.append(processor.apply_chat_template(
            s["messages"], tokenize=False, add_generation_prompt=False
        ))
        if s.get("images"):
            images.extend(s["images"])

    # Scale token budget to image resolution automatically
    max_px = max((img.width * img.height for img in images), default=0)
    budget = 1120 if max_px > 400_000 else 280 if max_px > 50_000 else 70

    batch = processor(
        text=texts, images=images or None, return_tensors="pt",
        padding=True, truncation=True, max_length=4096,
        image_token_budget=budget,
    )
    # Mask all non-completion tokens in the loss
    batch["labels"] = batch["input_ids"].clone()
    batch["labels"][batch["attention_mask"] == 0] = -100
    return batch
✅ Key Insight: multi_modal_projector

This small linear network maps vision encoder embeddings into the LM token space. Fine-tuning the LM backbone without adapting this bridge leaves visual representations misaligned — the model literally can't connect what it sees to what it says.

05

The Training Loop — With TRL's SFTTrainer

train.py
Python
from trl import SFTTrainer, SFTConfig
from datasets import DatasetDict


def build_trainer(model, processor, dataset: DatasetDict) -> SFTTrainer:
    args = SFTConfig(
        output_dir="./gemma4-e2b-finetuned",

        # Batch & steps
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,   # effective batch = 16
        num_train_epochs=3,

        # Optimizer
        # paged_adamw_8bit saves ~2 GB on optimizer state memory
        # cosine decay beats linear for fine-tuning in practice
        optim="paged_adamw_8bit",
        learning_rate=1e-4,
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        weight_decay=0.01,
        max_grad_norm=1.0,

        # Precision
        bf16=True, tf32=True, fp16=False,

        # Logging & checkpointing
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        report_to="wandb",

        # Sequence packing: pack multiple short samples into one 4096-token
        # sequence. For short Q&A datasets this 3-4x training throughput.
        packing=True,
        max_seq_length=4096,
        dataset_text_field=None,
        remove_unused_columns=False,
    )

    return SFTTrainer(
        model=model,
        args=args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        processing_class=processor,
        data_collator=lambda b: multimodal_data_collator(processor, b),
    )

Learning Rate — The Critical Hyperparameter

LRObserved BehaviorVerdict
1e-3Loss spikes, instability, NaN in ~100 steps❌ Never use
5e-4Fast convergence but overfits small datasets⚠️ Only if dataset > 20K
1e-4Stable convergence, good generalization✅ Default starting point
5e-5Slow but very stable, avoids forgetting✅ Small datasets (< 2K)
1e-5Barely learns — adapters stuck near initialization❌ Too low for LoRA
06

Debugging Training — What Failure Looks Like

Fine-tuning failures have recognizable signatures. Attach this callback to catch the five most common ones automatically:

debug_callbacks.py
Python
import torch
from transformers import TrainerCallback
from collections import deque


class DiagnosticCallback(TrainerCallback):
    """
    Catches the 5 most common Gemma 4 fine-tuning failure modes:
      1. NaN loss        — BF16 overflow, corrupt inputs, or zero-length audio
      2. Loss explosion  — LR too high or gradient norm overflow
      3. Loss flatline   — ALL labels masked (silent killer), LR too low
      4. Gradient spikes — reduce max_grad_norm to 0.5
      5. High eval loss  — catastrophic forgetting; reduce epochs or lora_r
    """
    def __init__(self):
        self.loss_history = deque(maxlen=20)
        self.step = 0

    def on_log(self, args, state, control, logs=None, **kwargs):
        loss = (logs or {}).get("loss")
        grad = (logs or {}).get("grad_norm")
        if loss is None: return

        self.loss_history.append(loss)
        self.step += 1

        if loss != loss:  # NaN check
            print(f"[DIAG] NaN loss at step {self.step}. "
                  "Check: bf16 overflow, zero-size audio chunks, corrupt labels.")

        elif loss > 8.0 and self.step > 50:
            print(f"[DIAG] Loss explosion ({loss:.2f}). Reduce LR by 2×.")

        elif len(self.loss_history) == 20:
            std = torch.tensor(list(self.loss_history)).std().item()
            if std < 0.002 and loss > 2.5:
                print(f"[DIAG] Loss flatline at {loss:.4f} (std={std:.4f}). "
                      "Likely cause: ALL labels masked in DataCollator. "
                      "Check (labels != -100).sum().item() on your first batch.")

        if grad and grad > 10.0:
            print(f"[DIAG] Large gradient norm ({grad:.1f}). "
                  "Reduce max_grad_norm to 0.5 if this persists.")
🔴 Silent Killer: All Labels Masked

If your labels tensor is entirely -100, the loss returns exactly 0.0 — not NaN, not an error. The model never learns, the run looks healthy, and you discover the problem only after wasting GPU-hours. Always verify with (labels != -100).sum().item() on your very first batch before starting a full run.

07

Evaluation Protocol — Beyond Perplexity

Eval loss decreasing does not mean your model is improving at the actual task. Build task-specific evaluation that measures what you actually care about, and always check for catastrophic forgetting:

evaluator.py
Python
import json, re, torch
from collections import defaultdict


class FineTuneEvaluator:
    def __init__(self, processor, model, base_model=None):
        self.processor  = processor
        self.model      = model
        self.base_model = base_model  # for forgetting measurement

    def eval_json_compliance(self, samples: list[dict]) -> dict:
        """Measures: parse rate, schema compliance, and value accuracy."""
        results = defaultdict(list)
        for s in samples:
            output   = self._generate(s["messages"])
            expected = s["expected_json"]

            try:
                parsed = json.loads(output.strip())
                results["parse_rate"].append(1.0)
            except json.JSONDecodeError:
                # Try salvaging by stripping markdown code fences
                clean = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", output, flags=re.DOTALL)
                try:
                    parsed = json.loads(clean.strip())
                    results["parse_rate"].append(0.5)  # partial credit
                except:
                    results["parse_rate"].append(0.0)
                    continue

            missing = set(expected) - set(parsed)
            results["schema_rate"].append(1.0 - len(missing) / len(expected))

            matches = sum(1 for k in expected
                         if str(parsed.get(k, "")).strip() == str(expected[k]).strip())
            results["value_acc"].append(matches / len(expected))

        return {k: sum(v) / len(v) for k, v in results.items()}

    def eval_forgetting(self, mmlu_samples: list[dict], k: int = 200) -> dict:
        """
        Catastrophic forgetting check.
        Delta > -5pp on MMLU = red flag, roll back or reduce epochs.
        """
        if not self.base_model:
            return {"error": "No base model provided"}
        ft_score   = self._score_mcq(self.model,      mmlu_samples[:k])
        base_score = self._score_mcq(self.base_model, mmlu_samples[:k])
        delta = ft_score - base_score
        return {
            "finetuned_acc": ft_score, "base_acc": base_score,
            "delta": delta,
            "status": (
                "CATASTROPHIC_FORGETTING" if delta < -0.05 else
                "MILD_DEGRADATION"        if delta < -0.02 else "OK"
            ),
        }

    def _generate(self, messages) -> str:
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs    = self.processor(text=text, return_tensors="pt").to(self.model.device)
        input_len = inputs["input_ids"].shape[-1]
        with torch.inference_mode():
            out = self.model.generate(**inputs, max_new_tokens=512)
        return self.processor.decode(out[0][input_len:], skip_special_tokens=True)
08

Merging, Quantizing & Shipping to Production

After training you have a base model + LoRA adapter checkpoint. Merge them into a single model for faster inference — no PEFT dependency required at serve time:

merge_and_export.py
Python
from peft import AutoPeftModelForCausalLM
from transformers import AutoProcessor, AutoModelForCausalLM
from pathlib import Path
import torch


def merge_lora_into_base(adapter_dir: str, output_dir: str):
    """
    Merge LoRA adapters into base model weights.
    Memory: ~9.6 GB RAM at merge time. Use device_map="cpu" to avoid OOM.
    Result: a standard HF model — loads without any PEFT imports.
    """
    model = AutoPeftModelForCausalLM.from_pretrained(
        adapter_dir, torch_dtype=torch.bfloat16, device_map="cpu"
    )
    merged = model.merge_and_unload(
        safe_merge=True,   # detect rank-collapse before merging
        progressbar=True,
    )

    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    merged.save_pretrained(str(out), safe_serialization=True, max_shard_size="4GB")

    # Always save the processor too — critical for multimodal models
    AutoProcessor.from_pretrained(adapter_dir).save_pretrained(str(out))
    print(f"Merged model saved to {output_dir}")


# After merging: load identically to the base model, no PEFT needed
processor = AutoProcessor.from_pretrained("./gemma4-e2b-merged")
model = AutoModelForCausalLM.from_pretrained(
    "./gemma4-e2b-merged",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
model.eval()
✅ Pre-Deployment Checklist

(1) Run forgetting eval — confirm delta < 2pp on MMLU.
(2) Test thinking mode on/off — both paths must produce valid output.
(3) Benchmark throughput — merged + Flash Attention 2 should reach 110+ tok/s on RTX 4090.
(4) Check multi-turn — verify thinking tokens don't leak into conversation history.
(5) Smoke test all modalities — text, image, audio, tool-call — even if you only trained on text.

Comments

Popular posts from this blog

Deep Dive into the Google Agent Development Kit (ADK): Features and Code Examples

In our previous overview, we introduced the Google Agent Development Kit (ADK) as a powerful Python framework for building sophisticated AI agents. Now, let's dive deeper into some of the specific features that make ADK a compelling choice for developers looking to create agents that can reason, plan, use tools, and interact effectively with the world. 1. The Core: Configuring the `LlmAgent` The heart of most ADK applications is the LlmAgent (aliased as Agent for convenience). This agent uses a Large Language Model (LLM) for its core reasoning and decision-making. Configuring it effectively is key: name (str): A unique identifier for your agent within the application. model (str | BaseLlm): Specify the LLM to use. You can provide a model name string (like 'gemini-1.5-flash') or an instance of a model class (e.g., Gemini() ). ADK resolves string names using its registry. instruction (str | Callable): This is crucial for guiding the agent's be...

Build Smarter AI Agents Faster: Introducing the Google Agent Development Kit (ADK)

The world is buzzing about AI agents – intelligent entities that can understand goals, make plans, use tools, and interact with the world to get things done. But building truly capable agents that go beyond simple chatbots can be complex. You need to handle Large Language Model (LLM) interactions, manage conversation state, give the agent access to tools (like APIs or code execution), orchestrate complex workflows, and much more. Introducing the Google Agent Development Kit (ADK) , a comprehensive Python framework from Google designed to significantly simplify the process of building, testing, deploying, and managing sophisticated AI agents. Whether you're building a customer service assistant that interacts with your internal APIs, a research agent that can browse the web and summarize findings, or a home automation hub, ADK provides the building blocks you need. Core Concepts: What Makes ADK Tick? ADK is built around several key concepts that make agent development more s...

Curious case of Cisco AnyConnect and WSL2

One thing Covid has taught me is the importance of VPN. Also one other thing COVID has taught me while I work from home  is that your Windows Machine can be brilliant  as long as you have WSL2 configured in it. So imagine my dismay when I realized I cannot access my University resources while being inside the University provided VPN client. Both of the institutions I have affiliation with, requires me to use VPN software which messes up WSL2 configuration (which of course I realized at 1:30 AM). Don't get me wrong, I have faced this multiple times last two years (when I was stuck in India), and mostly I have been lazy and bypassed the actual problem by side-stepping with my not-so-noble  alternatives, which mostly include one of the following: Connect to a physical machine exposed to the internet and do an ssh tunnel from there (not so reliable since this is my actual box sitting at lab desk, also not secure enough) Create a poor man's socks proxy in that same box to have...