simson_base / simson_modeling /create_splits.py
Defetya's picture
Upload folder using huggingface_hub
592e96e verified
import os
import pandas as pd
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
def concatenate_and_split_parquet(
input_dir: str,
output_dir: str,
val_size: int = 10000,
test_size: int = 5000,
random_state: int = 42
):
"""
Concatenate all parquet files in a directory and split into train/val/test sets.
Args:
input_dir: Path to directory containing parquet files
output_dir: Path to directory where split files will be saved
val_size: Number of samples for validation set (default: 10000)
test_size: Number of samples for test set (default: 5000)
random_state: Random seed for reproducibility
"""
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Find all parquet files in the input directory
input_path = Path(input_dir)
parquet_files = list(input_path.glob("*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {input_dir}")
print(f"Found {len(parquet_files)} parquet files")
# Read and concatenate all parquet files
print("Reading and concatenating parquet files...")
dataframes = []
for file_path in parquet_files:
print(f"Reading {file_path.name}...")
df = pd.read_parquet(file_path)
dataframes.append(df)
# Concatenate all dataframes
combined_df = pd.concat(dataframes, ignore_index=True)
print(f"Combined dataset shape: {combined_df.shape}")
# Check if we have enough samples
total_samples = len(combined_df)
required_samples = val_size + test_size
if total_samples < required_samples:
raise ValueError(
f"Not enough samples. Required: {required_samples}, Available: {total_samples}"
)
# Shuffle the data
combined_df = combined_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
# Split the data
print("Splitting data...")
# First split: separate test set
temp_df, test_df = train_test_split(
combined_df,
test_size=test_size,
random_state=random_state
)
# Second split: separate validation from remaining data
train_df, val_df = train_test_split(
temp_df,
test_size=val_size,
random_state=random_state
)
print(f"Training set shape: {train_df.shape}")
print(f"Validation set shape: {val_df.shape}")
print(f"Test set shape: {test_df.shape}")
# Save the splits as parquet files
output_path = Path(output_dir)
train_path = output_path / "train.parquet"
val_path = output_path / "validation.parquet"
test_path = output_path / "test.parquet"
print("Saving split datasets...")
train_df.to_parquet(train_path, index=False)
val_df.to_parquet(val_path, index=False)
test_df.to_parquet(test_path, index=False)
print(f"Files saved to:")
print(f" Training: {train_path}")
print(f" Validation: {val_path}")
print(f" Test: {test_path}")
return train_df, val_df, test_df
# Alternative version using PyArrow for better performance with large files
def concatenate_and_split_parquet_arrow(
input_dir: str,
output_dir: str,
val_size: int = 10000,
test_size: int = 5000,
random_state: int = 42
):
"""
Same functionality as above but using PyArrow for better performance.
"""
import pyarrow as pa
import pyarrow.parquet as pq
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Find all parquet files
input_path = Path(input_dir)
parquet_files = list(input_path.glob("*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {input_dir}")
print(f"Found {len(parquet_files)} parquet files")
# Read and concatenate using PyArrow
print("Reading and concatenating parquet files...")
tables = []
for file_path in parquet_files:
print(f"Reading {file_path.name}...")
table = pq.read_table(file_path)
tables.append(table)
# Concatenate tables
combined_table = pa.concat_tables(tables)
combined_df = combined_table.to_pandas()
print(f"Combined dataset shape: {combined_df.shape}")
# Rest of the function is the same as above
total_samples = len(combined_df)
required_samples = val_size + test_size
if total_samples < required_samples:
raise ValueError(
f"Not enough samples. Required: {required_samples}, Available: {total_samples}"
)
# Shuffle and split
combined_df = combined_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
temp_df, test_df = train_test_split(
combined_df, test_size=test_size, random_state=random_state
)
train_df, val_df = train_test_split(
temp_df, test_size=val_size, random_state=random_state
)
print(f"Training set shape: {train_df.shape}")
print(f"Validation set shape: {val_df.shape}")
print(f"Test set shape: {test_df.shape}")
# Save using PyArrow
output_path = Path(output_dir)
pq.write_table(pa.Table.from_pandas(train_df), output_path / "train.parquet")
pq.write_table(pa.Table.from_pandas(val_df), output_path / "validation.parquet")
pq.write_table(pa.Table.from_pandas(test_df), output_path / "test.parquet")
print(f"Files saved to {output_dir}")
return train_df, val_df, test_df
# Example usage
if __name__ == "__main__":
# Example usage
input_directory = "data"
output_directory = "data/polymer_splits"
# Using pandas version
train_df, val_df, test_df = concatenate_and_split_parquet(
input_dir=input_directory,
output_dir=output_directory,
val_size=10000,
test_size=5000,
random_state=42
)
# Or using PyArrow version for better performance
# train_df, val_df, test_df = concatenate_and_split_parquet_arrow(
# input_dir=input_directory,
# output_dir=output_directory,
# val_size=10000,
# test_size=5000,
# random_state=42
# )