import pandas as pd from tqdm import tqdm from rdkit import Chem, RDLogger from datasets import load_dataset from multiprocessing import Pool, cpu_count import os # Suppress RDKit console output for cleaner logs RDLogger.DisableLog('rdApp.*') class SmilesEnumerator: """ A simple class to encapsulate the SMILES randomization logic. Needed for multiprocessing to work correctly with instance methods. """ def randomize_smiles(self, smiles): """Generates a randomized SMILES string.""" try: mol = Chem.MolFromSmiles(smiles) # Return a randomized, non-canonical SMILES string return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles except: # If RDKit fails, return the original smiles string return smiles def create_augmented_pair(smiles_string): """ Worker function: takes one SMILES string and returns a tuple containing two different randomized versions of it. """ enumerator = SmilesEnumerator() smiles_1 = enumerator.randomize_smiles(smiles_string) smiles_2 = enumerator.randomize_smiles(smiles_string) return smiles_1, smiles_2 def main(): """ Main function to run the parallel data preprocessing. """ # --- Configuration --- # Load your desired dataset from Hugging Face dataset_name = 'jablonkagroup/pubchem-smiles-molecular-formula' # Specify the column containing the SMILES strings smiles_column_name = 'smiles' # Set the output file path output_path = 'data/pubchem_computed_110_end_M.parquet' # --- Data Loading --- print(f"Loading dataset '{dataset_name}'...") # Use streaming to avoid downloading the whole dataset if you only need a subset dataset = load_dataset(dataset_name, split='train').select(range(110_000_000, )) # Take the desired number of samples smiles_list = dataset[smiles_column_name] print(f"Successfully fetched {len(smiles_list)} SMILES strings.") # --- Parallel Processing --- # Use all available CPU cores for maximum speed num_workers = cpu_count() print(f"Starting SMILES augmentation with {num_workers} worker processes...") # A Pool of processes will run the `create_augmented_pair` function in parallel with Pool(num_workers) as p: # Use tqdm to create a progress bar for the mapping operation results = list(tqdm(p.imap(create_augmented_pair, smiles_list), total=len(smiles_list), desc="Augmenting Pairs")) # --- Saving Data --- print("Processing complete. Converting to DataFrame...") # Convert the list of tuples into a pandas DataFrame df = pd.DataFrame(results, columns=['smiles_1', 'smiles_2']) # Ensure the output directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) print(f"Saving augmented pairs to '{output_path}'...") # Save the DataFrame to a Parquet file for efficient storage and loading df.to_parquet(output_path) print("All done. Your pre-computed dataset is ready!") if __name__ == '__main__': main()