File size: 150,692 Bytes
592e96e |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3d5d52d1-4874-44b5-b532-ef03da47644a",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from rdkit import Chem\n",
"from rdkit.Chem import Descriptors, rdMolDescriptors, Crippen, Lipinski\n",
"from tqdm import tqdm\n",
"import warnings\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split\n",
"import random\n",
"from concurrent.futures import ProcessPoolExecutor\n",
"import multiprocessing\n",
"\n",
"def analyze_polymer_features_rdkit(smiles):\n",
" mol = Chem.MolFromSmiles(smiles)\n",
" if mol is None:\n",
" return None\n",
" \n",
" features = {}\n",
" \n",
" # Basic molecular properties\n",
" features['mol_weight'] = Descriptors.MolWt(mol)\n",
" features['exact_mol_weight'] = Descriptors.ExactMolWt(mol)\n",
" features['num_heavy_atoms'] = mol.GetNumHeavyAtoms()\n",
" features['num_atoms'] = mol.GetNumAtoms()\n",
" features['num_bonds'] = mol.GetNumBonds()\n",
" \n",
" # Hydrogen bonding features\n",
" features['num_hbond_donors'] = Descriptors.NumHDonors(mol)\n",
" features['num_hbond_acceptors'] = Descriptors.NumHAcceptors(mol)\n",
" features['num_heteroatoms'] = Descriptors.NumHeteroatoms(mol)\n",
" \n",
" # Structural complexity\n",
" features['num_rotatable_bonds'] = Descriptors.NumRotatableBonds(mol)\n",
" features['num_saturated_rings'] = Descriptors.NumSaturatedRings(mol)\n",
" features['num_aromatic_rings'] = Descriptors.NumAromaticRings(mol)\n",
" features['num_aliphatic_rings'] = Descriptors.NumAliphaticRings(mol)\n",
" features['ring_count'] = Descriptors.RingCount(mol)\n",
" features['fraction_csp3'] = Descriptors.FractionCSP3(mol)\n",
" \n",
" # Surface area and polarity\n",
" features['tpsa'] = Descriptors.TPSA(mol)\n",
" features['polar_surface_area'] = rdMolDescriptors.CalcTPSA(mol)\n",
" \n",
" # Lipophilicity and solubility\n",
" features['logp'] = Descriptors.MolLogP(mol)\n",
" features['crippen_logp'] = Crippen.MolLogP(mol)\n",
" features['crippen_mr'] = Crippen.MolMR(mol) # Molar refractivity\n",
" \n",
" # Flexibility and rigidity\n",
" features['kappa1'] = Descriptors.Kappa1(mol) # Molecular shape index\n",
" features['kappa2'] = Descriptors.Kappa2(mol)\n",
" features['kappa3'] = Descriptors.Kappa3(mol)\n",
" features['chi0v'] = Descriptors.Chi0v(mol) # Connectivity indices\n",
" features['chi1v'] = Descriptors.Chi1v(mol)\n",
" features['chi2v'] = Descriptors.Chi2v(mol)\n",
" \n",
" # Electronic properties\n",
" features['balaban_j'] = Descriptors.BalabanJ(mol)\n",
" features['bertz_ct'] = Descriptors.BertzCT(mol) # Complexity index\n",
" \n",
" # Polymer-specific features\n",
" features['num_radical_electrons'] = Descriptors.NumRadicalElectrons(mol)\n",
" features['num_valence_electrons'] = Descriptors.NumValenceElectrons(mol)\n",
" \n",
" # Atom type counts\n",
" atom_counts = {}\n",
" for atom in mol.GetAtoms():\n",
" symbol = atom.GetSymbol()\n",
" atom_counts[symbol] = atom_counts.get(symbol, 0) + 1\n",
" \n",
" # Add individual atom counts as features\n",
" for element in ['C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I']:\n",
" features[f'count_{element}'] = atom_counts.get(element, 0)\n",
" features[f'ratio_{element}'] = atom_counts.get(element, 0) / features['num_atoms'] if features['num_atoms'] > 0 else 0\n",
" \n",
" # Bond type analysis\n",
" bond_types = {'SINGLE': 0, 'DOUBLE': 0, 'TRIPLE': 0, 'AROMATIC': 0}\n",
" for bond in mol.GetBonds():\n",
" bond_type = str(bond.GetBondType())\n",
" if bond_type in bond_types:\n",
" bond_types[bond_type] += 1\n",
" \n",
" for bond_type, count in bond_types.items():\n",
" features[f'num_{bond_type.lower()}_bonds'] = count\n",
" features[f'ratio_{bond_type.lower()}_bonds'] = count / features['num_bonds'] if features['num_bonds'] > 0 else 0\n",
" \n",
" # Hybridization analysis\n",
" hybridization_counts = {'SP': 0, 'SP2': 0, 'SP3': 0, 'SP3D': 0, 'SP3D2': 0}\n",
" for atom in mol.GetAtoms():\n",
" hyb = str(atom.GetHybridization())\n",
" if hyb in hybridization_counts:\n",
" hybridization_counts[hyb] += 1\n",
" \n",
" for hyb_type, count in hybridization_counts.items():\n",
" features[f'num_{hyb_type.lower()}_carbons'] = count\n",
" features[f'ratio_{hyb_type.lower()}_carbons'] = count / features['num_atoms'] if features['num_atoms'] > 0 else 0\n",
" \n",
" # Formal charge analysis\n",
" formal_charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]\n",
" features['total_formal_charge'] = sum(formal_charges)\n",
" features['abs_total_formal_charge'] = sum(abs(charge) for charge in formal_charges)\n",
" features['max_formal_charge'] = max(formal_charges) if formal_charges else 0\n",
" features['min_formal_charge'] = min(formal_charges) if formal_charges else 0\n",
" \n",
" # Aromaticity features\n",
" aromatic_atoms = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())\n",
" features['num_aromatic_atoms'] = aromatic_atoms\n",
" features['aromatic_ratio'] = aromatic_atoms / features['num_atoms'] if features['num_atoms'] > 0 else 0\n",
" \n",
" # Ring size analysis\n",
" ring_info = mol.GetRingInfo()\n",
" ring_sizes = [len(ring) for ring in ring_info.AtomRings()]\n",
" if ring_sizes:\n",
" features['avg_ring_size'] = sum(ring_sizes) / len(ring_sizes)\n",
" features['max_ring_size'] = max(ring_sizes)\n",
" features['min_ring_size'] = min(ring_sizes)\n",
" features['num_3_rings'] = sum(1 for size in ring_sizes if size == 3)\n",
" features['num_4_rings'] = sum(1 for size in ring_sizes if size == 4)\n",
" features['num_5_rings'] = sum(1 for size in ring_sizes if size == 5)\n",
" features['num_6_rings'] = sum(1 for size in ring_sizes if size == 6)\n",
" features['num_7_rings'] = sum(1 for size in ring_sizes if size == 7)\n",
" features['num_large_rings'] = sum(1 for size in ring_sizes if size > 7)\n",
" else:\n",
" features.update({\n",
" 'avg_ring_size': 0, 'max_ring_size': 0, 'min_ring_size': 0,\n",
" 'num_3_rings': 0, 'num_4_rings': 0, 'num_5_rings': 0,\n",
" 'num_6_rings': 0, 'num_7_rings': 0, 'num_large_rings': 0\n",
" })\n",
" \n",
" # Polymer-specific structural features\n",
" features['has_polymer_notation'] = '*' in smiles\n",
" features['smiles_length'] = len(smiles)\n",
" features['branch_count'] = smiles.count('(')\n",
" features['branch_ratio'] = smiles.count('(') / len(smiles) if len(smiles) > 0 else 0\n",
" \n",
" return features\n",
"\n",
"def add_features(df, num_workers=None):\n",
" \"\"\"\n",
" Improved version using multiprocessing to calculate RDKit descriptors efficiently.\n",
" \n",
" Parameters:\n",
" df: pandas DataFrame with 'Smiles' column\n",
" num_workers: Number of worker processes (defaults to number of CPU cores)\n",
" \"\"\"\n",
" if num_workers is None:\n",
" num_workers = multiprocessing.cpu_count()\n",
" \n",
" smiles_list = df['Smiles'].tolist()\n",
" \n",
" with ProcessPoolExecutor(max_workers=num_workers) as executor:\n",
" # Use tqdm with executor.map for progress tracking\n",
" features_list = list(tqdm(executor.map(analyze_polymer_features_rdkit, smiles_list), \n",
" total=len(smiles_list), \n",
" desc=\"Computing RDKit descriptors\"))\n",
" \n",
" # Convert results to DataFrame\n",
" features_df = pd.DataFrame(features_list)\n",
" \n",
" # Concatenate with original DataFrame\n",
" df_result = pd.concat([df, features_df], axis=1)\n",
" \n",
" return df_result\n",
"\n",
"def get_list_dif(l1, l2):\n",
" return list(set(l1) - set(l2))\n",
"\n",
"# Usage example:\n",
"# df_with_features = add_features(df, num_workers=4)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "155598af-79f3-4933-8b5c-1fd11f64b870",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('/home/jovyan/simson_training_bolgov/regression/PI_Tg_P308K_synth_db_chem.csv').drop(columns=['Unnamed: 0'], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c69cc497-9fb6-4f74-96eb-257d7aa4a91a",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('/home/jovyan/simson_training_bolgov/kaggle_comp/train.csv')\n",
"df['Smiles'] = df['SMILES']\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b076c55-d6ef-4780-af97-5fccd5062661",
"metadata": {},
"outputs": [],
"source": [
"sample_df = df.iloc[:10_000]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96313883-c2ca-4eb8-9ec7-9aaca8dba077",
"metadata": {},
"outputs": [],
"source": [
"features_df = add_features(sample_df)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "41c7f85a-ea65-42e5-b315-ef304ba311c4",
"metadata": {},
"outputs": [],
"source": [
"selected_features = ['mol_weight', 'exact_mol_weight', 'num_heavy_atoms', 'num_atoms',\n",
" 'num_bonds', 'num_hbond_donors', 'num_hbond_acceptors',\n",
" 'num_heteroatoms', 'num_rotatable_bonds', 'num_saturated_rings',\n",
" 'num_aromatic_rings', 'num_aliphatic_rings', 'ring_count',\n",
" 'fraction_csp3', 'tpsa', 'polar_surface_area', 'logp', 'crippen_logp',\n",
" 'crippen_mr', 'kappa1', 'kappa2', 'kappa3', 'chi0v', 'chi1v', 'chi2v',\n",
" 'balaban_j', 'bertz_ct', 'num_radical_electrons',\n",
" 'num_valence_electrons',\n",
" 'count_O', 'ratio_O', 'count_S', 'ratio_S', 'count_P', 'ratio_P',\n",
" 'count_F', 'ratio_F', 'count_Cl', 'ratio_Cl', 'count_Br', 'ratio_Br',\n",
" 'count_I', 'ratio_I', 'num_single_bonds', 'ratio_single_bonds',\n",
" 'num_double_bonds', 'ratio_double_bonds', 'num_triple_bonds',\n",
" 'ratio_triple_bonds', 'num_aromatic_bonds', 'ratio_aromatic_bonds',\n",
" 'num_sp_carbons', 'ratio_sp_carbons', 'num_sp2_carbons',\n",
" 'ratio_sp2_carbons', 'num_sp3_carbons', 'ratio_sp3_carbons',\n",
" 'num_sp3d_carbons', 'ratio_sp3d_carbons', 'num_sp3d2_carbons',\n",
" 'ratio_sp3d2_carbons', 'total_formal_charge', 'abs_total_formal_charge',\n",
" 'max_formal_charge', 'min_formal_charge', 'num_aromatic_atoms',\n",
" 'aromatic_ratio', 'avg_ring_size', 'max_ring_size', 'min_ring_size',\n",
" 'num_3_rings', 'num_4_rings', 'num_5_rings', 'num_6_rings',\n",
" 'num_7_rings', 'num_large_rings', 'has_polymer_notation',\n",
" 'branch_count', 'branch_ratio']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fc31605d-cc21-4533-b04e-f8acdaef1a65",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['id', 'SMILES', 'Tg', 'FFV', 'Tc', 'Density', 'Rg', 'Smiles',\n",
" 'mol_weight', 'exact_mol_weight', 'num_heavy_atoms', 'num_atoms',\n",
" 'num_bonds', 'num_hbond_donors', 'num_hbond_acceptors',\n",
" 'num_heteroatoms', 'num_rotatable_bonds', 'num_saturated_rings',\n",
" 'num_aromatic_rings', 'num_aliphatic_rings', 'ring_count',\n",
" 'fraction_csp3', 'tpsa', 'polar_surface_area', 'logp', 'crippen_logp',\n",
" 'crippen_mr', 'kappa1', 'kappa2', 'kappa3', 'chi0v', 'chi1v', 'chi2v',\n",
" 'balaban_j', 'bertz_ct', 'num_radical_electrons',\n",
" 'num_valence_electrons', 'count_C', 'ratio_C', 'count_N', 'ratio_N',\n",
" 'count_O', 'ratio_O', 'count_S', 'ratio_S', 'count_P', 'ratio_P',\n",
" 'count_F', 'ratio_F', 'count_Cl', 'ratio_Cl', 'count_Br', 'ratio_Br',\n",
" 'count_I', 'ratio_I', 'num_single_bonds', 'ratio_single_bonds',\n",
" 'num_double_bonds', 'ratio_double_bonds', 'num_triple_bonds',\n",
" 'ratio_triple_bonds', 'num_aromatic_bonds', 'ratio_aromatic_bonds',\n",
" 'num_sp_carbons', 'ratio_sp_carbons', 'num_sp2_carbons',\n",
" 'ratio_sp2_carbons', 'num_sp3_carbons', 'ratio_sp3_carbons',\n",
" 'num_sp3d_carbons', 'ratio_sp3d_carbons', 'num_sp3d2_carbons',\n",
" 'ratio_sp3d2_carbons', 'total_formal_charge', 'abs_total_formal_charge',\n",
" 'max_formal_charge', 'min_formal_charge', 'num_aromatic_atoms',\n",
" 'aromatic_ratio', 'avg_ring_size', 'max_ring_size', 'min_ring_size',\n",
" 'num_3_rings', 'num_4_rings', 'num_5_rings', 'num_6_rings',\n",
" 'num_7_rings', 'num_large_rings', 'has_polymer_notation',\n",
" 'smiles_length', 'branch_count', 'branch_ratio'],\n",
" dtype='object')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scalers = []\n",
"for col in selected_features:\n",
" scaler = StandardScaler()\n",
" features_df[col] = scaler.fit_transform(features_df[col].to_numpy().reshape(-1, 1)).flatten()\n",
" scalers.append(scaler)\n",
" \n",
"features_df.columns"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f2f1a614-0ba7-4a01-9731-532afc1d14e0",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'features_df' is not defined",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m new_features = []\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m feature \u001b[38;5;129;01min\u001b[39;00m selected_features:\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m unique_list = \u001b[43mfeatures_df\u001b[49m[feature].unique()\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(unique_list) > \u001b[32m300\u001b[39m:\n\u001b[32m 6\u001b[39m new_features.append(feature)\n",
"\u001b[31mNameError\u001b[39m: name 'features_df' is not defined"
]
}
],
"source": [
"new_features = []\n",
"\n",
"for feature in selected_features:\n",
" unique_list = features_df[feature].unique()\n",
" if len(unique_list) > 300:\n",
" new_features.append(feature)\n",
"new_features.append('Smiles')\n",
"print(new_features)\n",
"len(new_features), len(selected_features)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28cbac75-8a9f-4292-aedb-11f33f5a6056",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "c065d950-7a63-4424-9923-1072d2e2268c",
"metadata": {},
"outputs": [],
"source": [
"features_df.to_csv('7k_w_descriptors.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "069a9021-d440-4bf1-9882-a2af25f2e801",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>SMILES</th>\n",
" <th>Tg</th>\n",
" <th>FFV</th>\n",
" <th>Tc</th>\n",
" <th>Density</th>\n",
" <th>Rg</th>\n",
" <th>Smiles</th>\n",
" <th>mol_weight</th>\n",
" <th>exact_mol_weight</th>\n",
" <th>...</th>\n",
" <th>num_3_rings</th>\n",
" <th>num_4_rings</th>\n",
" <th>num_5_rings</th>\n",
" <th>num_6_rings</th>\n",
" <th>num_7_rings</th>\n",
" <th>num_large_rings</th>\n",
" <th>has_polymer_notation</th>\n",
" <th>smiles_length</th>\n",
" <th>branch_count</th>\n",
" <th>branch_ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>87817</td>\n",
" <td>*CC(*)c1ccccc1C(=O)OCCCCCC</td>\n",
" <td>NaN</td>\n",
" <td>0.374645</td>\n",
" <td>0.205667</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*CC(*)c1ccccc1C(=O)OCCCCCC</td>\n",
" <td>-0.875755</td>\n",
" <td>-0.875617</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>-0.626991</td>\n",
" <td>-0.788904</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>26</td>\n",
" <td>-0.985221</td>\n",
" <td>-0.813832</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>106919</td>\n",
" <td>*Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5...</td>\n",
" <td>NaN</td>\n",
" <td>0.370410</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5...</td>\n",
" <td>0.651876</td>\n",
" <td>0.651916</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>-0.626991</td>\n",
" <td>0.736852</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>82</td>\n",
" <td>0.336345</td>\n",
" <td>-0.286141</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>388772</td>\n",
" <td>*Oc1ccc(S(=O)(=O)c2ccc(Oc3ccc(C4(c5ccc(Oc6ccc(...</td>\n",
" <td>NaN</td>\n",
" <td>0.378860</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*Oc1ccc(S(=O)(=O)c2ccc(Oc3ccc(C4(c5ccc(Oc6ccc(...</td>\n",
" <td>2.336573</td>\n",
" <td>2.336165</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>-0.626991</td>\n",
" <td>2.644047</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>134</td>\n",
" <td>1.657910</td>\n",
" <td>-0.109289</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>519416</td>\n",
" <td>*Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)...</td>\n",
" <td>NaN</td>\n",
" <td>0.387324</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)...</td>\n",
" <td>0.417716</td>\n",
" <td>0.417722</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>-0.626991</td>\n",
" <td>1.118291</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>79</td>\n",
" <td>0.556606</td>\n",
" <td>0.132247</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>539187</td>\n",
" <td>*Oc1ccc(OC(=O)c2cc(OCCCCCCCCCOCC3CCCN3c3ccc([N...</td>\n",
" <td>NaN</td>\n",
" <td>0.355470</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*Oc1ccc(OC(=O)c2cc(OCCCCCCCCCOCC3CCCN3c3ccc([N...</td>\n",
" <td>2.178003</td>\n",
" <td>2.178499</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>1.501149</td>\n",
" <td>0.355413</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>118</td>\n",
" <td>0.556606</td>\n",
" <td>-0.830501</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7968</th>\n",
" <td>2146592435</td>\n",
" <td>*Oc1cc(CCCCCCCC)cc(OC(=O)c2cccc(C(*)=O)c2)c1</td>\n",
" <td>NaN</td>\n",
" <td>0.367498</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*Oc1cc(CCCCCCCC)cc(OC(=O)c2cccc(C(*)=O)c2)c1</td>\n",
" <td>-0.375261</td>\n",
" <td>-0.375084</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>-0.626991</td>\n",
" <td>-0.407465</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>44</td>\n",
" <td>-0.324438</td>\n",
" <td>0.124891</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7969</th>\n",
" <td>2146810552</td>\n",
" <td>*C(=O)OCCN(CCOC(=O)c1ccc2c(c1)C(=O)N(c1cccc(N3...</td>\n",
" <td>NaN</td>\n",
" <td>0.353280</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*C(=O)OCCN(CCOC(=O)c1ccc2c(c1)C(=O)N(c1cccc(N3...</td>\n",
" <td>1.284275</td>\n",
" <td>1.284737</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>1.501149</td>\n",
" <td>0.736852</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>110</td>\n",
" <td>1.217388</td>\n",
" <td>0.008668</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7970</th>\n",
" <td>2147191531</td>\n",
" <td>*c1cc(C(=O)NCCCCCCCC)cc(N2C(=O)c3ccc(-c4ccc5c(...</td>\n",
" <td>NaN</td>\n",
" <td>0.369411</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*c1cc(C(=O)NCCCCCCCC)cc(N2C(=O)c3ccc(-c4ccc5c(...</td>\n",
" <td>0.329570</td>\n",
" <td>0.329823</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>1.501149</td>\n",
" <td>-0.026026</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>73</td>\n",
" <td>0.336345</td>\n",
" <td>0.021405</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7971</th>\n",
" <td>2147435020</td>\n",
" <td>*C=C(*)c1ccccc1C</td>\n",
" <td>261.662355</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*C=C(*)c1ccccc1C</td>\n",
" <td>-1.359802</td>\n",
" <td>-1.359728</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>-0.626991</td>\n",
" <td>-0.788904</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>16</td>\n",
" <td>-1.205481</td>\n",
" <td>-1.182617</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7972</th>\n",
" <td>2147438299</td>\n",
" <td>*c1ccc(OCCCCCCCCCCCOC(=O)CCCCC(=O)OCCCCCCCCCCC...</td>\n",
" <td>NaN</td>\n",
" <td>0.374049</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>*c1ccc(OCCCCCCCCCCCOC(=O)CCCCC(=O)OCCCCCCCCCCC...</td>\n",
" <td>1.160667</td>\n",
" <td>1.160653</td>\n",
" <td>...</td>\n",
" <td>-0.048476</td>\n",
" <td>-0.069289</td>\n",
" <td>0.437079</td>\n",
" <td>-0.407465</td>\n",
" <td>-0.051542</td>\n",
" <td>-0.047917</td>\n",
" <td>0.0</td>\n",
" <td>72</td>\n",
" <td>-0.324438</td>\n",
" <td>-1.005054</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>7973 rows Γ 92 columns</p>\n",
"</div>"
],
"text/plain": [
" id SMILES \\\n",
"0 87817 *CC(*)c1ccccc1C(=O)OCCCCCC \n",
"1 106919 *Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5... \n",
"2 388772 *Oc1ccc(S(=O)(=O)c2ccc(Oc3ccc(C4(c5ccc(Oc6ccc(... \n",
"3 519416 *Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)... \n",
"4 539187 *Oc1ccc(OC(=O)c2cc(OCCCCCCCCCOCC3CCCN3c3ccc([N... \n",
"... ... ... \n",
"7968 2146592435 *Oc1cc(CCCCCCCC)cc(OC(=O)c2cccc(C(*)=O)c2)c1 \n",
"7969 2146810552 *C(=O)OCCN(CCOC(=O)c1ccc2c(c1)C(=O)N(c1cccc(N3... \n",
"7970 2147191531 *c1cc(C(=O)NCCCCCCCC)cc(N2C(=O)c3ccc(-c4ccc5c(... \n",
"7971 2147435020 *C=C(*)c1ccccc1C \n",
"7972 2147438299 *c1ccc(OCCCCCCCCCCCOC(=O)CCCCC(=O)OCCCCCCCCCCC... \n",
"\n",
" Tg FFV Tc Density Rg \\\n",
"0 NaN 0.374645 0.205667 NaN NaN \n",
"1 NaN 0.370410 NaN NaN NaN \n",
"2 NaN 0.378860 NaN NaN NaN \n",
"3 NaN 0.387324 NaN NaN NaN \n",
"4 NaN 0.355470 NaN NaN NaN \n",
"... ... ... ... ... .. \n",
"7968 NaN 0.367498 NaN NaN NaN \n",
"7969 NaN 0.353280 NaN NaN NaN \n",
"7970 NaN 0.369411 NaN NaN NaN \n",
"7971 261.662355 NaN NaN NaN NaN \n",
"7972 NaN 0.374049 NaN NaN NaN \n",
"\n",
" Smiles mol_weight \\\n",
"0 *CC(*)c1ccccc1C(=O)OCCCCCC -0.875755 \n",
"1 *Nc1ccc([C@H](CCC)c2ccc(C3(c4ccc([C@@H](CCC)c5... 0.651876 \n",
"2 *Oc1ccc(S(=O)(=O)c2ccc(Oc3ccc(C4(c5ccc(Oc6ccc(... 2.336573 \n",
"3 *Nc1ccc(-c2c(-c3ccc(C)cc3)c(-c3ccc(C)cc3)c(N*)... 0.417716 \n",
"4 *Oc1ccc(OC(=O)c2cc(OCCCCCCCCCOCC3CCCN3c3ccc([N... 2.178003 \n",
"... ... ... \n",
"7968 *Oc1cc(CCCCCCCC)cc(OC(=O)c2cccc(C(*)=O)c2)c1 -0.375261 \n",
"7969 *C(=O)OCCN(CCOC(=O)c1ccc2c(c1)C(=O)N(c1cccc(N3... 1.284275 \n",
"7970 *c1cc(C(=O)NCCCCCCCC)cc(N2C(=O)c3ccc(-c4ccc5c(... 0.329570 \n",
"7971 *C=C(*)c1ccccc1C -1.359802 \n",
"7972 *c1ccc(OCCCCCCCCCCCOC(=O)CCCCC(=O)OCCCCCCCCCCC... 1.160667 \n",
"\n",
" exact_mol_weight ... num_3_rings num_4_rings num_5_rings \\\n",
"0 -0.875617 ... -0.048476 -0.069289 -0.626991 \n",
"1 0.651916 ... -0.048476 -0.069289 -0.626991 \n",
"2 2.336165 ... -0.048476 -0.069289 -0.626991 \n",
"3 0.417722 ... -0.048476 -0.069289 -0.626991 \n",
"4 2.178499 ... -0.048476 -0.069289 1.501149 \n",
"... ... ... ... ... ... \n",
"7968 -0.375084 ... -0.048476 -0.069289 -0.626991 \n",
"7969 1.284737 ... -0.048476 -0.069289 1.501149 \n",
"7970 0.329823 ... -0.048476 -0.069289 1.501149 \n",
"7971 -1.359728 ... -0.048476 -0.069289 -0.626991 \n",
"7972 1.160653 ... -0.048476 -0.069289 0.437079 \n",
"\n",
" num_6_rings num_7_rings num_large_rings has_polymer_notation \\\n",
"0 -0.788904 -0.051542 -0.047917 0.0 \n",
"1 0.736852 -0.051542 -0.047917 0.0 \n",
"2 2.644047 -0.051542 -0.047917 0.0 \n",
"3 1.118291 -0.051542 -0.047917 0.0 \n",
"4 0.355413 -0.051542 -0.047917 0.0 \n",
"... ... ... ... ... \n",
"7968 -0.407465 -0.051542 -0.047917 0.0 \n",
"7969 0.736852 -0.051542 -0.047917 0.0 \n",
"7970 -0.026026 -0.051542 -0.047917 0.0 \n",
"7971 -0.788904 -0.051542 -0.047917 0.0 \n",
"7972 -0.407465 -0.051542 -0.047917 0.0 \n",
"\n",
" smiles_length branch_count branch_ratio \n",
"0 26 -0.985221 -0.813832 \n",
"1 82 0.336345 -0.286141 \n",
"2 134 1.657910 -0.109289 \n",
"3 79 0.556606 0.132247 \n",
"4 118 0.556606 -0.830501 \n",
"... ... ... ... \n",
"7968 44 -0.324438 0.124891 \n",
"7969 110 1.217388 0.008668 \n",
"7970 73 0.336345 0.021405 \n",
"7971 16 -1.205481 -1.182617 \n",
"7972 72 -0.324438 -1.005054 \n",
"\n",
"[7973 rows x 92 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features_df = pd.read_csv('7k_w_descriptors.csv')\n",
"features_df"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "49998b8a-3925-4383-917a-116f70187d46",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
}
],
"source": [
"old_len = len(features_df)\n",
"new_len = len(features_df.drop_duplicates())\n",
"print(new_len - old_len)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c2f08ca9-21f6-4a79-ab94-80556b8dab1d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|βββββββββββββββββββββββββββββββββββββ| 6378/6378 [00:01<00:00, 3492.45it/s]\n",
"100%|βββββββββββββββββββββββββββββββββββββ| 1595/1595 [00:00<00:00, 3576.37it/s]\n"
]
}
],
"source": [
"import torch\n",
"from tqdm import tqdm\n",
"import copy\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def create_splits(df):\n",
" train, test = train_test_split(df, test_size=0.2)\n",
" return train, test\n",
"\n",
"def create_samples(df, features):\n",
" samples = []\n",
" features_without_smiles = copy.deepcopy(features)\n",
" features_without_smiles.remove('Smiles')\n",
" for i, row in tqdm(df.iterrows(), total=len(df)):\n",
" properties = torch.Tensor(row[features_without_smiles].to_list())\n",
" sample = {'Smiles': row['Smiles'], 'property_tensor': properties}\n",
" samples.append(sample)\n",
" return samples\n",
"\n",
"train, val = create_splits(features_df.reset_index(drop=True))\n",
"\n",
"train = train.reset_index(drop=True)\n",
"val = val.reset_index(drop=True)\n",
"\n",
"train_list = create_samples(train, new_features)\n",
"val_list = create_samples(val, new_features)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2fdb3171-deda-4c1f-ae4b-853d781ffdd5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|βββββββββββββββββββββββββββββββββββββββ| 20/20 [00:00<00:00, 106050.67it/s]\n"
]
}
],
"source": [
"from sklearn.metrics.pairwise import cosine_similarity\n",
"\n",
"prop_vectors = [el['property_tensor'] for el in train_list[:20]]\n",
"\n",
"sim_matrix = cosine_similarity(prop_vectors)\n",
" \n",
"n = len(prop_vectors)\n",
"positive_pairs, negative_candidates = [], []\n",
"sims = []\n",
"\n",
"positive_threshold = 0.9\n",
"negative_threshold = 0.2\n",
"\n",
"for i in tqdm(range(n)):\n",
" for j in range(i + 1, n):\n",
" sim = sim_matrix[i, j]\n",
"\n",
" if sim > positive_threshold:\n",
" positive_pairs.append((i, j, sim))\n",
" elif sim < negative_threshold:\n",
" negative_candidates.append((i, j, sim))\n",
" sims.append(float(sim))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "54f29e98-7c32-441c-bb1b-cdaf3fd1df49",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 126)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(positive_pairs), len(negative_candidates)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "22e0f46e-2673-4840-95fd-f98914e57b78",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8e7e795220>]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"plt.plot(sims)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "79e7e873-7950-4123-ab13-299360ae19ca",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from transformers import BertConfig, BertModel, AutoTokenizer\n",
"import pickle\n",
"import numpy as np\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"\n",
"def global_ap(x):\n",
" return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n",
"\n",
"class SimSonEncoder(nn.Module):\n",
" def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n",
" super(SimSonEncoder, self).__init__()\n",
" self.config = config\n",
" self.max_len = max_len\n",
" \n",
" self.bert = BertModel(config, add_pooling_layer=False)\n",
" \n",
" self.linear = nn.Linear(config.hidden_size, max_len)\n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, input_ids, attention_mask=None):\n",
" if attention_mask is None:\n",
" attention_mask = input_ids.ne(0)\n",
" \n",
" outputs = self.bert(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask\n",
" )\n",
" \n",
" hidden_states = outputs.last_hidden_state\n",
" \n",
" hidden_states = self.dropout(hidden_states)\n",
" \n",
" pooled = global_ap(hidden_states)\n",
" \n",
" out = self.linear(pooled)\n",
" \n",
" return out\n",
"\n",
"def initialize_model_and_tokenizer():\n",
" \"\"\"Initialize BERT model from config and ChemBERTa tokenizer\"\"\"\n",
" \n",
" \n",
" tokenizer = AutoTokenizer.from_pretrained(\"DeepChem/ChemBERTa-77M-MTR\")\n",
" config = BertConfig(\n",
" vocab_size=tokenizer.vocab_size,\n",
" hidden_size=768,\n",
" num_hidden_layers=4,\n",
" num_attention_heads=12,\n",
" intermediate_size=2048,\n",
" max_position_embeddings=512,\n",
" )\n",
" model = SimSonEncoder(config=config, max_len=512).cuda()\n",
" return model, tokenizer\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8a3adaff-da65-46b4-b9ee-95851d786a67",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"\n",
"class MolecularContrastiveDataset(Dataset):\n",
" def __init__(self, data_list, tokenizer, positive_threshold=0.9, cache_path=None, split_type='train'):\n",
" \"\"\"\n",
" Dataset that only contains positive pairs for NT-Xent contrastive learning\n",
" \"\"\"\n",
" self.data_list = data_list\n",
" self.tokenizer = tokenizer\n",
" self.positive_threshold = positive_threshold\n",
" self.cache_path = cache_path\n",
" self.split_type = split_type\n",
"\n",
" # Load or compute pairs\n",
" if cache_path and os.path.exists(cache_path) and os.path.getsize(cache_path) > 0:\n",
" print(f\"Loading cached pairs from {cache_path}\")\n",
" self._load_pairs()\n",
" else:\n",
" print(\"Computing positive pairs only...\")\n",
" self._compute_positive_pairs()\n",
" if cache_path:\n",
" self._save_pairs()\n",
" \n",
" def _compute_positive_pairs(self):\n",
" \"\"\"\n",
" Compute ONLY positive pairs based on descriptor similarity\n",
" \"\"\"\n",
" # --- 1. Cosine-similarity matrix ---------------------------------------\n",
" prop_vectors = torch.stack(\n",
" [item['property_tensor'] for item in self.data_list]\n",
" ).numpy()\n",
" sim_matrix = cosine_similarity(prop_vectors)\n",
"\n",
" n = len(self.data_list)\n",
" positive_pairs = []\n",
" pairs_per_molecule = 1 # STRICTLY ONE FOR CREATING PROPER NEGATIVE PAIRS\n",
" current_pairs_per_molecule = 0\n",
" # --- 2. Collect only positive pairs ------------------------------------\n",
" print(f'Collecting positive pairs with similarity threshold {self.positive_threshold}')\n",
" for i in tqdm(range(n)):\n",
" for j in range(i + 1, n):\n",
" sim = sim_matrix[i, j]\n",
" if sim > self.positive_threshold:\n",
" positive_pairs.append((i, j, sim))\n",
" current_pairs_per_molecule += 1\n",
" if current_pairs_per_molecule > pairs_per_molecule:\n",
" current_pairs_per_molecule = 0\n",
" break\n",
"\n",
" # --- 3. Store only positive pairs --------------------------------------\n",
" if len(positive_pairs) == 0:\n",
" raise ValueError(\"No positive pairs found β lower the positive_threshold.\")\n",
"\n",
" # No shuffling - we want consistent positive pairs\n",
" self.pairs = [(i, j) for i, j, _ in positive_pairs]\n",
" self.descriptor_similarities = [sim for _, _, sim in positive_pairs]\n",
"\n",
" print(f\"Generated {len(self.pairs)} positive pairs\")\n",
"\n",
" def _save_pairs(self):\n",
" \"\"\"Save computed pairs to cache file\"\"\"\n",
" cache_data = {\n",
" 'pairs': self.pairs,\n",
" 'descriptor_similarities': self.descriptor_similarities\n",
" }\n",
" with open(self.cache_path, 'wb') as f:\n",
" pickle.dump(cache_data, f)\n",
" print(f\"Cached pairs saved to {self.cache_path}\")\n",
" \n",
" def _load_pairs(self):\n",
" \"\"\"Load pairs from cache file\"\"\"\n",
" with open(self.cache_path, 'rb') as f:\n",
" cache_data = pickle.load(f)\n",
" \n",
" self.pairs = cache_data['pairs']\n",
" self.descriptor_similarities = cache_data['descriptor_similarities']\n",
" \n",
" def __len__(self):\n",
" return len(self.pairs)\n",
" \n",
" def __getitem__(self, idx):\n",
" i, j = self.pairs[idx]\n",
" desc_sim = self.descriptor_similarities[idx]\n",
" \n",
" # Get SMILES for both molecules\n",
" smiles_i = self.data_list[i]['Smiles']\n",
" smiles_j = self.data_list[j]['Smiles']\n",
" if self.split_type == 'val':\n",
" print(f'POSITIVE PAIR SMILES: \\n{smiles_i} \\n {smiles_j}')\n",
" # Tokenize SMILES\n",
" tokens_i = self.tokenizer(\n",
" smiles_i, \n",
" return_tensors='pt', \n",
" padding='max_length', \n",
" truncation=True, \n",
" max_length=256\n",
" )\n",
" tokens_j = self.tokenizer(\n",
" smiles_j, \n",
" return_tensors='pt', \n",
" padding='max_length', \n",
" truncation=True, \n",
" max_length=256\n",
" )\n",
" \n",
" # Remove batch dimension\n",
" tokens_i = {key: val.squeeze(0) for key, val in tokens_i.items()}\n",
" tokens_j = {key: val.squeeze(0) for key, val in tokens_j.items()}\n",
" \n",
" # Get property vectors\n",
" prop_vec_i = self.data_list[i]['property_tensor']\n",
" prop_vec_j = self.data_list[j]['property_tensor']\n",
" \n",
" return {\n",
" 'tokens_i': tokens_i,\n",
" 'tokens_j': tokens_j,\n",
" 'descriptor_similarity': torch.tensor(desc_sim, dtype=torch.float32),\n",
" 'property_tensor_i': prop_vec_i,\n",
" 'property_tensor_j': prop_vec_j\n",
" }\n",
"\n",
"\n",
"def contrastive_collate_fn(batch):\n",
" \"\"\"\n",
" Collate function that creates proper NT-Xent batches:\n",
" - Element 0 and 1 are positive pairs\n",
" - Element 2 and 3 are positive pairs \n",
" - etc.\n",
" \"\"\"\n",
" batch_size = len(batch)\n",
" \n",
" # Ensure even batch size for proper pairing\n",
" if batch_size % 2 != 0:\n",
" batch = batch[:-1] # Drop last element if odd\n",
" batch_size = len(batch)\n",
" \n",
" # Interleave: [sample1_i, sample1_j, sample2_i, sample2_j, ...]\n",
" tokens_list = []\n",
" desc_similarities = []\n",
" \n",
" for i in range(0, batch_size, 1):\n",
" # Add first molecule of pair i\n",
" tokens_list.append(batch[i]['tokens_i'])\n",
" desc_similarities.append(batch[i]['descriptor_similarity'])\n",
" \n",
" # Add second molecule of pair i (positive pair)\n",
" tokens_list.append(batch[i]['tokens_j'])\n",
" desc_similarities.append(batch[i]['descriptor_similarity']) # Same similarity for both elements in pair\n",
" \n",
" # Stack all tokens\n",
" tokens = {}\n",
" for key in tokens_list[0].keys():\n",
" tokens[key] = torch.stack([item[key] for item in tokens_list])\n",
" \n",
" desc_similarities_tensor = torch.stack(desc_similarities)\n",
" \n",
" return {\n",
" 'tokens': tokens,\n",
" 'descriptor_similarities': desc_similarities_tensor,\n",
" }\n",
"\n",
"\n",
"def create_dataloaders(train_list, val_list, tokenizer, batch_size=32, \n",
" positive_threshold=0.85, cache_dir=\"cache\"):\n",
" \"\"\"Create train and validation dataloaders for NT-Xent\"\"\"\n",
" os.makedirs(cache_dir, exist_ok=True)\n",
" \n",
" # Ensure even batch size for proper pairing\n",
" if batch_size % 2 != 0:\n",
" batch_size += 1\n",
" print(f\"Adjusted batch_size to {batch_size} (must be even for NT-Xent)\")\n",
" \n",
" train_cache = os.path.join(cache_dir, 'train_positive_pairs.pkl')\n",
" val_cache = os.path.join(cache_dir, 'val_positive_pairs.pkl')\n",
" \n",
" train_dataset = MolecularContrastiveDataset(\n",
" train_list, tokenizer, positive_threshold=positive_threshold, cache_path=train_cache\n",
" )\n",
" val_dataset = MolecularContrastiveDataset(\n",
" val_list, tokenizer, positive_threshold=positive_threshold, cache_path=val_cache, split_type='val',\n",
" )\n",
" \n",
" train_loader = DataLoader(\n",
" train_dataset, batch_size=batch_size, shuffle=True, collate_fn=contrastive_collate_fn, drop_last=True, pin_memory=True\n",
" )\n",
" val_loader = DataLoader(\n",
" val_dataset, batch_size=batch_size, shuffle=False, collate_fn=contrastive_collate_fn, drop_last=True, pin_memory=True\n",
" )\n",
" \n",
" return train_loader, val_loader\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f956a50b-85a5-49df-b7c6-6e40dce160e1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model initialized with 23,299,840 trainable parameters\n"
]
}
],
"source": [
"def nt_xent_loss_with_temp_scaling(embeddings1, embeddings2, descriptor_similarity, base_temp=0.02):\n",
" batch_size = embeddings1.shape[0]\n",
" device = embeddings1.device\n",
" #individual_temperatures = sigmoid_temp_scaling(descriptor_similarity, base_temp)\n",
" #temperature = individual_temperatures.mean() # Single temperature for the whole batch\n",
" temperature = base_temp\n",
" # Normalize projections\n",
" z_i = F.normalize(embeddings1, p=2, dim=1)\n",
" z_j = F.normalize(embeddings2, p=2, dim=1)\n",
" \n",
" # Concatenate for similarity matrix calculation\n",
" representations = torch.cat([z_i, z_j], dim=0)\n",
" # Calculate cosine similarity between all pairs\n",
" similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)\n",
" #similarity_matrix = torch.clamp(similarity_matrix, min=-0.999, max=0.999)\n",
" sim_ij = torch.diag(similarity_matrix, batch_size)\n",
" sim_ji = torch.diag(similarity_matrix, -batch_size)\n",
" positives = torch.cat([sim_ij, sim_ji], dim=0)\n",
" \n",
" # Create a mask to exclude self-comparisons\n",
" nominator = torch.exp(positives / temperature)\n",
" mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float()\n",
" denominator = mask * torch.exp(similarity_matrix / temperature)\n",
" \n",
" # Calculate the final loss\n",
" loss = -torch.log(nominator / torch.sum(denominator, dim=1))\n",
" if torch.isnan(loss).any():\n",
" print(similarity_matrix)\n",
" print(f\"Temperature: {temperature}\")\n",
" print(f\"Nominator range: {nominator.min().item():.6f} to {nominator.max().item():.6f}\")\n",
" \n",
" return torch.sum(loss) / (2 * batch_size)\n",
"\n",
"\n",
"def sigmoid_temp_scaling(descriptor_similarity, base_temp=0.05, steepness=10.0, midpoint=0.5):\n",
" \"\"\"Smooth sigmoid-based temperature scaling\"\"\"\n",
" sigmoid_factor = torch.sigmoid(steepness * (descriptor_similarity - midpoint))\n",
" temperature = base_temp * (2.0 - sigmoid_factor)\n",
" return temperature\n",
"\n",
"\n",
"def train_step(batch, model, optimizer, device, scheduler, base_temp=0.1):\n",
" \"\"\"Single training step for NT-Xent\"\"\"\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" \n",
" # Move batch to device\n",
" tokens = {k: v.to(device) for k, v in batch['tokens'].items()}\n",
" desc_similarities = batch['descriptor_similarities'].to(device)\n",
" \n",
" # Forward pass - get embeddings for all samples\n",
" outputs = model(**tokens) # i1, j1, i2, j2 ...\n",
" embeddings = outputs\n",
" \n",
" # Split embeddings: even indices are embeddings1, odd indices are embeddings2\n",
" embeddings1 = embeddings[::2] # [0, 2, 4, ...]\n",
" embeddings2 = embeddings[1::2] # [1, 3, 5, ...]\n",
" \n",
" # Get descriptor similarities for each pair (take every other one since they're duplicated)\n",
" pair_desc_similarities = desc_similarities[::2]\n",
" #print(f'FIRST TRAIN EMBED: {embeddings1}')\n",
" #print(f'SECOND TRAIN EMBED: {embeddings2}')\n",
" #print(f'COSINE SIM BETWEEN THEM TRAIN: {F.cosine_similarity(embeddings1, embeddings2, dim=1)}')\n",
" # Calculate NT-Xent loss\n",
" loss = nt_xent_loss_with_temp_scaling(embeddings1, embeddings2, pair_desc_similarities, base_temp=base_temp)\n",
" \n",
" # Backward pass\n",
" loss.backward()\n",
" optimizer.step()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
" scheduler.step()\n",
" return loss.item()\n",
"\n",
"def val_step(batch, model, device, base_temp=0.1):\n",
" \"\"\"Single validation step for NT-Xent\"\"\"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" # Move batch to device\n",
" tokens = {k: v.to(device) for k, v in batch['tokens'].items()}\n",
" desc_similarities = batch['descriptor_similarities'].to(device)\n",
" \n",
" # Forward pass\n",
" outputs = model(**tokens)\n",
" embeddings = outputs\n",
" \n",
" # Split embeddings\n",
" embeddings1 = embeddings[::2]\n",
" embeddings2 = embeddings[1::2]\n",
" \n",
" # Get descriptor similarities for pairs\n",
" pair_desc_similarities = desc_similarities[::2]\n",
" \n",
" print(f'FIRST VAL EMBED: {embeddings1}')\n",
" print(f'SECOND VAL EMBED: {embeddings2}')\n",
" print(f'COSINE SIM BETWEEN THEM: {F.cosine_similarity(embeddings1, embeddings2, dim=1)}')\n",
" #print(f'SECOND VAL EMBED: {embeddings2}')\n",
" loss = nt_xent_loss_with_temp_scaling(embeddings1, embeddings2, pair_desc_similarities, base_temp=base_temp)\n",
" print(f'VAL LOSS: {loss}')\n",
" \n",
" return loss.item()\n",
"\n",
"def train_epoch(train_loader, model, optimizer, scheduler, base_temp=0.01):\n",
" \"\"\"Train for one epoch\"\"\"\n",
" total_loss = 0\n",
" num_batches = 0\n",
" \n",
" progress_bar = tqdm(train_loader, desc=\"Training\")\n",
" \n",
" for batch in progress_bar:\n",
" loss = train_step(batch, model, optimizer, 'cuda', scheduler, base_temp=base_temp)\n",
" total_loss += loss\n",
" num_batches += 1\n",
" \n",
" # Calculate running average loss\n",
" avg_loss = total_loss / num_batches\n",
" \n",
" # Update progress bar with current loss info\n",
" progress_bar.set_postfix({\n",
" 'Loss': f'{loss:.4f}',\n",
" 'Avg Loss': f'{avg_loss:.4f}'\n",
" })\n",
" \n",
" return total_loss / num_batches if num_batches > 0 else 0\n",
"\n",
"\n",
"def validate_epoch(val_loader, model, base_temp=0.01):\n",
" \"\"\"Validate for one epoch\"\"\"\n",
" total_loss = 0\n",
" num_batches = 0\n",
" print('nah twin')\n",
" return 0\n",
" for batch in val_loader:\n",
" loss = val_step(batch, model, 'cuda', base_temp=base_temp)\n",
" total_loss += loss\n",
" num_batches += 1\n",
" \n",
" return total_loss / num_batches if num_batches > 0 else 0\n",
"\n",
"def training_loop(train_loader, val_loader, model, tokenizer, epochs=50, patience=5, lr=1e-4, base_temp=0.02,\n",
" device_name='cuda', save_path='best_model.pt'):\n",
" \"\"\"Main training loop with early stopping\"\"\"\n",
" device = torch.device(device_name if torch.cuda.is_available() else 'cpu')\n",
" print(f\"Using device: {device}\")\n",
" \n",
" # Initialize model and optimizer\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
" optimizer.zero_grad()\n",
"\n",
" total_steps = epochs * len(train_loader)\n",
" scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=total_steps)\n",
" # Early stopping variables\n",
" best_val_loss = float('inf')\n",
" no_improve_epochs = 0\n",
" \n",
" print(\"Starting training...\")\n",
" \n",
" for epoch in range(epochs):\n",
" # Training\n",
" with torch.autocast(dtype=torch.float16, device_type='cuda'):\n",
" train_loss = train_epoch(train_loader, model, optimizer, scheduler, base_temp=base_temp)\n",
" print('END TRAIN')\n",
" # Validation\n",
" val_loss = validate_epoch(val_loader, model)\n",
" \n",
" print(f\"Epoch {epoch + 1}/{epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}\")\n",
" \n",
" # Early stopping check\n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" no_improve_epochs = 0\n",
" # Save best model\n",
" torch.save(model.state_dict(), save_path)\n",
" print(f\"New best model saved with val loss: {val_loss:.4f}\")\n",
" else:\n",
" no_improve_epochs += 1\n",
" print(f\"No improvement for {no_improve_epochs} epochs\")\n",
" \n",
" if no_improve_epochs >= patience:\n",
" print(f\"Early stopping triggered after {epoch + 1} epochs\")\n",
" break\n",
" \n",
" # Load best model\n",
" print(f\"Loading best model from {save_path}\")\n",
" model.load_state_dict(torch.load(save_path))\n",
" model.eval()\n",
" \n",
" print(f\"Training completed. Best validation loss: {best_val_loss:.4f}\")\n",
"\n",
"\n",
"model, tokenizer = initialize_model_and_tokenizer()\n",
"#model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/regression/actual_encoder_state.pkl', weights_only=False))\n",
"print(f\"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters\")\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "c73e2bba-59c1-4b41-b2ff-235526dd2912",
"metadata": {},
"outputs": [],
"source": [
"!rm -rf cache"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0072c8f5-c5e9-4590-9544-c73cf1fac1e8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Computing positive pairs only...\n",
"Collecting positive pairs with similarity threshold 0.8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|ββββββββββββββββββββββββββββββββββββ| 6378/6378 [00:00<00:00, 55896.48it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated 12538 positive pairs\n",
"Cached pairs saved to cache/train_positive_pairs.pkl\n",
"Computing positive pairs only...\n",
"Collecting positive pairs with similarity threshold 0.8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|βββββββββββββββββββββββββββββββββββββ| 100/100 [00:00<00:00, 206209.64it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated 129 positive pairs\n",
"Cached pairs saved to cache/val_positive_pairs.pkl\n",
"Using device: cuda\n",
"Starting training...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|β| 1567/1567 [00:31<00:00, 49.37it/s, Loss=0.9300, Avg Loss=0.989\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"END TRAIN\n",
"nah twin\n",
"Epoch 1/10: Train Loss = 0.9891, Val Loss = 0.0000\n",
"New best model saved with val loss: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 100%|β| 1567/1567 [00:31<00:00, 50.05it/s, Loss=2.7072, Avg Loss=2.712\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"END TRAIN\n",
"nah twin\n",
"Epoch 2/10: Train Loss = 2.7125, Val Loss = 0.0000\n",
"No improvement for 1 epochs\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training: 77%|β| 1204/1567 [00:24<00:07, 49.73it/s, Loss=2.7080, Avg Loss=2.708"
]
}
],
"source": [
"train_loader, val_loader = create_dataloaders(\n",
" train_list, val_list[:100], tokenizer, \n",
" batch_size=128, positive_threshold=0.8\n",
")\n",
"\n",
"training_loop(\n",
" train_loader, val_loader, model, tokenizer,\n",
" epochs=10, patience=5, lr=1e-3, \n",
" device_name='cuda', base_temp=0.1\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58343b16-1bdb-4476-ac61-e797fbc661d2",
"metadata": {},
"outputs": [],
"source": [
"print(train_list[:5], '\\n\\n', val_list[:5])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47561022-5f57-4b7b-b903-ef1f8773f903",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fcef978-3630-4201-9301-6963a8560517",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.mlspace-bolgov_simson_training]",
"language": "python",
"name": "conda-env-.mlspace-bolgov_simson_training-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|