import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.metrics import roc_auc_score, average_precision_score from transformers import BertModel, BertConfig import os import json from collections import defaultdict from rdkit import Chem from rdkit.Chem import Scaffolds import warnings warnings.filterwarnings('ignore') from transformers import AutoTokenizer # Global average pooling function (assuming this exists in your codebase) def global_ap(x, dim=1): return torch.mean(x, dim=dim) class SimSonClassifier(nn.Module): def __init__(self, config: BertConfig, max_len: int, num_labels: int, dropout: float = 0.1): super(SimSonClassifier, self).__init__() self.config = config self.max_len = max_len self.num_labels = num_labels # BERT encoder (same as SimSonEncoder) self.bert = BertModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(dropout) # Classification head self.classifier = nn.Linear(config.hidden_size, num_labels) def forward(self, input_ids, attention_mask=None): if attention_mask is None: attention_mask = input_ids.ne(0) outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) hidden_states = outputs.last_hidden_state hidden_states = self.dropout(hidden_states) # Global average pooling pooled = global_ap(hidden_states) # Classification output logits = self.classifier(pooled) return logits def load_encoder_weights(self, encoder_path): """Load pretrained SimSonEncoder weights into the classifier""" encoder_state = torch.load(encoder_path, map_location='cpu') # Create mapping from encoder to classifier state dict classifier_state = {} for key, value in encoder_state.items(): if key.startswith('bert.') or key.startswith('dropout.'): classifier_state[key] = value # Load only the matching weights self.load_state_dict(classifier_state, strict=False) print(f"Loaded encoder weights from {encoder_path}") def load_moleculenet_data(dataset_name): """Load MoleculeNet dataset and return SMILES and labels""" if dataset_name == 'bbbp': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') smiles, labels = df.smiles, df.p_np elif dataset_name == 'clintox': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') smiles = df.smiles labels = df.drop(['smiles'], axis=1) elif dataset_name == 'hiv': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') smiles, labels = df.smiles, df.HIV_active elif dataset_name == 'sider': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') smiles = df.smiles labels = df.drop(['smiles'], axis=1) elif dataset_name == 'tox21': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') df = df.dropna(axis=0, how='any').reset_index(drop=True) smiles = df.smiles labels = df.drop(['mol_id', 'smiles'], axis=1) elif dataset_name == 'bace': df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') smiles, labels = df.mol, df.Class else: raise ValueError(f"Dataset {dataset_name} not supported") return smiles, labels class MoleculeDataset(Dataset): def __init__(self, smiles_list, labels, tokenizer, max_length=512): self.smiles = smiles_list self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.smiles) def __getitem__(self, idx): smiles = self.smiles[idx] # Tokenize SMILES encoding = self.tokenizer( smiles, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) # Handle labels if isinstance(self.labels, pd.Series): label = torch.tensor(self.labels.iloc[idx], dtype=torch.float32) else: # DataFrame (multi-label) label = torch.tensor(self.labels.iloc[idx].values, dtype=torch.float32) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': label } def get_loss_fn(num_labels): """Get appropriate loss function based on number of labels""" if num_labels == 1: return nn.BCEWithLogitsLoss() else: return nn.BCEWithLogitsLoss() # Multi-label classification def compute_metrics(predictions, labels, num_labels): """Compute ROC-AUC for single or multi-label classification""" predictions = torch.sigmoid(predictions).cpu().numpy() labels = labels.cpu().numpy() if num_labels == 1: # Single label try: auc = roc_auc_score(labels, predictions) return {'roc_auc': auc} except: return {'roc_auc': 0.5} else: # Multi-label aucs = [] for i in range(num_labels): try: auc = roc_auc_score(labels[:, i], predictions[:, i]) aucs.append(auc) except: aucs.append(0.5) return {'roc_auc': np.mean(aucs), 'individual_aucs': aucs} def train_epoch(model, dataloader, optimizer, loss_fn, device): model.train() total_loss = 0 for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() outputs = model(input_ids, attention_mask) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(dataloader) def evaluate(model, dataloader, loss_fn, num_labels, device): model.eval() total_loss = 0 all_predictions = [] all_labels = [] with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model(input_ids, attention_mask) loss = loss_fn(outputs, labels) total_loss += loss.item() all_predictions.append(outputs) all_labels.append(labels) all_predictions = torch.cat(all_predictions) all_labels = torch.cat(all_labels) metrics = compute_metrics(all_predictions, all_labels, num_labels) avg_loss = total_loss / len(dataloader) return avg_loss, metrics def run_experiment(dataset_name, config, tokenizer, encoder_path=None, batch_size=32, learning_rate=1e-4, epochs=50, device='cuda'): """Run complete experiment for one dataset""" print(f"\n=== Running experiment for {dataset_name.upper()} ===") # Load data smiles, labels = load_moleculenet_data(dataset_name) print(f"Loaded {len(smiles)} samples") # Determine number of labels if isinstance(labels, pd.Series): num_labels = 1 else: num_labels = labels.shape[1] print(f"Number of labels: {num_labels}") # Scaffold split smiles_list = smiles.tolist() train_idx, valid_idx, test_idx = scaffold_split(smiles_list) print(f"Split sizes - Train: {len(train_idx)}, Valid: {len(valid_idx)}, Test: {len(test_idx)}") # Create datasets train_smiles = [smiles_list[i] for i in train_idx] valid_smiles = [smiles_list[i] for i in valid_idx] test_smiles = [smiles_list[i] for i in test_idx] if isinstance(labels, pd.Series): train_labels = labels.iloc[list(train_idx)] valid_labels = labels.iloc[list(valid_idx)] test_labels = labels.iloc[list(test_idx)] else: train_labels = labels.iloc[list(train_idx)] valid_labels = labels.iloc[list(valid_idx)] test_labels = labels.iloc[list(test_idx)] # Create data loaders train_dataset = MoleculeDataset(train_smiles, train_labels, tokenizer) valid_dataset = MoleculeDataset(valid_smiles, valid_labels, tokenizer) test_dataset = MoleculeDataset(test_smiles, test_labels, tokenizer) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # Initialize model model = SimSonClassifier(config, max_len=512, num_labels=num_labels).to(device) # Load encoder weights if provided if encoder_path: model.load_encoder_weights(encoder_path) # Setup training optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) loss_fn = get_loss_fn(num_labels) best_valid_loss = float('inf') best_model_path = f'best_{dataset_name}_model.pth' # Training loop for epoch in range(epochs): train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device) valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, num_labels, device) # Save best model if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(model.state_dict(), best_model_path) if epoch % 10 == 0: print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, " f"Valid Loss = {valid_loss:.4f}, Valid AUC = {valid_metrics['roc_auc']:.4f}") # Load best model and test model.load_state_dict(torch.load(best_model_path)) test_loss, test_metrics = evaluate(model, test_loader, loss_fn, num_labels, device) print(f"Final Test Results - Loss: {test_loss:.4f}, ROC-AUC: {test_metrics['roc_auc']:.4f}") # Cleanup os.remove(best_model_path) return { 'dataset': dataset_name, 'num_labels': num_labels, 'test_loss': test_loss, 'test_roc_auc': test_metrics['roc_auc'], 'individual_aucs': test_metrics.get('individual_aucs', None) } def main(): """Main function to run all experiments""" # Setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Initialize tokenizer and config (you need to provide these) # tokenizer = your_tokenizer # Replace with your tokenizer # config = BertConfig(...) # Your config from above tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR' tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # Only the hidden size is slightly larger, everything else is the same config = BertConfig( vocab_size=tokenizer.vocab_size, hidden_size=768, num_hidden_layers=4, num_attention_heads=12, intermediate_size=2048, max_position_embeddings=512 ) # Datasets to test datasets = ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace'] # Path to your pretrained encoder (optional) encoder_path = 'simson_checkpoints_small/simson_model_single_gpu.bin' # Run experiments all_results = [] for dataset in datasets: try: result = run_experiment( dataset, config, tokenizer, encoder_path=encoder_path, device=device ) all_results.append(result) except Exception as e: print(f"Error with {dataset}: {e}") # Aggregate and display results print("\n" + "="*60) print("FINAL RESULTS SUMMARY") print("="*60) results_df = pd.DataFrame(all_results) print(results_df.to_string(index=False)) # Save results results_df.to_csv('moleculenet_results.csv', index=False) print(f"\nResults saved to moleculenet_results.csv") return results_df if __name__ == "__main__": results = main()