Trainer API
QuantLLM provides a comprehensive training API with built-in support for quantization, efficient fine-tuning, and progress tracking.
Fine-Tuning Trainer
- class quantllm.trainer.trainer.FineTuningTrainer(model, training_config, train_dataloader, eval_dataloader=None, checkpoint_manager=None, hub_manager=None, device=None, use_wandb=False, wandb_config=None)[source]
Bases:
object- __init__(model, training_config, train_dataloader, eval_dataloader=None, checkpoint_manager=None, hub_manager=None, device=None, use_wandb=False, wandb_config=None)[source]
Initialize the trainer with PyTorch-based training loop.
- Parameters:
model (nn.Module) – The model to train
training_config (TrainingConfig) – Training configuration
train_dataloader (DataLoader) – Training data loader
eval_dataloader (DataLoader, optional) – Evaluation data loader
checkpoint_manager (CheckpointManager, optional) – Checkpoint manager
hub_manager (HubManager, optional) – Hub manager for model pushing
device (str or torch.device, optional) – Device to train on
use_wandb (bool) – Whether to use Weights & Biases
wandb_config (Dict[str, Any], optional) – Weights & Biases configuration
Model Evaluator
- class quantllm.trainer.evaluator.ModelEvaluator(model, eval_dataloader, metrics=None, device=None)[source]
Bases:
object- __init__(model, eval_dataloader, metrics=None, device=None)[source]
Initialize the model evaluator.
- Parameters:
model (nn.Module) – The model to evaluate
eval_dataloader (DataLoader) – Evaluation data loader
metrics (List[Callable], optional) – List of metric functions
device (str, optional) – Device to evaluate on
Training Logger
Example Usage
Complete Training Pipeline
from quantllm import (
Model, ModelConfig,
FineTuningTrainer, TrainingConfig,
TrainingLogger, CheckpointManager
)
# Initialize logger for beautiful progress display
logger = TrainingLogger()
# Configure model with advanced optimizations
config = ModelConfig(
model_name="facebook/opt-125m",
load_in_4bit=True, # Memory efficient!
use_lora=True, # Parameter efficient!
gradient_checkpointing=True # Training efficient!
)
# Initialize training with rich features
training_config = TrainingConfig(
learning_rate=2e-4,
num_epochs=3,
batch_size=8,
gradient_accumulation_steps=4,
# Advanced features
warmup_ratio=0.1,
evaluation_strategy="steps",
eval_steps=100,
save_strategy="epoch",
logging_steps=10,
# Mixed precision training
fp16=True,
# Multi-GPU support
ddp_find_unused_parameters=False
)
# Setup checkpointing
checkpoint_manager = CheckpointManager(
checkpoint_dir="./checkpoints",
save_total_limit=3
)
# Initialize and train
trainer = FineTuningTrainer(
model=model,
training_config=training_config,
train_dataloader=train_loader,
eval_dataloader=val_loader,
logger=logger,
checkpoint_manager=checkpoint_manager
)
# Start training with full monitoring
trainer.train()
Basic Training
from quantllm import FineTuningTrainer, TrainingConfig
config = TrainingConfig(
learning_rate=2e-4,
num_epochs=3,
batch_size=8,
gradient_accumulation_steps=4
)
trainer = FineTuningTrainer(
model=model,
training_config=config,
train_dataloader=train_loader,
eval_dataloader=val_loader
)
trainer.train()
With Progress Tracking
from quantllm import FineTuningTrainer, TrainingLogger
logger = TrainingLogger()
trainer = FineTuningTrainer(
model=model,
training_config=config,
train_dataloader=train_loader,
eval_dataloader=val_loader,
logger=logger
)
trainer.train()
Model Evaluation
from quantllm import ModelEvaluator
evaluator = ModelEvaluator(
model=model,
eval_dataloader=test_loader
)
metrics = evaluator.evaluate()
Checkpoint Management
from quantllm import CheckpointManager
checkpoint_manager = CheckpointManager(
checkpoint_dir="./checkpoints",
save_total_limit=3
)
trainer = FineTuningTrainer(
model=model,
checkpoint_manager=checkpoint_manager,
...
)
Custom Training Loop
class CustomTrainer(FineTuningTrainer):
def training_step(self, batch):
# Custom training logic
pass
def validation_step(self, batch):
# Custom validation logic
pass
trainer = CustomTrainer(model=model, ...)