Trainer API

QuantLLM provides a comprehensive training API with built-in support for quantization, efficient fine-tuning, and progress tracking.

Fine-Tuning Trainer

class quantllm.trainer.trainer.FineTuningTrainer(model, training_config, train_dataloader, eval_dataloader=None, checkpoint_manager=None, hub_manager=None, device=None, use_wandb=False, wandb_config=None)[source]

Bases: object

__init__(model, training_config, train_dataloader, eval_dataloader=None, checkpoint_manager=None, hub_manager=None, device=None, use_wandb=False, wandb_config=None)[source]

Initialize the trainer with PyTorch-based training loop.

Parameters:
  • model (nn.Module) – The model to train

  • training_config (TrainingConfig) – Training configuration

  • train_dataloader (DataLoader) – Training data loader

  • eval_dataloader (DataLoader, optional) – Evaluation data loader

  • checkpoint_manager (CheckpointManager, optional) – Checkpoint manager

  • hub_manager (HubManager, optional) – Hub manager for model pushing

  • device (str or torch.device, optional) – Device to train on

  • use_wandb (bool) – Whether to use Weights & Biases

  • wandb_config (Dict[str, Any], optional) – Weights & Biases configuration

train_step(batch, scaler)[source]

Single training step.

train()[source]

Train the model.

save_model(output_dir)[source]

Save the model and training state.

load_model(input_dir)[source]

Load the model and training state.

Model Evaluator

class quantllm.trainer.evaluator.ModelEvaluator(model, eval_dataloader, metrics=None, device=None)[source]

Bases: object

__init__(model, eval_dataloader, metrics=None, device=None)[source]

Initialize the model evaluator.

Parameters:
  • model (nn.Module) – The model to evaluate

  • eval_dataloader (DataLoader) – Evaluation data loader

  • metrics (List[Callable], optional) – List of metric functions

  • device (str, optional) – Device to evaluate on

evaluate()[source]

Evaluate the model on the evaluation dataset.

Return type:

Dict[str, float]

evaluate_on_specific_batch(batch)[source]

Evaluate the model on a specific batch of data.

Return type:

Dict[str, float]

Training Logger

Example Usage

Complete Training Pipeline

from quantllm import (
    Model, ModelConfig,
    FineTuningTrainer, TrainingConfig,
    TrainingLogger, CheckpointManager
)

# Initialize logger for beautiful progress display
logger = TrainingLogger()

# Configure model with advanced optimizations
config = ModelConfig(
    model_name="facebook/opt-125m",
    load_in_4bit=True,         # Memory efficient!
    use_lora=True,             # Parameter efficient!
    gradient_checkpointing=True # Training efficient!
)

# Initialize training with rich features
training_config = TrainingConfig(
    learning_rate=2e-4,
    num_epochs=3,
    batch_size=8,
    gradient_accumulation_steps=4,
    # Advanced features
    warmup_ratio=0.1,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="epoch",
    logging_steps=10,
    # Mixed precision training
    fp16=True,
    # Multi-GPU support
    ddp_find_unused_parameters=False
)

# Setup checkpointing
checkpoint_manager = CheckpointManager(
    checkpoint_dir="./checkpoints",
    save_total_limit=3
)

# Initialize and train
trainer = FineTuningTrainer(
    model=model,
    training_config=training_config,
    train_dataloader=train_loader,
    eval_dataloader=val_loader,
    logger=logger,
    checkpoint_manager=checkpoint_manager
)

# Start training with full monitoring
trainer.train()

Basic Training

from quantllm import FineTuningTrainer, TrainingConfig

config = TrainingConfig(
    learning_rate=2e-4,
    num_epochs=3,
    batch_size=8,
    gradient_accumulation_steps=4
)

trainer = FineTuningTrainer(
    model=model,
    training_config=config,
    train_dataloader=train_loader,
    eval_dataloader=val_loader
)
trainer.train()

With Progress Tracking

from quantllm import FineTuningTrainer, TrainingLogger

logger = TrainingLogger()
trainer = FineTuningTrainer(
    model=model,
    training_config=config,
    train_dataloader=train_loader,
    eval_dataloader=val_loader,
    logger=logger
)
trainer.train()

Model Evaluation

from quantllm import ModelEvaluator

evaluator = ModelEvaluator(
    model=model,
    eval_dataloader=test_loader
)
metrics = evaluator.evaluate()

Checkpoint Management

from quantllm import CheckpointManager

checkpoint_manager = CheckpointManager(
    checkpoint_dir="./checkpoints",
    save_total_limit=3
)

trainer = FineTuningTrainer(
    model=model,
    checkpoint_manager=checkpoint_manager,
    ...
)

Custom Training Loop

class CustomTrainer(FineTuningTrainer):
    def training_step(self, batch):
        # Custom training logic
        pass

    def validation_step(self, batch):
        # Custom validation logic
        pass

trainer = CustomTrainer(model=model, ...)