from datasets import Dataset, DatasetDict
from typing import Optional, Tuple, Union
import numpy as np
from ..utils.logger import logger
from tqdm.auto import tqdm
import logging
# Configure logging
logging.getLogger("datasets").setLevel(logging.WARNING)
[docs]
class DatasetSplitter:
[docs]
def __init__(self):
"""Initialize dataset splitter."""
self.logger = logger
def _get_dataset_from_dict(self, dataset: Union[Dataset, DatasetDict], split: str = "train") -> Dataset:
"""Extract dataset from DatasetDict if needed."""
if isinstance(dataset, DatasetDict):
if split in dataset:
return dataset[split]
raise ValueError(f"DatasetDict does not contain split '{split}'")
return dataset
[docs]
def validate_split_params(self, train_size: float, val_size: float, test_size: float = None):
"""Validate split parameters."""
if train_size <= 0 or train_size >= 1:
raise ValueError(f"train_size must be between 0 and 1, got {train_size}")
if val_size <= 0 or val_size >= 1:
raise ValueError(f"val_size must be between 0 and 1, got {val_size}")
if test_size is not None and (test_size <= 0 or test_size >= 1):
raise ValueError(f"test_size must be between 0 and 1, got {test_size}")
total = train_size + val_size + (test_size or (1 - train_size - val_size))
if not (0.99 <= total <= 1.01): # Allow small floating point differences
raise ValueError(f"Split sizes must sum to 1.0, got {total}")
[docs]
def train_test_split(
self,
dataset: Dataset,
test_size: float = 0.2,
shuffle: bool = True,
seed: int = 42,
**kwargs
) -> Tuple[Dataset, Dataset]:
"""
Split dataset into train and test sets.
Args:
dataset (Dataset): Dataset to split
test_size (float): Size of test set
shuffle (bool): Whether to shuffle
seed (int): Random seed
**kwargs: Additional splitting arguments
Returns:
Tuple[Dataset, Dataset]: Train and test datasets
"""
try:
self.logger.log_info("Splitting dataset into train and test sets")
split_dataset = dataset.train_test_split(
test_size=test_size,
shuffle=shuffle,
seed=seed,
**kwargs
)
self.logger.log_info("Successfully split dataset")
return split_dataset["train"], split_dataset["test"]
except Exception as e:
self.logger.log_error(f"Error splitting dataset: {str(e)}")
raise
[docs]
def train_val_test_split(
self,
dataset: Union[Dataset, DatasetDict],
train_size: float = 0.8,
val_size: float = 0.1,
test_size: float = 0.1,
shuffle: bool = True,
seed: int = 42,
split: str = "train"
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Split dataset into train, validation and test sets with progress indication.
Args:
dataset (Dataset or DatasetDict): Dataset to split
train_size (float): Proportion of training set
val_size (float): Proportion of validation set
test_size (float): Proportion of test set
shuffle (bool): Whether to shuffle the dataset
seed (int): Random seed
split (str): Which split to use if dataset is a DatasetDict
Returns:
Tuple[Dataset, Dataset, Dataset]: Train, validation and test datasets
"""
try:
# Get the actual dataset if we have a DatasetDict
dataset = self._get_dataset_from_dict(dataset, split)
# Validate split proportions
total = train_size + val_size + test_size
if not np.isclose(total, 1.0):
raise ValueError(f"Split proportions must sum to 1, got {total}")
# Calculate split sizes
total_size = len(dataset)
train_samples = int(total_size * train_size)
val_samples = int(total_size * val_size)
test_samples = total_size - train_samples - val_samples
self.logger.log_info("Splitting dataset...")
# Create indices
indices = np.arange(total_size)
if shuffle:
with tqdm(total=1, desc="Shuffling dataset", unit="operation") as pbar:
rng = np.random.default_rng(seed)
rng.shuffle(indices)
pbar.update(1)
# Split dataset using Hugging Face's built-in functionality
with tqdm(total=2, desc="Creating splits", unit="split") as pbar:
# First split: train vs rest
train_val_split = dataset.train_test_split(
train_size=train_size,
seed=seed,
shuffle=False # We already shuffled if needed
)
train_dataset = train_val_split["train"]
rest_dataset = train_val_split["test"]
pbar.update(1)
# Second split: val vs test from the rest
val_ratio = val_size / (val_size + test_size)
val_test_split = rest_dataset.train_test_split(
train_size=val_ratio,
seed=seed,
shuffle=False
)
val_dataset = val_test_split["train"]
test_dataset = val_test_split["test"]
pbar.update(1)
# Log split sizes
self.logger.log_info(f"Split sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
return train_dataset, val_dataset, test_dataset
except Exception as e:
self.logger.log_error(f"Error splitting dataset: {str(e)}")
raise
[docs]
def train_val_split(
self,
dataset: Union[Dataset, DatasetDict],
train_size: float = 0.8,
shuffle: bool = True,
seed: int = 42,
split: str = "train"
) -> Tuple[Dataset, Dataset]:
"""Split dataset into train and validation sets."""
dataset = self._get_dataset_from_dict(dataset, split)
return dataset.train_test_split(
train_size=train_size,
shuffle=shuffle,
seed=seed
).values()
[docs]
def k_fold_split(self, dataset, n_splits: int = 5, shuffle: bool = True, seed: int = 42):
"""Create k-fold cross validation splits."""
try:
if not isinstance(dataset, Dataset):
raise ValueError(f"Expected Dataset object, got {type(dataset)}")
if n_splits < 2:
raise ValueError(f"n_splits must be at least 2, got {n_splits}")
# Convert to pandas for k-fold split
df = dataset.to_pandas()
from sklearn.model_selection import KFold
kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=seed)
folds = []
for train_idx, val_idx in kf.split(df):
train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
folds.append((train_dataset, val_dataset))
self.logger.log_info(f"Created {n_splits}-fold cross validation splits")
return folds
except Exception as e:
self.logger.log_error(f"Error creating k-fold splits: {str(e)}")
raise