{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "6db37aa6", "metadata": {}, "source": [ "# MoL-MoE Foundation Models - Multi Output (K=4)" ] }, { "cell_type": "code", "execution_count": null, "id": "f13dd822", "metadata": {}, "outputs": [], "source": [ "# System\n", "import warnings\n", "import sys\n", "sys.path.insert(1, '../')\n", "sys.path.insert(2, '../experts')\n", "sys.path.insert(3, '../moe')\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# Deep learning\n", "import torch.nn.functional as F\n", "import torch\n", "from torch import nn\n", "from moe import MoE, train\n", "from models import Net\n", "\n", "# Machine learning\n", "from xgboost import XGBClassifier\n", "from sklearn.metrics import roc_auc_score\n", "\n", "# Data\n", "import pandas as pd\n", "import numpy as np\n", "\n", "# Chemistry\n", "from rdkit import Chem\n", "from rdkit.Chem import PandasTools\n", "from rdkit.Chem import Descriptors\n", "PandasTools.RenderImagesInAllDataFrames(True)\n", "\n", "def normalize_smiles(smi, canonical=True, isomeric=False):\n", " try:\n", " normalized = Chem.MolToSmiles(\n", " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", " )\n", " except:\n", " normalized = None\n", " return normalized\n", "\n", "torch.manual_seed(0)\n", "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "id": "5afbebeb", "metadata": {}, "source": [ "## Load Foundation Models" ] }, { "cell_type": "code", "execution_count": null, "id": "e531965e", "metadata": {}, "outputs": [], "source": [ "from experts.selfies_ted.load import SELFIES\n", "\n", "model_selfies = SELFIES()\n", "model_selfies.load()" ] }, { "cell_type": "code", "execution_count": null, "id": "6cee782c", "metadata": {}, "outputs": [], "source": [ "from experts.mhg_model.load import load\n", "\n", "mhg_gnn = load()" ] }, { "cell_type": "code", "execution_count": null, "id": "93e1dbfa", "metadata": {}, "outputs": [], "source": [ "from experts.smi_ted_light.load import load_smi_ted, MolTranBertTokenizer\n", "\n", "smi_ted = load_smi_ted()" ] }, { "cell_type": "markdown", "id": "91c14dd5", "metadata": {}, "source": [ "## Load datasets" ] }, { "cell_type": "code", "execution_count": null, "id": "d3945b2d", "metadata": {}, "outputs": [], "source": [ "train_df = pd.read_csv(\"../data/moleculenet/bbbp/train.csv\")\n", "valid_df = pd.read_csv(\"../data/moleculenet/bbbp/valid.csv\")\n", "test_df = pd.read_csv(\"../data/moleculenet/bbbp/test.csv\")" ] }, { "cell_type": "code", "execution_count": null, "id": "60dbdc78", "metadata": { "scrolled": true }, "outputs": [], "source": [ "train_df['canon_smiles'] = train_df['smiles'].apply(normalize_smiles)\n", "train_df = train_df.dropna(subset='canon_smiles')\n", "print(train_df.shape)\n", "train_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "a5e69c82", "metadata": { "scrolled": true }, "outputs": [], "source": [ "valid_df['canon_smiles'] = valid_df['smiles'].apply(normalize_smiles)\n", "valid_df = valid_df.dropna(subset='canon_smiles')\n", "print(valid_df.shape)\n", "valid_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "bbb1cdd5", "metadata": { "scrolled": true }, "outputs": [], "source": [ "test_df['canon_smiles'] = test_df['smiles'].apply(normalize_smiles)\n", "test_df = test_df.dropna(subset='canon_smiles')\n", "print(test_df.shape)\n", "test_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "b4f1f377", "metadata": {}, "outputs": [], "source": [ "smiles_col = 'canon_smiles'\n", "target = 'p_np'\n", "\n", "# training\n", "X_train = train_df[smiles_col].to_list()\n", "y_train = train_df[target]\n", "\n", "# validation\n", "X_valid = valid_df[smiles_col].to_list()\n", "y_valid = valid_df[target]\n", "\n", "# test\n", "X_test = test_df[smiles_col].to_list()\n", "y_test = test_df[target]" ] }, { "cell_type": "markdown", "id": "09c0c01c", "metadata": {}, "source": [ "## Training MoE" ] }, { "cell_type": "code", "execution_count": null, "id": "cf69fc9b", "metadata": {}, "outputs": [], "source": [ "# arguments\n", "input_size = 768\n", "output_size = 2048\n", "num_experts = 12\n", "k = 4\n", "batch_size = 16\n", "learning_rate = 3e-5\n", "epochs = 100\n", "\n", "# experts\n", "models = [\n", " smi_ted, smi_ted, smi_ted, smi_ted, # SMI-TED\n", " model_selfies, model_selfies, model_selfies, model_selfies, # SELFIES-BART\n", " mhg_gnn, mhg_gnn, mhg_gnn, mhg_gnn # MHG-GNN\n", "]\n", "\n", "# instantiate the MoE layer\n", "net = Net(smiles_embed_dim=2048, dropout=0.2, output_dim=2)\n", "tokenizer = MolTranBertTokenizer('../experts/smi_ted_light/bert_vocab_curated.txt')\n", "moe_model = MoE(input_size, \n", " output_size, \n", " num_experts, \n", " models=models, \n", " tokenizer=tokenizer, \n", " tok_emb=smi_ted.encoder.tok_emb, \n", " k=k, \n", " noisy_gating=False, \n", " verbose=False).to(DEVICE)\n", "\n", "net.apply(smi_ted._init_weights)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "params = list(moe_model.parameters()) + list(net.parameters())\n", "optim = torch.optim.AdamW(params, lr=learning_rate)\n", "\n", "train_loader = torch.utils.data.DataLoader(list(zip(X_train, y_train)), batch_size=batch_size,\n", " shuffle=True, num_workers=1)\n", "\n", "# train\n", "moe_model, net = train(train_loader, moe_model, net, loss_fn, optim, epochs)" ] }, { "cell_type": "markdown", "id": "57aa1cc0", "metadata": {}, "source": [ "## Evaluate (using auxiliary Net)" ] }, { "cell_type": "code", "execution_count": null, "id": "b9e18da2", "metadata": { "scrolled": true }, "outputs": [], "source": [ "moe_model.eval()\n", "net.eval()\n", "\n", "with torch.no_grad():\n", " out, _ = moe_model(X_test, verbose=False)\n", " preds = net(out)\n", " preds_cpu = F.softmax(preds, dim=1)[:, 1]\n", " print('Prediction probabilities:', preds_cpu[:30])" ] }, { "cell_type": "code", "execution_count": null, "id": "f204d19d", "metadata": {}, "outputs": [], "source": [ "roc_auc = roc_auc_score(y_test, preds_cpu.detach().numpy())\n", "print(f\"ROC-AUC Score: {roc_auc:.4f}\")" ] }, { "cell_type": "markdown", "id": "3af9e51e", "metadata": {}, "source": [ "# Training XGBoost from MoE" ] }, { "cell_type": "code", "execution_count": null, "id": "cb0d9f5a", "metadata": {}, "outputs": [], "source": [ "# extract embeddings\n", "moe_model.eval()\n", "net.eval()\n", "\n", "with torch.no_grad():\n", " xgb_train, _ = moe_model(X_train, verbose=True)\n", " xgb_valid, _ = moe_model(X_valid, verbose=True)\n", " xgb_test, _ = moe_model(X_test, verbose=True)\n", " \n", "xgb_train = xgb_train.detach().numpy()\n", "xgb_valid = xgb_valid.detach().numpy()\n", "xgb_test = xgb_test.detach().numpy()" ] }, { "cell_type": "code", "execution_count": null, "id": "a099f2f1", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n", "from sklearn.metrics import roc_auc_score\n", "from xgboost import XGBClassifier\n", "import numpy as np\n", "\n", "# Define lists to store ROC-AUC scores and model instances\n", "roc_auc_scores = []\n", "\n", "# Loop over seeds from 0 to 90 in steps of 10\n", "for seed in range(0, 91, 10):\n", " # Define XGBoost parameters with different values for each seed\n", " xgb_params = {\n", " 'learning_rate': [0.01, 0.4, 0.6, 0.8],\n", " 'max_depth': [6, 8, 10, 12],\n", " 'n_estimators': [1500, 2000, 2200]\n", " }\n", "\n", " # Initialize XGBoost classifier\n", " xgb_classifier = XGBClassifier()\n", "\n", " # Perform RandomizedSearchCV to find optimal hyperparameters\n", " random_search = RandomizedSearchCV(estimator=xgb_classifier, param_distributions=xgb_params, n_iter=10, scoring='roc_auc', cv=3, random_state=seed)\n", " random_search.fit(xgb_train, y_train)\n", "\n", " # Get best estimator and predict probabilities\n", " best_estimator = random_search.best_estimator_\n", " y_prob = best_estimator.predict_proba(xgb_test)[:, 1]\n", "\n", " # Evaluate ROC-AUC score\n", " roc_auc = roc_auc_score(y_test, y_prob)\n", " roc_auc_scores.append(roc_auc)\n", "\n", " print(f\"Seed {seed}: ROC-AUC Score: {roc_auc:.4f}\")\n", "\n", "# Calculate standard deviation and average ROC-AUC score\n", "std_dev = np.std(roc_auc_scores)\n", "avg_roc_auc = np.mean(roc_auc_scores)\n", "\n", "# Plot ROC-AUC scores\n", "plt.figure(figsize=(8, 6))\n", "plt.errorbar(range(0, 91, 10), roc_auc_scores, yerr=std_dev, fmt='o', color='b')\n", "plt.hlines(avg_roc_auc, xmin=-1, xmax=91, colors='r', linestyles='dashed', label=f'Average ROC-AUC: {avg_roc_auc:.4f}')\n", "plt.xlabel('Seed')\n", "plt.ylabel('ROC-AUC Score')\n", "plt.title('ROC-AUC Scores with Standard Deviation')\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }