|
import pandas as pd |
|
from tqdm import tqdm |
|
from rdkit import Chem, RDLogger |
|
from datasets import load_dataset |
|
from multiprocessing import Pool, cpu_count |
|
import os |
|
|
|
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
class SmilesEnumerator: |
|
""" |
|
A simple class to encapsulate the SMILES randomization logic. |
|
Needed for multiprocessing to work correctly with instance methods. |
|
""" |
|
def randomize_smiles(self, smiles): |
|
"""Generates a randomized SMILES string.""" |
|
try: |
|
mol = Chem.MolFromSmiles(smiles) |
|
|
|
return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
|
except: |
|
|
|
return smiles |
|
|
|
def create_augmented_pair(smiles_string): |
|
""" |
|
Worker function: takes one SMILES string and returns a tuple |
|
containing two different randomized versions of it. |
|
""" |
|
enumerator = SmilesEnumerator() |
|
smiles_1 = enumerator.randomize_smiles(smiles_string) |
|
smiles_2 = enumerator.randomize_smiles(smiles_string) |
|
return smiles_1, smiles_2 |
|
|
|
def main(): |
|
""" |
|
Main function to run the parallel data preprocessing. |
|
""" |
|
|
|
|
|
dataset_name = 'jablonkagroup/pubchem-smiles-molecular-formula' |
|
|
|
smiles_column_name = 'smiles' |
|
|
|
output_path = 'data/pubchem_computed_110_end_M.parquet' |
|
|
|
|
|
print(f"Loading dataset '{dataset_name}'...") |
|
|
|
dataset = load_dataset(dataset_name, split='train').select(range(110_000_000, )) |
|
|
|
|
|
smiles_list = dataset[smiles_column_name] |
|
print(f"Successfully fetched {len(smiles_list)} SMILES strings.") |
|
|
|
|
|
|
|
num_workers = cpu_count() |
|
print(f"Starting SMILES augmentation with {num_workers} worker processes...") |
|
|
|
|
|
with Pool(num_workers) as p: |
|
|
|
results = list(tqdm(p.imap(create_augmented_pair, smiles_list), total=len(smiles_list), desc="Augmenting Pairs")) |
|
|
|
|
|
print("Processing complete. Converting to DataFrame...") |
|
|
|
df = pd.DataFrame(results, columns=['smiles_1', 'smiles_2']) |
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
print(f"Saving augmented pairs to '{output_path}'...") |
|
|
|
df.to_parquet(output_path) |
|
|
|
print("All done. Your pre-computed dataset is ready!") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|