Memory Optimization¶
This document covers the memory optimization techniques used in our LoRA fine-tuning implementation, focusing on practical approaches that enable efficient training of large language models on consumer GPUs.
💾 Memory Usage Overview¶
Actual Memory Breakdown¶
For LoRA fine-tuning of Phi-4-mini-instruct (4B parameters) on a consumer GPU:
Base Model (4-bit quantized): ~3.5GB
LoRA Adapters (r=16): ~50MB
Optimizer States (AdamW): ~100MB
Gradients: ~50MB
Activation Cache: ~1-2GB
Input Batch (batch_size=4): ~500MB-1GB
---
Total: ~6-7GB
Our Memory Optimization Strategy¶
We use a layered approach that leverages existing libraries:
- 4-bit Quantization (BitsAndBytesConfig)
- LoRA Parameter Efficiency (PEFT)
- Gradient Checkpointing (Built into Trainer)
🔧 Implementation¶
1. 4-bit Quantization with BitsAndBytesConfig¶
Our primary memory saver is quantized LoRA (QLoRA) from the BitsAndBytes package. This library loads the pre-trained base/instruction model at 4-bit precision rather than the full 16-bit (or 32-bit) precision. This reduces model size by ~75%. The LoRA weights and gradients are at the typical bfloat-16 precision.
def setup_model(model_name: str, seed: int):
"""Setup model with 4-bit quantization for memory efficiency."""
# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Use 4-bit quantization
bnb_4bit_compute_dtype=torch.float16, # Compute in float16
bnb_4bit_use_double_quant=True, # Double quantization for extra savings
bnb_4bit_quant_type="nf4", # NormalFloat4 - best for neural networks
)
# Load quantized model
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Automatic device placement
quantization_config=bnb_config,
dtype=torch.float16,
)
return model, tokenizer
# Memory savings: ~75% reduction in model memory
2. LoRA Parameter Efficiency¶
Only train a small fraction of parameters:
def setup_lora(model, cfg):
"""Apply LoRA - only ~0.1% of parameters become trainable."""
# Prepare quantized model for LoRA
model = prepare_model_for_kbit_training(model)
model.config.use_cache = False # Disable KV cache for training
# LoRA configuration
peft_config = LoraConfig(
r=cfg.r, # Low rank (16) - smaller = less memory
lora_alpha=cfg.alpha, # Scaling factor (32)
target_modules=cfg.target_modules, # Which layers to adapt
lora_dropout=cfg.dropout, # Regularization
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model
# Example output:
# trainable params: 83,886,080 || all params: 14,888,534,016 || trainable%: 0.56%
3. Gradient Checkpointing¶
During training, the computer keeps the track of the pre- and post-layer gradients so that calculating the backward pass is quick. However, the gradient matricies are large and may cause the GPU to run out of memory (OOM).
Gradient checkpointing is a compromise where the gradients are only stored every few layers to save memory. During the backward pass, the missing gradients are recalculated using the checkpoints. This balances GPU memory versus training speed.
Gradient checkpointing is automatically enabled in our transformers Trainer configuration.
This comprehensive memory optimization guide ensures efficient training even on limited hardware resources while maintaining model performance and training stability.