Defetya's picture
Upload folder using huggingface_hub
592e96e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from transformers import BertModel, BertConfig
import os
import json
from collections import defaultdict
from rdkit import Chem
from rdkit.Chem import Scaffolds
import warnings
warnings.filterwarnings('ignore')
from transformers import AutoTokenizer
# Global average pooling function (assuming this exists in your codebase)
def global_ap(x, dim=1):
return torch.mean(x, dim=dim)
class SimSonClassifier(nn.Module):
def __init__(self, config: BertConfig, max_len: int, num_labels: int, dropout: float = 0.1):
super(SimSonClassifier, self).__init__()
self.config = config
self.max_len = max_len
self.num_labels = num_labels
# BERT encoder (same as SimSonEncoder)
self.bert = BertModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(dropout)
# Classification head
self.classifier = nn.Linear(config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask=None):
if attention_mask is None:
attention_mask = input_ids.ne(0)
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden_states = outputs.last_hidden_state
hidden_states = self.dropout(hidden_states)
# Global average pooling
pooled = global_ap(hidden_states)
# Classification output
logits = self.classifier(pooled)
return logits
def load_encoder_weights(self, encoder_path):
"""Load pretrained SimSonEncoder weights into the classifier"""
encoder_state = torch.load(encoder_path, map_location='cpu')
# Create mapping from encoder to classifier state dict
classifier_state = {}
for key, value in encoder_state.items():
if key.startswith('bert.') or key.startswith('dropout.'):
classifier_state[key] = value
# Load only the matching weights
self.load_state_dict(classifier_state, strict=False)
print(f"Loaded encoder weights from {encoder_path}")
def load_moleculenet_data(dataset_name):
"""Load MoleculeNet dataset and return SMILES and labels"""
if dataset_name == 'bbbp':
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv')
smiles, labels = df.smiles, df.p_np
elif dataset_name == 'clintox':
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip')
smiles = df.smiles
labels = df.drop(['smiles'], axis=1)
elif dataset_name == 'hiv':
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv')
smiles, labels = df.smiles, df.HIV_active
elif dataset_name == 'sider':
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip')
smiles = df.smiles
labels = df.drop(['smiles'], axis=1)
elif dataset_name == 'tox21':
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip')
df = df.dropna(axis=0, how='any').reset_index(drop=True)
smiles = df.smiles
labels = df.drop(['mol_id', 'smiles'], axis=1)
elif dataset_name == 'bace':
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv')
smiles, labels = df.mol, df.Class
else:
raise ValueError(f"Dataset {dataset_name} not supported")
return smiles, labels
class MoleculeDataset(Dataset):
def __init__(self, smiles_list, labels, tokenizer, max_length=512):
self.smiles = smiles_list
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.smiles)
def __getitem__(self, idx):
smiles = self.smiles[idx]
# Tokenize SMILES
encoding = self.tokenizer(
smiles,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# Handle labels
if isinstance(self.labels, pd.Series):
label = torch.tensor(self.labels.iloc[idx], dtype=torch.float32)
else: # DataFrame (multi-label)
label = torch.tensor(self.labels.iloc[idx].values, dtype=torch.float32)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': label
}
def get_loss_fn(num_labels):
"""Get appropriate loss function based on number of labels"""
if num_labels == 1:
return nn.BCEWithLogitsLoss()
else:
return nn.BCEWithLogitsLoss() # Multi-label classification
def compute_metrics(predictions, labels, num_labels):
"""Compute ROC-AUC for single or multi-label classification"""
predictions = torch.sigmoid(predictions).cpu().numpy()
labels = labels.cpu().numpy()
if num_labels == 1:
# Single label
try:
auc = roc_auc_score(labels, predictions)
return {'roc_auc': auc}
except:
return {'roc_auc': 0.5}
else:
# Multi-label
aucs = []
for i in range(num_labels):
try:
auc = roc_auc_score(labels[:, i], predictions[:, i])
aucs.append(auc)
except:
aucs.append(0.5)
return {'roc_auc': np.mean(aucs), 'individual_aucs': aucs}
def train_epoch(model, dataloader, optimizer, loss_fn, device):
model.train()
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def evaluate(model, dataloader, loss_fn, num_labels, device):
model.eval()
total_loss = 0
all_predictions = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
loss = loss_fn(outputs, labels)
total_loss += loss.item()
all_predictions.append(outputs)
all_labels.append(labels)
all_predictions = torch.cat(all_predictions)
all_labels = torch.cat(all_labels)
metrics = compute_metrics(all_predictions, all_labels, num_labels)
avg_loss = total_loss / len(dataloader)
return avg_loss, metrics
def run_experiment(dataset_name, config, tokenizer, encoder_path=None,
batch_size=32, learning_rate=1e-4, epochs=50, device='cuda'):
"""Run complete experiment for one dataset"""
print(f"\n=== Running experiment for {dataset_name.upper()} ===")
# Load data
smiles, labels = load_moleculenet_data(dataset_name)
print(f"Loaded {len(smiles)} samples")
# Determine number of labels
if isinstance(labels, pd.Series):
num_labels = 1
else:
num_labels = labels.shape[1]
print(f"Number of labels: {num_labels}")
# Scaffold split
smiles_list = smiles.tolist()
train_idx, valid_idx, test_idx = scaffold_split(smiles_list)
print(f"Split sizes - Train: {len(train_idx)}, Valid: {len(valid_idx)}, Test: {len(test_idx)}")
# Create datasets
train_smiles = [smiles_list[i] for i in train_idx]
valid_smiles = [smiles_list[i] for i in valid_idx]
test_smiles = [smiles_list[i] for i in test_idx]
if isinstance(labels, pd.Series):
train_labels = labels.iloc[list(train_idx)]
valid_labels = labels.iloc[list(valid_idx)]
test_labels = labels.iloc[list(test_idx)]
else:
train_labels = labels.iloc[list(train_idx)]
valid_labels = labels.iloc[list(valid_idx)]
test_labels = labels.iloc[list(test_idx)]
# Create data loaders
train_dataset = MoleculeDataset(train_smiles, train_labels, tokenizer)
valid_dataset = MoleculeDataset(valid_smiles, valid_labels, tokenizer)
test_dataset = MoleculeDataset(test_smiles, test_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Initialize model
model = SimSonClassifier(config, max_len=512, num_labels=num_labels).to(device)
# Load encoder weights if provided
if encoder_path:
model.load_encoder_weights(encoder_path)
# Setup training
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = get_loss_fn(num_labels)
best_valid_loss = float('inf')
best_model_path = f'best_{dataset_name}_model.pth'
# Training loop
for epoch in range(epochs):
train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device)
valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, num_labels, device)
# Save best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), best_model_path)
if epoch % 10 == 0:
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, "
f"Valid Loss = {valid_loss:.4f}, Valid AUC = {valid_metrics['roc_auc']:.4f}")
# Load best model and test
model.load_state_dict(torch.load(best_model_path))
test_loss, test_metrics = evaluate(model, test_loader, loss_fn, num_labels, device)
print(f"Final Test Results - Loss: {test_loss:.4f}, ROC-AUC: {test_metrics['roc_auc']:.4f}")
# Cleanup
os.remove(best_model_path)
return {
'dataset': dataset_name,
'num_labels': num_labels,
'test_loss': test_loss,
'test_roc_auc': test_metrics['roc_auc'],
'individual_aucs': test_metrics.get('individual_aucs', None)
}
def main():
"""Main function to run all experiments"""
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Initialize tokenizer and config (you need to provide these)
# tokenizer = your_tokenizer # Replace with your tokenizer
# config = BertConfig(...) # Your config from above
tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# Only the hidden size is slightly larger, everything else is the same
config = BertConfig(
vocab_size=tokenizer.vocab_size,
hidden_size=768,
num_hidden_layers=4,
num_attention_heads=12,
intermediate_size=2048,
max_position_embeddings=512
)
# Datasets to test
datasets = ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']
# Path to your pretrained encoder (optional)
encoder_path = 'simson_checkpoints_small/simson_model_single_gpu.bin'
# Run experiments
all_results = []
for dataset in datasets:
try:
result = run_experiment(
dataset,
config,
tokenizer,
encoder_path=encoder_path,
device=device
)
all_results.append(result)
except Exception as e:
print(f"Error with {dataset}: {e}")
# Aggregate and display results
print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)
results_df = pd.DataFrame(all_results)
print(results_df.to_string(index=False))
# Save results
results_df.to_csv('moleculenet_results.csv', index=False)
print(f"\nResults saved to moleculenet_results.csv")
return results_df
if __name__ == "__main__":
results = main()