A complete, opinionated recipe for adapting Gemma 4 E2B to your domain — from multimodal dataset construction and QLoRA configuration through training loop debugging, evaluation, and production deployment.
- NVIDIA GPU ≥ 24 GB VRAM
- System RAM ≥ 32 GB
- Storage (SSD) ≥ 50 GB free
- Training dataset 1K–100K samples
- Python ≥ 3.11
- CUDA ≥ 12.1
- transformers ≥ 4.51
- peft ≥ 0.12
- trl ≥ 0.12
- bitsandbytes ≥ 0.44
- accelerate ≥ 0.34
- wandb / mlflow any
Contents
Why Fine-Tune E2B Specifically?
The instruct model is already excellent — why spend compute on fine-tuning? Three reasons: domain adaptation, format compliance, and modality specialization.
The base model covers 140+ languages and general reasoning, but it doesn't know your ontology, your entity schema, or your document templates. Fine-tuning bridges that gap with dramatically fewer tokens than few-shot prompting. Because the model is small, a properly fine-tuned E2B can match or exceed a raw 27B model on narrow domain tasks.
| Scenario | Strategy | Dataset Size | Expected Gain |
|---|---|---|---|
| Domain vocabulary / jargon | SFT with curated Q&A pairs | 1K–10K | +15–30% |
| Strict output format (JSON/XML) | SFT on format-exemplar pairs | 500–2K | +40–60% |
| Custom visual document layout | Multimodal SFT with screenshots | 2K–20K pages | +20–35% |
| Specific audio accent / domain | ASR fine-tune on audio+transcript | 10h–100h audio | −30–60% WER |
| Agentic reasoning preference | DPO/ORPO on tool-call trajectories | 5K–50K pairs | +25–50% task completion |
Dataset Construction — The Most Important Step
Poor data kills fine-tunes. 500 high-quality, diverse, correctly-formatted samples beat 50,000 scraped, noisy ones every time. Here's how to build a validated multimodal SFT dataset:
from dataclasses import dataclass, field
from PIL import Image
import hashlib
from datasets import Dataset, DatasetDict
@dataclass
class TrainingSample:
"""A single multimodal training conversation."""
id: str
messages: list[dict]
images: list[Image.Image] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
def validate(self) -> list[str]:
errors = []
roles = [m["role"] for m in self.messages]
if roles[-1] != "assistant":
errors.append("Last turn must be assistant")
if "user" not in roles:
errors.append("Must have at least one user turn")
# CRITICAL: no thinking tokens in training targets
for i, msg in enumerate(self.messages):
if msg["role"] == "assistant":
content = str(msg.get("content", ""))
if "channel" in content:
errors.append(
f"Turn {i}: thinking tokens in training target — strip them!"
)
return errors
class DatasetBuilder:
def __init__(self):
self.samples: list[TrainingSample] = []
def add_text_qa(self, question: str, answer: str, system: str = ""):
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.extend([
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
])
self.samples.append(TrainingSample(id=self._id(question), messages=messages))
def add_multimodal_qa(self, images: list, question: str, answer: str):
# Images MUST come before text in each turn
user_content = [{"type": "image"} for _ in images]
user_content.append({"type": "text", "text": question})
self.samples.append(TrainingSample(
id=self._id(question),
messages=[
{"role": "user", "content": user_content},
{"role": "assistant", "content": answer},
],
images=images,
))
def validate_all(self) -> dict:
errors = {s.id: s.validate() for s in self.samples if s.validate()}
return {"total": len(self.samples), "invalid": len(errors), "errors": errors}
def build_hf_dataset(self, val_split: float = 0.1) -> DatasetDict:
records = [{"id": s.id, "messages": s.messages, "images": s.images}
for s in self.samples]
full = Dataset.from_list(records)
split = full.train_test_split(test_size=val_split, seed=42)
return DatasetDict({"train": split["train"], "validation": split["test"]})
def _id(self, text: str) -> str:
return hashlib.md5((str(len(self.samples)) + text).encode()).hexdigest()[:12]
The 4 Non-Negotiable Rules
-
Strip all thinking tokens from training targets
Gemma 4 thinking content must never appear in assistant training targets. The validator above catches this. Thinking is inference-time-only — including it in training data corrupts the model's behavior.
-
Use the official chat template for tokenization
Always call
processor.apply_chat_template(messages, tokenize=False). Manually constructing the prompt string will miss special tokens and produce poisoned training signal. -
Mask prompt tokens in the loss
Only compute cross-entropy on assistant turns. Use
DataCollatorForCompletionOnlyLMfrom TRL. Training on input tokens wastes gradient budget and can cause prompt memorization. -
Images must precede their referring text
Gemma 4's vision encoder processes image tokens before text tokens. Violating this at train time creates a distribution mismatch that degrades inference quality — silently.
QLoRA Configuration — The Parameters That Matter
QLoRA (Quantized Low-Rank Adaptation) stays comfortably under 24 GB VRAM at batch size 4. The key insight is that 4-bit base weights + bf16 LoRA adapters gives you the representational capacity you need at a fraction of the compute cost.
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig, AutoProcessor, AutoModelForCausalLM
import torch
def build_qlora_model(base_model_id: str = "google/gemma-4-E2B-it"):
# 4-bit quantization config
# double_quant: quantize the quantization constants too (~0.4 GB savings)
# nf4: NormalFloat4, theoretically optimal for normally-distributed weights
# compute_dtype: actual math in bf16 despite 4-bit storage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
processor = AutoProcessor.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="flash_attention_2", # 20-40% speed boost
)
# Cast LayerNorm & embeddings to bf16 to avoid dtype mismatch errors
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
# Target ALL projection matrices — omitting v_proj or o_proj is a common
# mistake that prevents the model from rewriting attention outputs
lora_config = LoraConfig(
r=64,
lora_alpha=128, # convention: alpha = 2 × r
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_dropout=0.05,
bias="none",
use_rslora=True, # rank-stabilized LoRA: normalize by √r
modules_to_save=["embed_tokens", "lm_head"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Expected: trainable params: 168,034,304 || trainable%: 3.18
return processor, model
Always set use_rslora=True. Standard LoRA scales adapter outputs by alpha/r, making larger-rank adapters numerically unstable. rsLoRA normalizes by alpha/√r, allowing higher ranks (r=64, r=128) without needing to re-tune the learning rate.
Choosing LoRA Rank
| Task Complexity | Recommended r | Param Overhead | Training Time (1K steps) |
|---|---|---|---|
| Simple format compliance | r=8 | ~21M params | ~8 min |
| Domain knowledge injection | r=32 | ~84M params | ~20 min |
| Complex reasoning adaptation | r=64 | ~168M params | ~35 min |
| Full behavioral overhaul | r=128 | ~336M params | ~65 min |
Multimodal Fine-Tuning — Vision & Audio Adapters
For truly custom visual domains (microscopy, satellite imagery, proprietary chart styles) adapt the vision encoder alongside the LM backbone. The critical addition is multi_modal_projector in modules_to_save:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
def build_multimodal_qlora(
model_id: str,
adapt_vision_encoder: bool = False, # True for custom visual domains
lm_backbone_r: int = 64,
):
bnb = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb, device_map="auto",
attn_implementation="flash_attention_2"
)
model = prepare_model_for_kbit_training(model)
target_mods = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
if adapt_vision_encoder:
# Use lower rank (r=16) for encoders — they're more brittle
target_mods.extend([
"vision_tower.encoder.layers.*.self_attn.q_proj",
"vision_tower.encoder.layers.*.self_attn.v_proj",
])
lora_config = LoraConfig(
r=lm_backbone_r,
lora_alpha=lm_backbone_r * 2,
target_modules=target_mods,
use_rslora=True,
# multi_modal_projector maps vision embeddings into the LM token space.
# ALWAYS include it — without it, visual representations stay misaligned
# and the model fails to improve on visual tasks regardless of training steps.
modules_to_save=["embed_tokens", "lm_head", "multi_modal_projector"],
)
return get_peft_model(model, lora_config)
def multimodal_data_collator(processor, samples: list[dict]) -> dict:
"""Dynamic token budget + label masking for mixed-modality batches."""
texts, images = [], []
for s in samples:
texts.append(processor.apply_chat_template(
s["messages"], tokenize=False, add_generation_prompt=False
))
if s.get("images"):
images.extend(s["images"])
# Scale token budget to image resolution automatically
max_px = max((img.width * img.height for img in images), default=0)
budget = 1120 if max_px > 400_000 else 280 if max_px > 50_000 else 70
batch = processor(
text=texts, images=images or None, return_tensors="pt",
padding=True, truncation=True, max_length=4096,
image_token_budget=budget,
)
# Mask all non-completion tokens in the loss
batch["labels"] = batch["input_ids"].clone()
batch["labels"][batch["attention_mask"] == 0] = -100
return batch
This small linear network maps vision encoder embeddings into the LM token space. Fine-tuning the LM backbone without adapting this bridge leaves visual representations misaligned — the model literally can't connect what it sees to what it says.
The Training Loop — With TRL's SFTTrainer
from trl import SFTTrainer, SFTConfig
from datasets import DatasetDict
def build_trainer(model, processor, dataset: DatasetDict) -> SFTTrainer:
args = SFTConfig(
output_dir="./gemma4-e2b-finetuned",
# Batch & steps
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8, # effective batch = 16
num_train_epochs=3,
# Optimizer
# paged_adamw_8bit saves ~2 GB on optimizer state memory
# cosine decay beats linear for fine-tuning in practice
optim="paged_adamw_8bit",
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
weight_decay=0.01,
max_grad_norm=1.0,
# Precision
bf16=True, tf32=True, fp16=False,
# Logging & checkpointing
logging_steps=10,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="wandb",
# Sequence packing: pack multiple short samples into one 4096-token
# sequence. For short Q&A datasets this 3-4x training throughput.
packing=True,
max_seq_length=4096,
dataset_text_field=None,
remove_unused_columns=False,
)
return SFTTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
processing_class=processor,
data_collator=lambda b: multimodal_data_collator(processor, b),
)
Learning Rate — The Critical Hyperparameter
| LR | Observed Behavior | Verdict |
|---|---|---|
1e-3 | Loss spikes, instability, NaN in ~100 steps | ❌ Never use |
5e-4 | Fast convergence but overfits small datasets | ⚠️ Only if dataset > 20K |
1e-4 | Stable convergence, good generalization | ✅ Default starting point |
5e-5 | Slow but very stable, avoids forgetting | ✅ Small datasets (< 2K) |
1e-5 | Barely learns — adapters stuck near initialization | ❌ Too low for LoRA |
Debugging Training — What Failure Looks Like
Fine-tuning failures have recognizable signatures. Attach this callback to catch the five most common ones automatically:
import torch
from transformers import TrainerCallback
from collections import deque
class DiagnosticCallback(TrainerCallback):
"""
Catches the 5 most common Gemma 4 fine-tuning failure modes:
1. NaN loss — BF16 overflow, corrupt inputs, or zero-length audio
2. Loss explosion — LR too high or gradient norm overflow
3. Loss flatline — ALL labels masked (silent killer), LR too low
4. Gradient spikes — reduce max_grad_norm to 0.5
5. High eval loss — catastrophic forgetting; reduce epochs or lora_r
"""
def __init__(self):
self.loss_history = deque(maxlen=20)
self.step = 0
def on_log(self, args, state, control, logs=None, **kwargs):
loss = (logs or {}).get("loss")
grad = (logs or {}).get("grad_norm")
if loss is None: return
self.loss_history.append(loss)
self.step += 1
if loss != loss: # NaN check
print(f"[DIAG] NaN loss at step {self.step}. "
"Check: bf16 overflow, zero-size audio chunks, corrupt labels.")
elif loss > 8.0 and self.step > 50:
print(f"[DIAG] Loss explosion ({loss:.2f}). Reduce LR by 2×.")
elif len(self.loss_history) == 20:
std = torch.tensor(list(self.loss_history)).std().item()
if std < 0.002 and loss > 2.5:
print(f"[DIAG] Loss flatline at {loss:.4f} (std={std:.4f}). "
"Likely cause: ALL labels masked in DataCollator. "
"Check (labels != -100).sum().item() on your first batch.")
if grad and grad > 10.0:
print(f"[DIAG] Large gradient norm ({grad:.1f}). "
"Reduce max_grad_norm to 0.5 if this persists.")
If your labels tensor is entirely -100, the loss returns exactly 0.0 — not NaN, not an error. The model never learns, the run looks healthy, and you discover the problem only after wasting GPU-hours. Always verify with (labels != -100).sum().item() on your very first batch before starting a full run.
Evaluation Protocol — Beyond Perplexity
Eval loss decreasing does not mean your model is improving at the actual task. Build task-specific evaluation that measures what you actually care about, and always check for catastrophic forgetting:
import json, re, torch
from collections import defaultdict
class FineTuneEvaluator:
def __init__(self, processor, model, base_model=None):
self.processor = processor
self.model = model
self.base_model = base_model # for forgetting measurement
def eval_json_compliance(self, samples: list[dict]) -> dict:
"""Measures: parse rate, schema compliance, and value accuracy."""
results = defaultdict(list)
for s in samples:
output = self._generate(s["messages"])
expected = s["expected_json"]
try:
parsed = json.loads(output.strip())
results["parse_rate"].append(1.0)
except json.JSONDecodeError:
# Try salvaging by stripping markdown code fences
clean = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", output, flags=re.DOTALL)
try:
parsed = json.loads(clean.strip())
results["parse_rate"].append(0.5) # partial credit
except:
results["parse_rate"].append(0.0)
continue
missing = set(expected) - set(parsed)
results["schema_rate"].append(1.0 - len(missing) / len(expected))
matches = sum(1 for k in expected
if str(parsed.get(k, "")).strip() == str(expected[k]).strip())
results["value_acc"].append(matches / len(expected))
return {k: sum(v) / len(v) for k, v in results.items()}
def eval_forgetting(self, mmlu_samples: list[dict], k: int = 200) -> dict:
"""
Catastrophic forgetting check.
Delta > -5pp on MMLU = red flag, roll back or reduce epochs.
"""
if not self.base_model:
return {"error": "No base model provided"}
ft_score = self._score_mcq(self.model, mmlu_samples[:k])
base_score = self._score_mcq(self.base_model, mmlu_samples[:k])
delta = ft_score - base_score
return {
"finetuned_acc": ft_score, "base_acc": base_score,
"delta": delta,
"status": (
"CATASTROPHIC_FORGETTING" if delta < -0.05 else
"MILD_DEGRADATION" if delta < -0.02 else "OK"
),
}
def _generate(self, messages) -> str:
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.processor(text=text, return_tensors="pt").to(self.model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
out = self.model.generate(**inputs, max_new_tokens=512)
return self.processor.decode(out[0][input_len:], skip_special_tokens=True)
Merging, Quantizing & Shipping to Production
After training you have a base model + LoRA adapter checkpoint. Merge them into a single model for faster inference — no PEFT dependency required at serve time:
from peft import AutoPeftModelForCausalLM
from transformers import AutoProcessor, AutoModelForCausalLM
from pathlib import Path
import torch
def merge_lora_into_base(adapter_dir: str, output_dir: str):
"""
Merge LoRA adapters into base model weights.
Memory: ~9.6 GB RAM at merge time. Use device_map="cpu" to avoid OOM.
Result: a standard HF model — loads without any PEFT imports.
"""
model = AutoPeftModelForCausalLM.from_pretrained(
adapter_dir, torch_dtype=torch.bfloat16, device_map="cpu"
)
merged = model.merge_and_unload(
safe_merge=True, # detect rank-collapse before merging
progressbar=True,
)
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
merged.save_pretrained(str(out), safe_serialization=True, max_shard_size="4GB")
# Always save the processor too — critical for multimodal models
AutoProcessor.from_pretrained(adapter_dir).save_pretrained(str(out))
print(f"Merged model saved to {output_dir}")
# After merging: load identically to the base model, no PEFT needed
processor = AutoProcessor.from_pretrained("./gemma4-e2b-merged")
model = AutoModelForCausalLM.from_pretrained(
"./gemma4-e2b-merged",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
)
model.eval()
(1) Run forgetting eval — confirm delta < 2pp on MMLU.
(2) Test thinking mode on/off — both paths must produce valid output.
(3) Benchmark throughput — merged + Flash Attention 2 should reach 110+ tok/s on RTX 4090.
(4) Check multi-turn — verify thinking tokens don't leak into conversation history.
(5) Smoke test all modalities — text, image, audio, tool-call — even if you only trained on text.
Comments
Post a Comment