Source code for quantllm.data.dataset_preprocessor

from datasets import Dataset
from typing import Optional, Dict, Any, Callable, Tuple
from transformers import PreTrainedTokenizer
from ..utils.logger import logger
from tqdm.auto import tqdm
import logging
import warnings

# Disable unnecessary logging
logging.getLogger("tokenizers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

[docs] class DatasetPreprocessor:
[docs] def __init__(self, tokenizer: PreTrainedTokenizer): self.tokenizer = tokenizer self.logger = logger # Set pad token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.logger.log_info("Set pad token to eos token")
[docs] def validate_datasets(self, datasets): """Validate input datasets.""" for dataset in datasets: if dataset is not None and not isinstance(dataset, Dataset): raise ValueError(f"Expected Dataset object, got {type(dataset)}")
[docs] def preprocess_text(self, text: str) -> str: """Basic text preprocessing""" if not text: return "" text = str(text).strip() text = " ".join(text.split()) # Normalize whitespace return text
[docs] def tokenize_dataset( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, max_length: int = 512, text_column: str = "text", label_column: Optional[str] = None, batch_size: int = 1000 ) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]: """Tokenize datasets with preprocessing and progress bars.""" try: self.validate_datasets([train_dataset, val_dataset, test_dataset]) def process_and_tokenize_batch(examples): # Get texts and preprocess with progress indication texts = examples[text_column] if not isinstance(texts, list): texts = [texts] # Preprocess texts texts = [self.preprocess_text(text) for text in texts] try: # Tokenize with padding and truncation tokenized = self.tokenizer( texts, padding="max_length", truncation=True, max_length=max_length + 1, # Add 1 for shift return_tensors=None ) # For causal language modeling, prepare shifted sequences input_ids = tokenized["input_ids"] attention_mask = tokenized["attention_mask"] # Prepare shifted sequences for input and labels labels = [ids[1:] for ids in input_ids] input_ids = [ids[:-1] for ids in input_ids] attention_mask = [mask[:-1] for mask in attention_mask] # Verify sequence lengths expected_length = max_length assert all(len(seq) == expected_length for seq in input_ids), "Input sequence lengths don't match" assert all(len(seq) == expected_length for seq in attention_mask), "Attention mask lengths don't match" assert all(len(seq) == expected_length for seq in labels), "Label sequence lengths don't match" result = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } return result except Exception as e: self.logger.log_error(f"Error tokenizing batch: {str(e)}") raise # Process datasets with overall progress bars self.logger.log_info("Processing training dataset") train_tokenized = train_dataset.map( process_and_tokenize_batch, batched=True, batch_size=batch_size, remove_columns=train_dataset.column_names, desc="Tokenizing training set" ) self.logger.log_success(f"Tokenized training dataset: {len(train_tokenized)} examples") val_tokenized = None if val_dataset is not None: self.logger.log_info("Processing validation dataset") val_tokenized = val_dataset.map( process_and_tokenize_batch, batched=True, batch_size=batch_size, remove_columns=val_dataset.column_names, desc="Tokenizing validation set" ) self.logger.log_success(f"Tokenized validation dataset: {len(val_tokenized)} examples") test_tokenized = None if test_dataset is not None: self.logger.log_info("Processing test dataset") test_tokenized = test_dataset.map( process_and_tokenize_batch, batched=True, batch_size=batch_size, remove_columns=test_dataset.column_names, desc="Tokenizing test set" ) self.logger.log_success(f"Tokenized test dataset: {len(test_tokenized)} examples") # Set format to PyTorch tensors train_tokenized.set_format("torch") if val_tokenized: val_tokenized.set_format("torch") if test_tokenized: test_tokenized.set_format("torch") return train_tokenized, val_tokenized, test_tokenized except Exception as e: self.logger.log_error(f"Error in tokenization: {str(e)}") raise