{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7678e528-e2d6-4ef4-bd11-8c745bbce7c4", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "from transformers import BertModel, BertConfig, AutoTokenizer\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.decomposition import PCA\n", "from sklearn.manifold import TSNE\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "from sklearn.cluster import KMeans\n", "import umap.umap_ as umap\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Set style for better plots\n", "plt.style.use('seaborn-v0_8')\n", "sns.set_palette(\"husl\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ad8d5a02-6162-41f7-a730-e31e353ed9b8", "metadata": {}, "outputs": [], "source": [ "class PrecomputedContrastiveSmilesDataset(Dataset):\n", " \"\"\"\n", " A Dataset class that reads pre-augmented SMILES pairs from a Parquet file.\n", " This is significantly faster as it offloads the expensive SMILES randomization\n", " to a one-time preprocessing step.\n", " \"\"\"\n", " def __init__(self, tokenizer, file_path: str, max_length: int = 512):\n", " self.tokenizer = tokenizer\n", " self.max_length = max_length\n", " \n", " # Load the entire dataset from the Parquet file into memory.\n", " # This is fast and efficient for subsequent access.\n", " print(f\"Loading pre-computed data from {file_path}...\")\n", " self.data = pd.read_parquet(file_path)\n", " print(\"Data loaded successfully.\")\n", "\n", " def __len__(self):\n", " \"\"\"Returns the total number of pairs in the dataset.\"\"\"\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " \"\"\"\n", " Retrieves a pre-augmented pair, tokenizes it, and returns it\n", " in the format expected by the DataCollator.\n", " \"\"\"\n", " # Retrieve the pre-augmented pair from the DataFrame\n", " row = self.data.iloc[idx]\n", " smiles_1 = row['smiles_1']\n", " smiles_2 = row['smiles_2']\n", " \n", " # Tokenize the pair. This operation is fast and remains in the data loader.\n", " tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')\n", " tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')\n", " \n", " return {\n", " 'input_ids_1': torch.tensor(tokens_1['input_ids']),\n", " 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),\n", " 'input_ids_2': torch.tensor(tokens_2['input_ids']),\n", " 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),\n", " }\n", "\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(self.config.pad_token_id)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = self.dropout(outputs.last_hidden_state)\n", " pooled = global_ap(hidden_states)\n", " return self.linear(pooled)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "476226dd-54f3-4d3e-adb4-cc30e922fd96", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OptimizedModule(\n", " (_orig_mod): SimSonEncoder(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(591, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-3): 4 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSdpaSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=2048, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=2048, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (linear): Linear(in_features=768, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "model_config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=12,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "model = SimSonEncoder(config=model_config, max_len=512).to(device)\n", "model = torch.compile(model)\n", "model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/simson_checkpoints_polymer_1M/simson_model_single_gpu.bin'))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 4, "id": "df28c332-c20b-4804-b4ca-97de9d652445", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test dataset shape: (5000, 2)\n", "Columns: ['smiles_1', 'smiles_2']\n" ] }, { "data": { "text/html": [ "
\n", " | smiles_1 | \n", "smiles_2 | \n", "
---|---|---|
0 | \n", "c1c(ccc(c1)OC(=O)CCCS(CCCc1ccc(C(O*)=O)cc1)(=O... | \n", "c1cc(ccc1OC(CCCS(CCCc1ccc(C(=O)O*)cc1)(=O)=O)=... | \n", "
1 | \n", "C(CC(=O)NCCC[Si](O*)(C)C)CCCCC(OCCCCCOC(Cc1cc(... | \n", "CN(*)Cc1cccc(c1)CC(OCCCCCOC(CCCCCCC(=O)NCCC[Si... | \n", "
2 | \n", "C(SCCNC(OCC*)=O)CSCCCOCCNC(=O)O* | \n", "O=C(OCC*)NCCSCCSCCCOCCNC(O*)=O | \n", "
3 | \n", "*CCCCOC(=O)CSCCCC[PH](C)(=O)O* | \n", "*CCCCOC(=O)CSCCCC[PH](C)(=O)O* | \n", "
4 | \n", "C(O*)COCCOCCN(C(=O)OCCC*)C(=O)OC | \n", "C(CC*)OC(=O)N(C(=O)OC)CCOCCOCCO* | \n", "