Source code for quantllm.trainer.trainer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from typing import Optional, Dict, Any, List, Union, Callable
from pathlib import Path
import numpy as np
from tqdm import tqdm
import os
from datetime import datetime

# Conditionally import wandb
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

from ..config.training_config import TrainingConfig
from ..utils.logger import logger
from ..hub.checkpoint_manager import CheckpointManager
from ..hub.hub_manager import HubManager

[docs] class FineTuningTrainer:
[docs] def __init__( self, model: nn.Module, training_config: TrainingConfig, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, checkpoint_manager: Optional[CheckpointManager] = None, hub_manager: Optional[HubManager] = None, device: Optional[Union[str, torch.device]] = None, use_wandb: bool = False, wandb_config: Optional[Dict[str, Any]] = None ): """ Initialize the trainer with PyTorch-based training loop. Args: model (nn.Module): The model to train training_config (TrainingConfig): Training configuration train_dataloader (DataLoader): Training data loader eval_dataloader (DataLoader, optional): Evaluation data loader checkpoint_manager (CheckpointManager, optional): Checkpoint manager hub_manager (HubManager, optional): Hub manager for model pushing device (str or torch.device, optional): Device to train on use_wandb (bool): Whether to use Weights & Biases wandb_config (Dict[str, Any], optional): Weights & Biases configuration """ self.model = model self.config = training_config self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.checkpoint_manager = checkpoint_manager self.hub_manager = hub_manager self.use_wandb = use_wandb and WANDB_AVAILABLE self.wandb_config = wandb_config or {} # Handle device setup if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) logger.log_info(f"Using device: {self.device}") self.model.to(self.device) # Initialize optimizer self.optimizer = self._create_optimizer() # Initialize learning rate scheduler self.scheduler = self._create_scheduler() # Initialize mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.device == torch.device("cuda") else None # Training state self.global_step = 0 self.epoch = 0 self.best_metric = float('inf') self.patience_counter = 0 # Setup Weights & Biases if enabled if self.use_wandb: self._setup_wandb()
def _create_optimizer(self) -> optim.Optimizer: """Create optimizer based on configuration.""" # Get parameters to optimize (excluding frozen parameters) params_to_optimize = [ p for p in self.model.parameters() if p.requires_grad ] # Create optimizer if self.config.optimizer.lower() == "adamw": optimizer = optim.AdamW( params_to_optimize, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, betas=(0.9, 0.999), eps=1e-8 ) elif self.config.optimizer.lower() == "adam": optimizer = optim.Adam( params_to_optimize, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, betas=(0.9, 0.999), eps=1e-8 ) elif self.config.optimizer.lower() == "sgd": optimizer = optim.SGD( params_to_optimize, lr=self.config.learning_rate, momentum=0.9, weight_decay=self.config.weight_decay ) else: raise ValueError(f"Unsupported optimizer: {self.config.optimizer}") return optimizer def _create_scheduler(self) -> Optional[optim.lr_scheduler._LRScheduler]: """Create learning rate scheduler based on configuration.""" if self.config.scheduler.lower() == "linear": def lr_lambda(current_step: int) -> float: if current_step < self.config.warmup_steps: return float(current_step) / float(max(1, self.config.warmup_steps)) return max( 0.0, float(self.config.num_epochs * len(self.train_dataloader) - current_step) / float(max(1, self.config.num_epochs * len(self.train_dataloader) - self.config.warmup_steps)) ) scheduler = LambdaLR(self.optimizer, lr_lambda) elif self.config.scheduler.lower() == "plateau": scheduler = ReduceLROnPlateau( self.optimizer, mode='min', factor=0.1, patience=3, verbose=True ) else: scheduler = None return scheduler def _setup_wandb(self): """Setup Weights & Biases logging.""" try: if wandb.login(key=self.wandb_token, relogin=True): logger.log_info("Logged in to Weights & Biases") wandb.init( project=self.wandb_config.get("project", "quantllm"), name=self.wandb_config.get("name", f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"), config=self.config.to_dict() ) else: logger.log_warning("Failed to log in to Weights & Biases. Continuing without wandb logging.") self.use_wandb = False except Exception as e: logger.log_warning(f"Error setting up Weights & Biases: {str(e)}. Continuing without wandb logging.") self.use_wandb = False def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute loss for a batch of data.""" # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Forward pass outputs = self.model(**batch) return outputs.loss
[docs] def train_step(self, batch, scaler): """Single training step.""" try: # Convert batch to dictionary if it's a tuple/list if isinstance(batch, (tuple, list)): batch = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2] if len(batch) > 2 else None } # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Determine if we should use autocast based on device if self.device.type == "cuda": with torch.cuda.amp.autocast(): outputs = self.model(**batch) loss = outputs.loss # Backward pass with gradient scaling scaler.scale(loss).backward() if self.config.max_grad_norm is not None: scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) scaler.step(self.optimizer) scaler.update() else: # CPU or MPS training - no autocast needed outputs = self.model(**batch) loss = outputs.loss # Standard backward pass loss.backward() if self.config.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.optimizer.step() self.optimizer.zero_grad() return loss.item() except Exception as e: logger.log_error(f"Error in training step: {str(e)}") raise
[docs] def train(self): """Train the model.""" try: # Log training start logger.log_training_start( model_name=self.model.config.name_or_path, dataset_name=self.train_dataloader.dataset.__class__.__name__, config=self.config.to_dict() ) # Disable model caching when using gradient checkpointing if hasattr(self.model.config, 'gradient_checkpointing') and self.model.config.gradient_checkpointing: self.model.config.use_cache = False logger.log_info("Disabled model caching due to gradient checkpointing") scaler = torch.cuda.amp.GradScaler() for epoch in range(self.config.num_epochs): self.model.train() total_loss = 0 # Log epoch start logger.log_epoch_start(epoch + 1, self.config.num_epochs) # Training loop with tqdm(total=len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.config.num_epochs}") as pbar: for step, batch in enumerate(self.train_dataloader): loss = self.train_step(batch, scaler) total_loss += loss # Update progress bar pbar.update(1) pbar.set_postfix({'loss': f'{loss:.4f}'}) # Log steps if configured if step > 0 and self.config.logging_steps > 0 and step % self.config.logging_steps == 0: avg_loss = total_loss / (step + 1) logger.log_metrics({"loss": avg_loss}, step=step) # Save checkpoint if configured if self.config.save_steps > 0 and (step + 1) % self.config.save_steps == 0: metrics = {"loss": total_loss / (step + 1)} self._save_checkpoint(epoch, step, metrics) # Epoch end processing avg_loss = total_loss / len(self.train_dataloader) epoch_metrics = {"avg_loss": avg_loss} logger.log_epoch_complete(epoch + 1, epoch_metrics) # Run evaluation if configured if self.config.eval_epochs > 0 and (epoch + 1) % self.config.eval_epochs == 0: eval_metrics = self._evaluate() logger.log_evaluation_complete(eval_metrics) # Save epoch checkpoint if configured if self.config.save_epochs > 0 and (epoch + 1) % self.config.save_epochs == 0: metrics = {"epoch": epoch + 1, "avg_loss": avg_loss} self._save_checkpoint(epoch, None, metrics) # Log final training metrics final_metrics = { "final_loss": avg_loss, "total_epochs": self.config.num_epochs, "total_steps": self.global_step } logger.log_training_complete(final_metrics) except Exception as e: logger.log_error(f"Training error: {str(e)}") raise
def _evaluate(self) -> Dict[str, float]: """Evaluate the model on the validation set.""" if self.eval_dataloader is None: return {} logger.log_evaluation_start() self.model.eval() total_loss = 0 num_batches = 0 with torch.no_grad(): for batch in tqdm(self.eval_dataloader, desc="Evaluating"): loss = self._compute_loss(batch) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / num_batches metrics = {"eval_loss": avg_loss} logger.log_evaluation_complete(metrics) return metrics def _save_checkpoint(self, epoch: int, step: Optional[int] = None, metrics: Optional[Dict[str, float]] = None): """Save a checkpoint.""" if self.checkpoint_manager is None: return checkpoint_metrics = metrics or {} checkpoint_metrics.update({ "epoch": epoch + 1, "step": step if step is not None else "end_of_epoch", "global_step": self.global_step }) path = self.checkpoint_manager.save_checkpoint( model=self.model, tokenizer=None, # We don't save tokenizer with checkpoints epoch=epoch, metrics=checkpoint_metrics ) logger.log_checkpoint_save(path, checkpoint_metrics)
[docs] def save_model(self, output_dir: Union[str, Path]): """Save the model and training state.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Save model model_path = output_dir / "model" self.model.save_pretrained(model_path) # Save training state training_state = { "global_step": self.global_step, "epoch": self.epoch, "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None, "best_metric": self.best_metric } torch.save(training_state, output_dir / "training_state.pt")
[docs] def load_model(self, input_dir: Union[str, Path]): """Load the model and training state.""" input_dir = Path(input_dir) # Load model model_path = input_dir / "model" self.model = self.model.from_pretrained(model_path) self.model.to(self.device) # Load training state training_state = torch.load(input_dir / "training_state.pt") self.global_step = training_state["global_step"] self.epoch = training_state["epoch"] self.optimizer.load_state_dict(training_state["optimizer_state_dict"]) if self.scheduler and training_state["scheduler_state_dict"]: self.scheduler.load_state_dict(training_state["scheduler_state_dict"]) self.best_metric = training_state["best_metric"]