|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.metrics import roc_auc_score, average_precision_score |
|
from transformers import BertModel, BertConfig |
|
import os |
|
import json |
|
from collections import defaultdict |
|
from rdkit import Chem |
|
from rdkit.Chem import Scaffolds |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
from transformers import AutoTokenizer |
|
|
|
|
|
def global_ap(x, dim=1): |
|
return torch.mean(x, dim=dim) |
|
|
|
class SimSonClassifier(nn.Module): |
|
def __init__(self, config: BertConfig, max_len: int, num_labels: int, dropout: float = 0.1): |
|
super(SimSonClassifier, self).__init__() |
|
self.config = config |
|
self.max_len = max_len |
|
self.num_labels = num_labels |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, num_labels) |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
if attention_mask is None: |
|
attention_mask = input_ids.ne(0) |
|
|
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
pooled = global_ap(hidden_states) |
|
|
|
|
|
logits = self.classifier(pooled) |
|
|
|
return logits |
|
|
|
def load_encoder_weights(self, encoder_path): |
|
"""Load pretrained SimSonEncoder weights into the classifier""" |
|
encoder_state = torch.load(encoder_path, map_location='cpu') |
|
|
|
|
|
classifier_state = {} |
|
for key, value in encoder_state.items(): |
|
if key.startswith('bert.') or key.startswith('dropout.'): |
|
classifier_state[key] = value |
|
|
|
|
|
self.load_state_dict(classifier_state, strict=False) |
|
print(f"Loaded encoder weights from {encoder_path}") |
|
|
|
|
|
|
|
def load_moleculenet_data(dataset_name): |
|
"""Load MoleculeNet dataset and return SMILES and labels""" |
|
if dataset_name == 'bbbp': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
|
smiles, labels = df.smiles, df.p_np |
|
elif dataset_name == 'clintox': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
|
smiles = df.smiles |
|
labels = df.drop(['smiles'], axis=1) |
|
elif dataset_name == 'hiv': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
|
smiles, labels = df.smiles, df.HIV_active |
|
elif dataset_name == 'sider': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
|
smiles = df.smiles |
|
labels = df.drop(['smiles'], axis=1) |
|
elif dataset_name == 'tox21': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
|
df = df.dropna(axis=0, how='any').reset_index(drop=True) |
|
smiles = df.smiles |
|
labels = df.drop(['mol_id', 'smiles'], axis=1) |
|
elif dataset_name == 'bace': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
|
smiles, labels = df.mol, df.Class |
|
else: |
|
raise ValueError(f"Dataset {dataset_name} not supported") |
|
|
|
return smiles, labels |
|
|
|
class MoleculeDataset(Dataset): |
|
def __init__(self, smiles_list, labels, tokenizer, max_length=512): |
|
self.smiles = smiles_list |
|
self.labels = labels |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.smiles) |
|
|
|
def __getitem__(self, idx): |
|
smiles = self.smiles[idx] |
|
|
|
|
|
encoding = self.tokenizer( |
|
smiles, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=self.max_length, |
|
return_tensors='pt' |
|
) |
|
|
|
|
|
if isinstance(self.labels, pd.Series): |
|
label = torch.tensor(self.labels.iloc[idx], dtype=torch.float32) |
|
else: |
|
label = torch.tensor(self.labels.iloc[idx].values, dtype=torch.float32) |
|
|
|
return { |
|
'input_ids': encoding['input_ids'].flatten(), |
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
'labels': label |
|
} |
|
|
|
def get_loss_fn(num_labels): |
|
"""Get appropriate loss function based on number of labels""" |
|
if num_labels == 1: |
|
return nn.BCEWithLogitsLoss() |
|
else: |
|
return nn.BCEWithLogitsLoss() |
|
|
|
def compute_metrics(predictions, labels, num_labels): |
|
"""Compute ROC-AUC for single or multi-label classification""" |
|
predictions = torch.sigmoid(predictions).cpu().numpy() |
|
labels = labels.cpu().numpy() |
|
|
|
if num_labels == 1: |
|
|
|
try: |
|
auc = roc_auc_score(labels, predictions) |
|
return {'roc_auc': auc} |
|
except: |
|
return {'roc_auc': 0.5} |
|
else: |
|
|
|
aucs = [] |
|
for i in range(num_labels): |
|
try: |
|
auc = roc_auc_score(labels[:, i], predictions[:, i]) |
|
aucs.append(auc) |
|
except: |
|
aucs.append(0.5) |
|
return {'roc_auc': np.mean(aucs), 'individual_aucs': aucs} |
|
|
|
def train_epoch(model, dataloader, optimizer, loss_fn, device): |
|
model.train() |
|
total_loss = 0 |
|
|
|
for batch in dataloader: |
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
labels = batch['labels'].to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
outputs = model(input_ids, attention_mask) |
|
loss = loss_fn(outputs, labels) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
|
|
return total_loss / len(dataloader) |
|
|
|
def evaluate(model, dataloader, loss_fn, num_labels, device): |
|
model.eval() |
|
total_loss = 0 |
|
all_predictions = [] |
|
all_labels = [] |
|
|
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
labels = batch['labels'].to(device) |
|
|
|
outputs = model(input_ids, attention_mask) |
|
loss = loss_fn(outputs, labels) |
|
|
|
total_loss += loss.item() |
|
all_predictions.append(outputs) |
|
all_labels.append(labels) |
|
|
|
all_predictions = torch.cat(all_predictions) |
|
all_labels = torch.cat(all_labels) |
|
|
|
metrics = compute_metrics(all_predictions, all_labels, num_labels) |
|
avg_loss = total_loss / len(dataloader) |
|
|
|
return avg_loss, metrics |
|
|
|
def run_experiment(dataset_name, config, tokenizer, encoder_path=None, |
|
batch_size=32, learning_rate=1e-4, epochs=50, device='cuda'): |
|
"""Run complete experiment for one dataset""" |
|
print(f"\n=== Running experiment for {dataset_name.upper()} ===") |
|
|
|
|
|
smiles, labels = load_moleculenet_data(dataset_name) |
|
print(f"Loaded {len(smiles)} samples") |
|
|
|
|
|
if isinstance(labels, pd.Series): |
|
num_labels = 1 |
|
else: |
|
num_labels = labels.shape[1] |
|
print(f"Number of labels: {num_labels}") |
|
|
|
|
|
smiles_list = smiles.tolist() |
|
train_idx, valid_idx, test_idx = scaffold_split(smiles_list) |
|
|
|
print(f"Split sizes - Train: {len(train_idx)}, Valid: {len(valid_idx)}, Test: {len(test_idx)}") |
|
|
|
train_smiles = [smiles_list[i] for i in train_idx] |
|
valid_smiles = [smiles_list[i] for i in valid_idx] |
|
test_smiles = [smiles_list[i] for i in test_idx] |
|
|
|
if isinstance(labels, pd.Series): |
|
train_labels = labels.iloc[list(train_idx)] |
|
valid_labels = labels.iloc[list(valid_idx)] |
|
test_labels = labels.iloc[list(test_idx)] |
|
else: |
|
train_labels = labels.iloc[list(train_idx)] |
|
valid_labels = labels.iloc[list(valid_idx)] |
|
test_labels = labels.iloc[list(test_idx)] |
|
|
|
|
|
train_dataset = MoleculeDataset(train_smiles, train_labels, tokenizer) |
|
valid_dataset = MoleculeDataset(valid_smiles, valid_labels, tokenizer) |
|
test_dataset = MoleculeDataset(test_smiles, test_labels, tokenizer) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) |
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
model = SimSonClassifier(config, max_len=512, num_labels=num_labels).to(device) |
|
|
|
|
|
if encoder_path: |
|
model.load_encoder_weights(encoder_path) |
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
loss_fn = get_loss_fn(num_labels) |
|
|
|
best_valid_loss = float('inf') |
|
best_model_path = f'best_{dataset_name}_model.pth' |
|
|
|
|
|
for epoch in range(epochs): |
|
train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device) |
|
valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, num_labels, device) |
|
|
|
|
|
if valid_loss < best_valid_loss: |
|
best_valid_loss = valid_loss |
|
torch.save(model.state_dict(), best_model_path) |
|
|
|
if epoch % 10 == 0: |
|
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, " |
|
f"Valid Loss = {valid_loss:.4f}, Valid AUC = {valid_metrics['roc_auc']:.4f}") |
|
|
|
|
|
model.load_state_dict(torch.load(best_model_path)) |
|
test_loss, test_metrics = evaluate(model, test_loader, loss_fn, num_labels, device) |
|
|
|
print(f"Final Test Results - Loss: {test_loss:.4f}, ROC-AUC: {test_metrics['roc_auc']:.4f}") |
|
|
|
|
|
os.remove(best_model_path) |
|
|
|
return { |
|
'dataset': dataset_name, |
|
'num_labels': num_labels, |
|
'test_loss': test_loss, |
|
'test_roc_auc': test_metrics['roc_auc'], |
|
'individual_aucs': test_metrics.get('individual_aucs', None) |
|
} |
|
|
|
def main(): |
|
"""Main function to run all experiments""" |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR' |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
|
|
|
config = BertConfig( |
|
vocab_size=tokenizer.vocab_size, |
|
hidden_size=768, |
|
num_hidden_layers=4, |
|
num_attention_heads=12, |
|
intermediate_size=2048, |
|
max_position_embeddings=512 |
|
) |
|
|
|
datasets = ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace'] |
|
|
|
|
|
encoder_path = 'simson_checkpoints_small/simson_model_single_gpu.bin' |
|
|
|
|
|
all_results = [] |
|
for dataset in datasets: |
|
try: |
|
result = run_experiment( |
|
dataset, |
|
config, |
|
tokenizer, |
|
encoder_path=encoder_path, |
|
device=device |
|
) |
|
all_results.append(result) |
|
except Exception as e: |
|
print(f"Error with {dataset}: {e}") |
|
|
|
|
|
print("\n" + "="*60) |
|
print("FINAL RESULTS SUMMARY") |
|
print("="*60) |
|
|
|
results_df = pd.DataFrame(all_results) |
|
print(results_df.to_string(index=False)) |
|
|
|
|
|
results_df.to_csv('moleculenet_results.csv', index=False) |
|
print(f"\nResults saved to moleculenet_results.csv") |
|
|
|
return results_df |
|
|
|
if __name__ == "__main__": |
|
|
|
results = main() |
|
|