simson_base / simson_modeling /create_augmented_dataset.py.save
Defetya's picture
Upload folder using huggingface_hub
592e96e verified
import pandas as pd
from tqdm import tqdm
from rdkit import Chem, RDLogger
from datasets import load_dataset
from multiprocessing import Pool, cpu_count
import os
# Suppress RDKit console output for cleaner logs
RDLogger.DisableLog('rdApp.*')
class SmilesEnumerator:
"""
A simple class to encapsulate the SMILES randomization logic.
Needed for multiprocessing to work correctly with instance methods.
"""
def randomize_smiles(self, smiles):
"""Generates a randomized SMILES string."""
try:
mol = Chem.MolFromSmiles(smiles)
# Return a randomized, non-canonical SMILES string
return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
except:
# If RDKit fails, return the original smiles string
return smiles
def create_augmented_pair(smiles_string):
"""
Worker function: takes one SMILES string and returns a tuple
containing two different randomized versions of it.
"""
enumerator = SmilesEnumerator()
smiles_1 = enumerator.randomize_smiles(smiles_string)
smiles_2 = enumerator.randomize_smiles(smiles_string)
return smiles_1, smiles_2
def main():
"""
Main function to run the parallel data preprocessing.
"""
# --- Configuration ---
# Load your desired dataset from Hugging Face
dataset_name = 'jablonkagroup/pubchem-smiles-molecular-formula'
# Specify the column containing the SMILES strings
smiles_column_name = 'smiles'
# Set the output file path
output_path = 'data/pubchem_computed_110_end_M.parquet'
# --- Data Loading ---
print(f"Loading dataset '{dataset_name}'...")
# Use streaming to avoid downloading the whole dataset if you only need a subset
dataset = load_dataset(dataset_name, split='train').select(range(110_000_000, ))
# Take the desired number of samples
smiles_list = dataset[smiles_column_name]
print(f"Successfully fetched {len(smiles_list)} SMILES strings.")
# --- Parallel Processing ---
# Use all available CPU cores for maximum speed
num_workers = cpu_count()
print(f"Starting SMILES augmentation with {num_workers} worker processes...")
# A Pool of processes will run the `create_augmented_pair` function in parallel
with Pool(num_workers) as p:
# Use tqdm to create a progress bar for the mapping operation
results = list(tqdm(p.imap(create_augmented_pair, smiles_list), total=len(smiles_list), desc="Augmenting Pairs"))
# --- Saving Data ---
print("Processing complete. Converting to DataFrame...")
# Convert the list of tuples into a pandas DataFrame
df = pd.DataFrame(results, columns=['smiles_1', 'smiles_2'])
# Ensure the output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"Saving augmented pairs to '{output_path}'...")
# Save the DataFrame to a Parquet file for efficient storage and loading
df.to_parquet(output_path)
print("All done. Your pre-computed dataset is ready!")
if __name__ == '__main__':
main()