Source code for quantllm.data.load_dataset

from datasets import load_dataset, Dataset
from typing import Optional, Dict, Any, Union
import os
from ..utils.logger import logger
from tqdm.auto import tqdm
import logging
import warnings

# Disable unnecessary logging
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
warnings.filterwarnings("ignore")

[docs] class LoadDataset:
[docs] def __init__(self): """ Initialize the dataset loader. """ self.logger = logger
[docs] def load_hf_dataset( self, dataset_name: str, split: Optional[str] = None, streaming: bool = False, **kwargs ) -> Dataset: """ Load a dataset from HuggingFace with custom progress bar. Args: dataset_name (str): Name of the dataset split (str, optional): Dataset split streaming (bool): Whether to use streaming **kwargs: Additional arguments for dataset loading Returns: Dataset: Loaded dataset """ try: self.logger.log_info(f"Loading dataset: {dataset_name}") # Create progress bar with tqdm(total=1, desc=f"Downloading {dataset_name}", unit="dataset") as pbar: dataset = load_dataset( dataset_name, split=split, streaming=streaming, **kwargs ) pbar.update(1) self.logger.log_info(f"Successfully loaded dataset: {dataset_name}") return dataset except Exception as e: self.logger.log_error(f"Error loading dataset: {str(e)}") raise
[docs] def load_local_dataset( self, file_path: str, file_type: str = "auto", **kwargs ) -> Dataset: """ Load a dataset from local file with progress bar. Args: file_path (str): Path to the dataset file file_type (str): Type of file (auto, csv, json, text, parquet) **kwargs: Additional arguments for dataset loading """ try: self.logger.log_info(f"Loading local dataset: {file_path}") if file_type == "auto": extension = os.path.splitext(file_path)[1][1:].lower() if extension in ["csv", "json", "txt", "parquet"]: file_type = "text" if extension == "txt" else extension else: raise ValueError(f"Unsupported file extension: {extension}") # Create progress bar with tqdm(total=1, desc=f"Loading {file_path}", unit="file") as pbar: dataset = load_dataset( file_type, data_files=file_path, **kwargs ) pbar.update(1) self.logger.log_info(f"Successfully loaded local dataset: {file_path}") return dataset except Exception as e: self.logger.log_error(f"Error loading local dataset: {str(e)}") raise
[docs] def load_custom_dataset( self, data: Union[Dict, list], **kwargs ) -> Dataset: """ Load a custom dataset from data with progress indication. Args: data (Union[Dict, list]): Dataset data **kwargs: Additional arguments for dataset creation Returns: Dataset: Created dataset """ try: self.logger.log_info("Creating custom dataset") # Create progress bar with tqdm(total=1, desc="Creating dataset", unit="dataset") as pbar: dataset = Dataset.from_dict(data, **kwargs) pbar.update(1) self.logger.log_info("Successfully created custom dataset") return dataset except Exception as e: self.logger.log_error(f"Error creating custom dataset: {str(e)}") raise