import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from transformers import BertConfig, BertModel, AutoTokenizer from rdkit import Chem, RDLogger from rdkit.Chem.Scaffolds import MurckoScaffold import copy from tqdm import tqdm import os from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error from itertools import compress from collections import defaultdict from sklearn.metrics.pairwise import cosine_similarity RDLogger.DisableLog('rdApp.*') torch.set_float32_matmul_precision('high') # --- 0. Smiles enumeration class SmilesEnumerator: """Generates randomized SMILES strings for data augmentation.""" def randomize_smiles(self, smiles): try: mol = Chem.MolFromSmiles(smiles) return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles except: return smiles def compute_embedding_similarity(encoder, smiles_list, tokenizer, device, max_len=256): encoder.eval() enumerator = SmilesEnumerator() embeddings_orig = [] embeddings_aug = [] with torch.no_grad(): for smi in smiles_list: # Original SMILES encoding encoding_orig = tokenizer( smi, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt' ) # Augmented SMILES encoding smi_aug = enumerator.randomize_smiles(smi) encoding_aug = tokenizer( smi_aug, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt' ) input_ids_orig = encoding_orig.input_ids.to(device) attention_mask_orig = encoding_orig.attention_mask.to(device) input_ids_aug = encoding_aug.input_ids.to(device) attention_mask_aug = encoding_aug.attention_mask.to(device) emb_orig = encoder(input_ids_orig, attention_mask_orig).cpu().numpy().flatten() emb_aug = encoder(input_ids_aug, attention_mask_aug).cpu().numpy().flatten() embeddings_orig.append(emb_orig) embeddings_aug.append(emb_aug) embeddings_orig = np.array(embeddings_orig) embeddings_aug = np.array(embeddings_aug) # Cosine similarity between each original and its augmented version similarities = np.array([cosine_similarity([embeddings_orig[i]], [embeddings_aug[i]])[0][0] for i in range(len(embeddings_orig))]) return similarities # --- 1. Data Loading --- def load_lists_from_url(data): if data == 'bbbp': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') smiles, labels = df.smiles, df.p_np elif data == 'clintox': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') smiles = df.smiles labels = df.drop(['smiles'], axis=1) elif data == 'hiv': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') smiles, labels = df.smiles, df.HIV_active elif data == 'sider': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') smiles = df.smiles labels = df.drop(['smiles'], axis=1) elif data == 'esol': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv') smiles = df.smiles labels = df['ESOL predicted log solubility in mols per litre'] elif data == 'freesolv': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv') smiles = df.smiles labels = df.calc elif data == 'lipophicility': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv') smiles, labels = df.smiles, df['exp'] elif data == 'tox21': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') df = df.dropna(axis=0, how='any').reset_index(drop=True) smiles = df.smiles labels = df.drop(['mol_id', 'smiles'], axis=1) elif data == 'bace': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') smiles, labels = df.mol, df.Class elif data == 'qm8': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv') df = df.dropna(axis=0, how='any').reset_index(drop=True) smiles = df.smiles labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) return smiles, labels # --- 2. Scaffold Splitting --- class ScaffoldSplitter: def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True): self.data = data self.seed = seed self.include_chirality = include_chirality self.train_frac = train_frac self.val_frac = val_frac self.test_frac = test_frac def generate_scaffold(self, smiles): mol = Chem.MolFromSmiles(smiles) scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality) return scaffold def scaffold_split(self): smiles, labels = load_lists_from_url(self.data) non_null = np.ones(len(smiles)) == 0 if self.data in {'tox21', 'sider', 'clintox'}: for i in range(len(smiles)): if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0: non_null[i] = 1 else: for i in range(len(smiles)): if Chem.MolFromSmiles(smiles[i]): non_null[i] = 1 smiles_list = list(compress(enumerate(smiles), non_null)) rng = np.random.RandomState(self.seed) scaffolds = defaultdict(list) for i, sms in smiles_list: scaffold = self.generate_scaffold(sms) scaffolds[scaffold].append(i) scaffold_sets = list(scaffolds.values()) rng.shuffle(scaffold_sets) n_total_val = int(np.floor(self.val_frac * len(smiles_list))) n_total_test = int(np.floor(self.test_frac * len(smiles_list))) train_idx, val_idx, test_idx = [], [], [] for scaffold_set in scaffold_sets: if len(val_idx) + len(scaffold_set) <= n_total_val: val_idx.extend(scaffold_set) elif len(test_idx) + len(scaffold_set) <= n_total_test: test_idx.extend(scaffold_set) else: train_idx.extend(scaffold_set) return train_idx, val_idx, test_idx # --- 2a. Normal Random Split --- def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1): np.random.seed(seed) indices = np.random.permutation(n) n_train = int(n * train_frac) n_val = int(n * val_frac) train_idx = indices[:n_train] val_idx = indices[n_train:n_train+n_val] test_idx = indices[n_train+n_val:] return train_idx.tolist(), val_idx.tolist(), test_idx.tolist() # --- 3. PyTorch Dataset --- class MoleculeDataset(Dataset): def __init__(self, smiles_list, labels, tokenizer, max_len=512): self.smiles_list = smiles_list self.labels = labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.smiles_list) def __getitem__(self, idx): smiles = self.smiles_list[idx] label = self.labels.iloc[idx] encoding = self.tokenizer( smiles, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt' ) item = {key: val.squeeze(0) for key, val in encoding.items()} if isinstance(label, pd.Series): label_values = label.values.astype(np.float32) else: label_values = np.array([label], dtype=np.float32) item['labels'] = torch.tensor(label_values, dtype=torch.float) return item # --- 4. Model Architecture --- def global_ap(x): return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1) class SimSonEncoder(nn.Module): def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): super(SimSonEncoder, self).__init__() self.config = config self.max_len = max_len self.bert = BertModel(config, add_pooling_layer=False) self.linear = nn.Linear(config.hidden_size, max_len) self.dropout = nn.Dropout(dropout) def forward(self, input_ids, attention_mask=None): if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) hidden_states = self.dropout(outputs.last_hidden_state) pooled = global_ap(hidden_states) return self.linear(pooled) class SimSonClassifier(nn.Module): def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1): super(SimSonClassifier, self).__init__() self.encoder = encoder self.clf = nn.Linear(encoder.max_len, num_labels) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) def forward(self, input_ids, attention_mask=None): x = self.encoder(input_ids, attention_mask) x = self.relu(self.dropout(x)) logits = self.clf(x) return logits def load_encoder_params(self, state_dict_path): self.encoder.load_state_dict(torch.load(state_dict_path)) print("Pretrained encoder parameters loaded.") # --- 5. Training, Validation, and Testing Loops --- def get_criterion(task_type, num_labels): if task_type == 'classification': return nn.BCEWithLogitsLoss() elif task_type == 'regression': return nn.MSELoss() else: raise ValueError(f"Unknown task type: {task_type}") def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): model.train() total_loss = 0 for batch in dataloader: inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} labels = batch['labels'].to(device) optimizer.zero_grad() outputs = model(**inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() #scheduler.step() total_loss += loss.item() return total_loss / len(dataloader) def eval_epoch(model, dataloader, criterion, device): model.eval() total_loss = 0 with torch.no_grad(): for batch in dataloader: inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} labels = batch['labels'].to(device) outputs = model(**inputs) loss = criterion(outputs, labels) total_loss += loss.item() return total_loss / len(dataloader) def test_model(model, dataloader, device): model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for batch in dataloader: inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} labels = batch['labels'] outputs = model(**inputs) preds = torch.sigmoid(outputs) all_preds.append(preds.cpu().numpy()) all_labels.append(labels.numpy()) return np.concatenate(all_preds), np.concatenate(all_labels) def calc_val_metrics(model, dataloader, criterion, device, task_type): model.eval() all_labels, all_preds = [], [] total_loss = 0 with torch.no_grad(): for batch in dataloader: inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} labels = batch['labels'].to(device) outputs = model(**inputs) loss = criterion(outputs, labels) total_loss += loss.item() if task_type == 'classification': pred_probs = torch.sigmoid(outputs).cpu().numpy() all_preds.append(pred_probs) all_labels.append(labels.cpu().numpy()) else: # Regression preds = outputs.cpu().numpy() all_preds.append(preds) all_labels.append(labels.cpu().numpy()) avg_loss = total_loss / len(dataloader) if task_type == 'classification': y_true = np.concatenate(all_labels) y_pred = np.concatenate(all_preds) try: score = roc_auc_score(y_true, y_pred, average='macro') except Exception: score = 0.0 return avg_loss, score else: return avg_loss, None # --- 6. Main Execution Block --- def main(): DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {DEVICE}") DATASETS_TO_RUN = { # 'esol': {'task_type': 'regression', 'num_labels': 1, 'split': 'random'}, #'tox21': {'task_type': 'classification', 'num_labels': 12, 'split': 'random'}, #'hiv': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'}, # Add more datasets here, e.g. 'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'}, #'sider': {'task_type': 'classification', 'num_labels': 27, 'split': 'random'}, #'bace': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'}, 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'random'}, #'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'} } PATIENCE = 15 EPOCHS = 50 LEARNING_RATE = 1e-4 BATCH_SIZE = 16 MAX_LEN = 512 TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') ENCODER_CONFIG = BertConfig( vocab_size=TOKENIZER.vocab_size, hidden_size=768, num_hidden_layers=4, num_attention_heads=12, intermediate_size=2048, max_position_embeddings=512 ) aggregated_results = {} for name, info in DATASETS_TO_RUN.items(): print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}") smiles, labels = load_lists_from_url(name) # Split selection if info.get('split', 'scaffold') == 'scaffold': splitter = ScaffoldSplitter(data=name, seed=42) train_idx, val_idx, test_idx = splitter.scaffold_split() elif info['split'] == 'random': train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42) else: raise ValueError(f"Unknown split type for {name}: {info['split']}") train_smiles = smiles.iloc[train_idx].reset_index(drop=True) train_labels = labels.iloc[train_idx].reset_index(drop=True) val_smiles = smiles.iloc[val_idx].reset_index(drop=True) val_labels = labels.iloc[val_idx].reset_index(drop=True) test_smiles = smiles.iloc[test_idx].reset_index(drop=True) test_labels = labels.iloc[test_idx].reset_index(drop=True) print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}") train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN) val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN) test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) encoder = SimSonEncoder(ENCODER_CONFIG, 512) encoder = torch.compile(encoder) model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE) model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') criterion = get_criterion(info['task_type'], info['num_labels']) optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0024) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.59298) best_val_loss = float('-inf') best_model_state = None current_patience = 0 for epoch in range(EPOCHS): train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE) val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, 'cuda', info['task_type']) print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}") if val_metric <= val_loss: best_val_loss = val_loss best_model_state = copy.deepcopy(model.state_dict()) print(f" -> New best model saved with validation loss: {best_val_loss:.4f}") current_patience = 0 else: current_patience += 1 if current_patience >= PATIENCE: print(f'Early stopping at {PATIENCE} epochs') break print("\nTesting with the best model...") if not best_model_state is None: model.load_state_dict(best_model_state) test_loss = eval_epoch(model, test_loader, criterion, DEVICE) print(f'Test loss: {test_loss}') test_preds, test_true = test_model(model, test_loader, DEVICE) aggregated_results[name] = { 'best_val_loss': best_val_loss, 'test_predictions': test_preds, 'test_labels': test_true } print(f"Finished testing for {name}.") test_smiles_list = list(test_smiles) similarities = compute_embedding_similarity( model.encoder, test_smiles_list, TOKENIZER, DEVICE, MAX_LEN ) print(f"Similarity score: {similarities.mean():.4f}") if name == 'do_not_save': torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin') print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}") for name, result in aggregated_results.items(): if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']: auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro') print(f'{name} ROC AUC: {auc}') if name in ['lipophicility', 'esol', 'qm8']: rmse = root_mean_squared_error(result['test_labels'], result['test_predictions']) mae = mean_absolute_error(result['test_labels'], result['test_predictions']) print(f'{name} MAE: {mae}') print(f'{name} RMSE: {rmse}') print("\nScript finished.") if __name__ == '__main__': main()