|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import BertConfig, BertModel, AutoTokenizer |
|
from rdkit import Chem, RDLogger |
|
from rdkit.Chem.Scaffolds import MurckoScaffold |
|
import copy |
|
from tqdm import tqdm |
|
import os |
|
from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error |
|
from itertools import compress |
|
from collections import defaultdict |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
class SmilesEnumerator: |
|
"""Generates randomized SMILES strings for data augmentation.""" |
|
def randomize_smiles(self, smiles): |
|
try: |
|
mol = Chem.MolFromSmiles(smiles) |
|
return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
|
except: |
|
return smiles |
|
|
|
|
|
def compute_embedding_similarity(encoder, smiles_list, tokenizer, device, max_len=256): |
|
encoder.eval() |
|
enumerator = SmilesEnumerator() |
|
|
|
embeddings_orig = [] |
|
embeddings_aug = [] |
|
|
|
with torch.no_grad(): |
|
for smi in smiles_list: |
|
|
|
encoding_orig = tokenizer( |
|
smi, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=max_len, |
|
return_tensors='pt' |
|
) |
|
|
|
smi_aug = enumerator.randomize_smiles(smi) |
|
encoding_aug = tokenizer( |
|
smi_aug, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=max_len, |
|
return_tensors='pt' |
|
) |
|
|
|
input_ids_orig = encoding_orig.input_ids.to(device) |
|
attention_mask_orig = encoding_orig.attention_mask.to(device) |
|
input_ids_aug = encoding_aug.input_ids.to(device) |
|
attention_mask_aug = encoding_aug.attention_mask.to(device) |
|
|
|
emb_orig = encoder(input_ids_orig, attention_mask_orig).cpu().numpy().flatten() |
|
emb_aug = encoder(input_ids_aug, attention_mask_aug).cpu().numpy().flatten() |
|
|
|
embeddings_orig.append(emb_orig) |
|
embeddings_aug.append(emb_aug) |
|
|
|
embeddings_orig = np.array(embeddings_orig) |
|
embeddings_aug = np.array(embeddings_aug) |
|
|
|
|
|
similarities = np.array([cosine_similarity([embeddings_orig[i]], [embeddings_aug[i]])[0][0] for i in range(len(embeddings_orig))]) |
|
return similarities |
|
|
|
|
|
def load_lists_from_url(data): |
|
if data == 'bbbp': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
|
smiles, labels = df.smiles, df.p_np |
|
elif data == 'clintox': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
|
smiles = df.smiles |
|
labels = df.drop(['smiles'], axis=1) |
|
elif data == 'hiv': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
|
smiles, labels = df.smiles, df.HIV_active |
|
elif data == 'sider': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
|
smiles = df.smiles |
|
labels = df.drop(['smiles'], axis=1) |
|
elif data == 'esol': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv') |
|
smiles = df.smiles |
|
labels = df['ESOL predicted log solubility in mols per litre'] |
|
elif data == 'freesolv': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv') |
|
smiles = df.smiles |
|
labels = df.calc |
|
elif data == 'lipophicility': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv') |
|
smiles, labels = df.smiles, df['exp'] |
|
elif data == 'tox21': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
|
df = df.dropna(axis=0, how='any').reset_index(drop=True) |
|
smiles = df.smiles |
|
labels = df.drop(['mol_id', 'smiles'], axis=1) |
|
elif data == 'bace': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
|
smiles, labels = df.mol, df.Class |
|
elif data == 'qm8': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv') |
|
df = df.dropna(axis=0, how='any').reset_index(drop=True) |
|
smiles = df.smiles |
|
labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) |
|
return smiles, labels |
|
|
|
|
|
class ScaffoldSplitter: |
|
def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True): |
|
self.data = data |
|
self.seed = seed |
|
self.include_chirality = include_chirality |
|
self.train_frac = train_frac |
|
self.val_frac = val_frac |
|
self.test_frac = test_frac |
|
|
|
def generate_scaffold(self, smiles): |
|
mol = Chem.MolFromSmiles(smiles) |
|
scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality) |
|
return scaffold |
|
|
|
def scaffold_split(self): |
|
smiles, labels = load_lists_from_url(self.data) |
|
non_null = np.ones(len(smiles)) == 0 |
|
|
|
if self.data in {'tox21', 'sider', 'clintox'}: |
|
for i in range(len(smiles)): |
|
if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0: |
|
non_null[i] = 1 |
|
else: |
|
for i in range(len(smiles)): |
|
if Chem.MolFromSmiles(smiles[i]): |
|
non_null[i] = 1 |
|
|
|
smiles_list = list(compress(enumerate(smiles), non_null)) |
|
rng = np.random.RandomState(self.seed) |
|
|
|
scaffolds = defaultdict(list) |
|
for i, sms in smiles_list: |
|
scaffold = self.generate_scaffold(sms) |
|
scaffolds[scaffold].append(i) |
|
|
|
scaffold_sets = list(scaffolds.values()) |
|
rng.shuffle(scaffold_sets) |
|
n_total_val = int(np.floor(self.val_frac * len(smiles_list))) |
|
n_total_test = int(np.floor(self.test_frac * len(smiles_list))) |
|
train_idx, val_idx, test_idx = [], [], [] |
|
|
|
for scaffold_set in scaffold_sets: |
|
if len(val_idx) + len(scaffold_set) <= n_total_val: |
|
val_idx.extend(scaffold_set) |
|
elif len(test_idx) + len(scaffold_set) <= n_total_test: |
|
test_idx.extend(scaffold_set) |
|
else: |
|
train_idx.extend(scaffold_set) |
|
return train_idx, val_idx, test_idx |
|
|
|
|
|
def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1): |
|
np.random.seed(seed) |
|
indices = np.random.permutation(n) |
|
n_train = int(n * train_frac) |
|
n_val = int(n * val_frac) |
|
train_idx = indices[:n_train] |
|
val_idx = indices[n_train:n_train+n_val] |
|
test_idx = indices[n_train+n_val:] |
|
return train_idx.tolist(), val_idx.tolist(), test_idx.tolist() |
|
|
|
|
|
class MoleculeDataset(Dataset): |
|
def __init__(self, smiles_list, labels, tokenizer, max_len=512): |
|
self.smiles_list = smiles_list |
|
self.labels = labels |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
|
|
def __len__(self): |
|
return len(self.smiles_list) |
|
|
|
def __getitem__(self, idx): |
|
smiles = self.smiles_list[idx] |
|
label = self.labels.iloc[idx] |
|
|
|
encoding = self.tokenizer( |
|
smiles, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=self.max_len, |
|
return_tensors='pt' |
|
) |
|
item = {key: val.squeeze(0) for key, val in encoding.items()} |
|
if isinstance(label, pd.Series): |
|
label_values = label.values.astype(np.float32) |
|
else: |
|
label_values = np.array([label], dtype=np.float32) |
|
item['labels'] = torch.tensor(label_values, dtype=torch.float) |
|
return item |
|
|
|
|
|
def global_ap(x): |
|
return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1) |
|
|
|
class SimSonEncoder(nn.Module): |
|
def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): |
|
super(SimSonEncoder, self).__init__() |
|
self.config = config |
|
self.max_len = max_len |
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.linear = nn.Linear(config.hidden_size, max_len) |
|
self.dropout = nn.Dropout(dropout) |
|
def forward(self, input_ids, attention_mask=None): |
|
if attention_mask is None: |
|
attention_mask = input_ids.ne(self.config.pad_token_id) |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
hidden_states = self.dropout(outputs.last_hidden_state) |
|
pooled = global_ap(hidden_states) |
|
return self.linear(pooled) |
|
|
|
class SimSonClassifier(nn.Module): |
|
def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1): |
|
super(SimSonClassifier, self).__init__() |
|
self.encoder = encoder |
|
self.clf = nn.Linear(encoder.max_len, num_labels) |
|
self.relu = nn.ReLU() |
|
self.dropout = nn.Dropout(dropout) |
|
def forward(self, input_ids, attention_mask=None): |
|
x = self.encoder(input_ids, attention_mask) |
|
x = self.relu(self.dropout(x)) |
|
logits = self.clf(x) |
|
return logits |
|
|
|
def load_encoder_params(self, state_dict_path): |
|
self.encoder.load_state_dict(torch.load(state_dict_path)) |
|
print("Pretrained encoder parameters loaded.") |
|
|
|
|
|
def get_criterion(task_type, num_labels): |
|
if task_type == 'classification': |
|
return nn.BCEWithLogitsLoss() |
|
elif task_type == 'regression': |
|
return nn.MSELoss() |
|
else: |
|
raise ValueError(f"Unknown task type: {task_type}") |
|
|
|
def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): |
|
model.train() |
|
total_loss = 0 |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'].to(device) |
|
optimizer.zero_grad() |
|
outputs = model(**inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
return total_loss / len(dataloader) |
|
|
|
def eval_epoch(model, dataloader, criterion, device): |
|
model.eval() |
|
total_loss = 0 |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'].to(device) |
|
outputs = model(**inputs) |
|
loss = criterion(outputs, labels) |
|
total_loss += loss.item() |
|
return total_loss / len(dataloader) |
|
|
|
def test_model(model, dataloader, device): |
|
model.eval() |
|
all_preds, all_labels = [], [] |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'] |
|
outputs = model(**inputs) |
|
preds = torch.sigmoid(outputs) |
|
all_preds.append(preds.cpu().numpy()) |
|
all_labels.append(labels.numpy()) |
|
return np.concatenate(all_preds), np.concatenate(all_labels) |
|
|
|
def calc_val_metrics(model, dataloader, criterion, device, task_type): |
|
model.eval() |
|
all_labels, all_preds = [], [] |
|
total_loss = 0 |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'].to(device) |
|
outputs = model(**inputs) |
|
loss = criterion(outputs, labels) |
|
total_loss += loss.item() |
|
if task_type == 'classification': |
|
pred_probs = torch.sigmoid(outputs).cpu().numpy() |
|
all_preds.append(pred_probs) |
|
all_labels.append(labels.cpu().numpy()) |
|
else: |
|
|
|
preds = outputs.cpu().numpy() |
|
all_preds.append(preds) |
|
all_labels.append(labels.cpu().numpy()) |
|
avg_loss = total_loss / len(dataloader) |
|
if task_type == 'classification': |
|
y_true = np.concatenate(all_labels) |
|
y_pred = np.concatenate(all_preds) |
|
try: |
|
score = roc_auc_score(y_true, y_pred, average='macro') |
|
except Exception: |
|
score = 0.0 |
|
return avg_loss, score |
|
else: |
|
return avg_loss, None |
|
|
|
|
|
def main(): |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {DEVICE}") |
|
|
|
DATASETS_TO_RUN = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'random'}, |
|
|
|
} |
|
PATIENCE = 15 |
|
EPOCHS = 50 |
|
LEARNING_RATE = 1e-4 |
|
BATCH_SIZE = 16 |
|
MAX_LEN = 512 |
|
|
|
TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
|
ENCODER_CONFIG = BertConfig( |
|
vocab_size=TOKENIZER.vocab_size, |
|
hidden_size=768, |
|
num_hidden_layers=4, |
|
num_attention_heads=12, |
|
intermediate_size=2048, |
|
max_position_embeddings=512 |
|
) |
|
|
|
aggregated_results = {} |
|
|
|
for name, info in DATASETS_TO_RUN.items(): |
|
print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}") |
|
smiles, labels = load_lists_from_url(name) |
|
|
|
|
|
if info.get('split', 'scaffold') == 'scaffold': |
|
splitter = ScaffoldSplitter(data=name, seed=42) |
|
train_idx, val_idx, test_idx = splitter.scaffold_split() |
|
elif info['split'] == 'random': |
|
train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42) |
|
else: |
|
raise ValueError(f"Unknown split type for {name}: {info['split']}") |
|
|
|
train_smiles = smiles.iloc[train_idx].reset_index(drop=True) |
|
train_labels = labels.iloc[train_idx].reset_index(drop=True) |
|
val_smiles = smiles.iloc[val_idx].reset_index(drop=True) |
|
val_labels = labels.iloc[val_idx].reset_index(drop=True) |
|
test_smiles = smiles.iloc[test_idx].reset_index(drop=True) |
|
test_labels = labels.iloc[test_idx].reset_index(drop=True) |
|
print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}") |
|
|
|
train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN) |
|
val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN) |
|
test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) |
|
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) |
|
|
|
encoder = SimSonEncoder(ENCODER_CONFIG, 512) |
|
encoder = torch.compile(encoder) |
|
model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE) |
|
model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
|
criterion = get_criterion(info['task_type'], info['num_labels']) |
|
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0024) |
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.59298) |
|
|
|
best_val_loss = float('-inf') |
|
best_model_state = None |
|
current_patience = 0 |
|
for epoch in range(EPOCHS): |
|
train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE) |
|
val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, 'cuda', info['task_type']) |
|
print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}") |
|
|
|
if val_metric <= val_loss: |
|
best_val_loss = val_loss |
|
best_model_state = copy.deepcopy(model.state_dict()) |
|
print(f" -> New best model saved with validation loss: {best_val_loss:.4f}") |
|
current_patience = 0 |
|
else: |
|
current_patience += 1 |
|
if current_patience >= PATIENCE: |
|
print(f'Early stopping at {PATIENCE} epochs') |
|
break |
|
|
|
print("\nTesting with the best model...") |
|
if not best_model_state is None: |
|
model.load_state_dict(best_model_state) |
|
test_loss = eval_epoch(model, test_loader, criterion, DEVICE) |
|
print(f'Test loss: {test_loss}') |
|
test_preds, test_true = test_model(model, test_loader, DEVICE) |
|
|
|
aggregated_results[name] = { |
|
'best_val_loss': best_val_loss, |
|
'test_predictions': test_preds, |
|
'test_labels': test_true |
|
} |
|
print(f"Finished testing for {name}.") |
|
test_smiles_list = list(test_smiles) |
|
similarities = compute_embedding_similarity( |
|
model.encoder, test_smiles_list, TOKENIZER, DEVICE, MAX_LEN |
|
) |
|
print(f"Similarity score: {similarities.mean():.4f}") |
|
if name == 'do_not_save': |
|
torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin') |
|
|
|
|
|
|
|
print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}") |
|
for name, result in aggregated_results.items(): |
|
if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']: |
|
auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro') |
|
print(f'{name} ROC AUC: {auc}') |
|
|
|
if name in ['lipophicility', 'esol', 'qm8']: |
|
rmse = root_mean_squared_error(result['test_labels'], result['test_predictions']) |
|
mae = mean_absolute_error(result['test_labels'], result['test_predictions']) |
|
print(f'{name} MAE: {mae}') |
|
print(f'{name} RMSE: {rmse}') |
|
|
|
print("\nScript finished.") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|