{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7678e528-e2d6-4ef4-bd11-8c745bbce7c4", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "from transformers import BertModel, BertConfig, AutoTokenizer\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.decomposition import PCA\n", "from sklearn.manifold import TSNE\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "from sklearn.cluster import KMeans\n", "import umap.umap_ as umap\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Set style for better plots\n", "plt.style.use('seaborn-v0_8')\n", "sns.set_palette(\"husl\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ad8d5a02-6162-41f7-a730-e31e353ed9b8", "metadata": {}, "outputs": [], "source": [ "class PrecomputedContrastiveSmilesDataset(Dataset):\n", " \"\"\"\n", " A Dataset class that reads pre-augmented SMILES pairs from a Parquet file.\n", " This is significantly faster as it offloads the expensive SMILES randomization\n", " to a one-time preprocessing step.\n", " \"\"\"\n", " def __init__(self, tokenizer, file_path: str, max_length: int = 512):\n", " self.tokenizer = tokenizer\n", " self.max_length = max_length\n", " \n", " # Load the entire dataset from the Parquet file into memory.\n", " # This is fast and efficient for subsequent access.\n", " print(f\"Loading pre-computed data from {file_path}...\")\n", " self.data = pd.read_parquet(file_path)\n", " print(\"Data loaded successfully.\")\n", "\n", " def __len__(self):\n", " \"\"\"Returns the total number of pairs in the dataset.\"\"\"\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " \"\"\"\n", " Retrieves a pre-augmented pair, tokenizes it, and returns it\n", " in the format expected by the DataCollator.\n", " \"\"\"\n", " # Retrieve the pre-augmented pair from the DataFrame\n", " row = self.data.iloc[idx]\n", " smiles_1 = row['smiles_1']\n", " smiles_2 = row['smiles_2']\n", " \n", " # Tokenize the pair. This operation is fast and remains in the data loader.\n", " tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')\n", " tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')\n", " \n", " return {\n", " 'input_ids_1': torch.tensor(tokens_1['input_ids']),\n", " 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),\n", " 'input_ids_2': torch.tensor(tokens_2['input_ids']),\n", " 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),\n", " }\n", "\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(self.config.pad_token_id)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = self.dropout(outputs.last_hidden_state)\n", " pooled = global_ap(hidden_states)\n", " return self.linear(pooled)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "476226dd-54f3-4d3e-adb4-cc30e922fd96", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OptimizedModule(\n", " (_orig_mod): SimSonEncoder(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(591, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-3): 4 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSdpaSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=2048, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=2048, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (linear): Linear(in_features=768, out_features=512, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "model_config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=12,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "model = SimSonEncoder(config=model_config, max_len=512).to(device)\n", "model = torch.compile(model)\n", "model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/simson_checkpoints_polymer_1M/simson_model_single_gpu.bin'))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 4, "id": "df28c332-c20b-4804-b4ca-97de9d652445", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test dataset shape: (5000, 2)\n", "Columns: ['smiles_1', 'smiles_2']\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smiles_1smiles_2
0c1c(ccc(c1)OC(=O)CCCS(CCCc1ccc(C(O*)=O)cc1)(=O...c1cc(ccc1OC(CCCS(CCCc1ccc(C(=O)O*)cc1)(=O)=O)=...
1C(CC(=O)NCCC[Si](O*)(C)C)CCCCC(OCCCCCOC(Cc1cc(...CN(*)Cc1cccc(c1)CC(OCCCCCOC(CCCCCCC(=O)NCCC[Si...
2C(SCCNC(OCC*)=O)CSCCCOCCNC(=O)O*O=C(OCC*)NCCSCCSCCCOCCNC(O*)=O
3*CCCCOC(=O)CSCCCC[PH](C)(=O)O**CCCCOC(=O)CSCCCC[PH](C)(=O)O*
4C(O*)COCCOCCN(C(=O)OCCC*)C(=O)OCC(CC*)OC(=O)N(C(=O)OC)CCOCCOCCO*
\n", "
" ], "text/plain": [ " smiles_1 \\\n", "0 c1c(ccc(c1)OC(=O)CCCS(CCCc1ccc(C(O*)=O)cc1)(=O... \n", "1 C(CC(=O)NCCC[Si](O*)(C)C)CCCCC(OCCCCCOC(Cc1cc(... \n", "2 C(SCCNC(OCC*)=O)CSCCCOCCNC(=O)O* \n", "3 *CCCCOC(=O)CSCCCC[PH](C)(=O)O* \n", "4 C(O*)COCCOCCN(C(=O)OCCC*)C(=O)OC \n", "\n", " smiles_2 \n", "0 c1cc(ccc1OC(CCCS(CCCc1ccc(C(=O)O*)cc1)(=O)=O)=... \n", "1 CN(*)Cc1cccc(c1)CC(OCCCCCOC(CCCCCCC(=O)NCCC[Si... \n", "2 O=C(OCC*)NCCSCCSCCCOCCNC(O*)=O \n", "3 *CCCCOC(=O)CSCCCC[PH](C)(=O)O* \n", "4 C(CC*)OC(=O)N(C(=O)OC)CCOCCOCCO* " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data = pd.read_parquet('/home/jovyan/simson_training_bolgov/data/polymer_splits/test.parquet')\n", "print(f\"Test dataset shape: {test_data.shape}\")\n", "print(f\"Columns: {test_data.columns.tolist()}\")\n", "test_data.head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "a7f2e2bc-12f5-443b-9d80-899542e07370", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating embeddings for original SMILES...\n", "Processed 2560 / 5000 SMILES\n", "Processed 5000 / 5000 SMILES\n", "Generating embeddings for augmented SMILES...\n", "Processed 2560 / 5000 SMILES\n", "Processed 5000 / 5000 SMILES\n", "Original embeddings shape: (5000, 512)\n", "Augmented embeddings shape: (5000, 512)\n" ] } ], "source": [ "def generate_embeddings(model, tokenizer, smiles_list, batch_size=256, max_length=512):\n", " \"\"\"Generate embeddings for a list of SMILES strings\"\"\"\n", " model.eval()\n", " embeddings = []\n", " \n", " with torch.no_grad():\n", " for i in range(0, len(smiles_list), batch_size):\n", " batch_smiles = smiles_list[i:i+batch_size]\n", " \n", " # Tokenize batch\n", " tokens = tokenizer(batch_smiles, \n", " max_length=max_length, \n", " truncation=True, \n", " padding='max_length', \n", " return_tensors='pt')\n", " \n", " # Move to device\n", " input_ids = tokens['input_ids'].to(device)\n", " attention_mask = tokens['attention_mask'].to(device)\n", " \n", " # Generate embeddings\n", " batch_embeddings = model(input_ids, attention_mask)\n", " embeddings.append(batch_embeddings.cpu().numpy())\n", " \n", " if (i // batch_size + 1) % 10 == 0:\n", " print(f\"Processed {i + len(batch_smiles)} / {len(smiles_list)} SMILES\")\n", " \n", " return np.vstack(embeddings)\n", "\n", "# Generate embeddings for original and augmented SMILES\n", "print(\"Generating embeddings for original SMILES...\")\n", "original_embeddings = generate_embeddings(model, tokenizer, test_data['smiles_1'].tolist())\n", "\n", "print(\"Generating embeddings for augmented SMILES...\")\n", "augmented_embeddings = generate_embeddings(model, tokenizer, test_data['smiles_2'].tolist())\n", "\n", "print(f\"Original embeddings shape: {original_embeddings.shape}\")\n", "print(f\"Augmented embeddings shape: {augmented_embeddings.shape}\")\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "c770d7f4-fa37-4519-bda2-9b084d2a4b32", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average cosine similarity between original and augmented SMILES: 0.9874\n", "Standard deviation: 0.0197\n", "Min similarity: 0.7691\n", "Max similarity: 1.0000\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def calculate_pairwise_similarities(embeddings1, embeddings2):\n", " \"\"\"Calculate cosine similarities between corresponding pairs\"\"\"\n", " similarities = []\n", " for i in range(len(embeddings1)):\n", " sim = cosine_similarity([embeddings1[i]], [embeddings2[i]])[0][0]\n", " similarities.append(sim)\n", " return np.array(similarities)\n", "\n", "# Calculate cosine similarities\n", "pairwise_similarities = calculate_pairwise_similarities(original_embeddings, augmented_embeddings)\n", "\n", "print(f\"Average cosine similarity between original and augmented SMILES: {np.mean(pairwise_similarities):.4f}\")\n", "print(f\"Standard deviation: {np.std(pairwise_similarities):.4f}\")\n", "print(f\"Min similarity: {np.min(pairwise_similarities):.4f}\")\n", "print(f\"Max similarity: {np.max(pairwise_similarities):.4f}\")\n", "\n", "# Plot similarity distribution\n", "plt.figure(figsize=(10, 6))\n", "plt.hist(pairwise_similarities, bins=50, alpha=0.7, edgecolor='black')\n", "plt.axvline(np.mean(pairwise_similarities), color='red', linestyle='--', \n", " label=f'Mean: {np.mean(pairwise_similarities):.4f}')\n", "plt.xlabel('Cosine Similarity')\n", "plt.ylabel('Frequency')\n", "plt.title('Distribution of Cosine Similarities Between Original and Augmented SMILES')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "d7fa6ed4-b546-4544-bb00-4d90a99678da", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== Embedding Space Analysis ===\n", "Embedding dimensionality: 512\n", "Average L2 norm (original): 11.9105\n", "Average L2 norm (augmented): 11.9072\n", "Average intra-class similarity (original): 0.0047\n", "Average intra-class similarity (augmented): 0.0048\n", "Average inter-class similarity: 0.0049\n", "\n", "=== Nearest Neighbor Analysis (k=5) ===\n", "Augmented SMILES in top-5 neighbors: 5000/5000 (100.0%)\n", "Augmented SMILES as top-1 neighbor: 4872/5000 (97.4%)\n" ] } ], "source": [ "# 1. Embedding Space Statistics\n", "def analyze_embedding_space(original_emb, augmented_emb):\n", " \"\"\"Analyze the embedding space properties\"\"\"\n", " print(\"=== Embedding Space Analysis ===\")\n", " \n", " # Dimensionality and norms\n", " print(f\"Embedding dimensionality: {original_emb.shape[1]}\")\n", " print(f\"Average L2 norm (original): {np.mean(np.linalg.norm(original_emb, axis=1)):.4f}\")\n", " print(f\"Average L2 norm (augmented): {np.mean(np.linalg.norm(augmented_emb, axis=1)):.4f}\")\n", " \n", " # Intra-class similarities\n", " orig_similarities = cosine_similarity(original_emb)\n", " aug_similarities = cosine_similarity(augmented_emb)\n", " \n", " # Remove diagonal (self-similarity)\n", " orig_similarities_off_diag = orig_similarities[np.triu_indices_from(orig_similarities, k=1)]\n", " aug_similarities_off_diag = aug_similarities[np.triu_indices_from(aug_similarities, k=1)]\n", " \n", " print(f\"Average intra-class similarity (original): {np.mean(orig_similarities_off_diag):.4f}\")\n", " print(f\"Average intra-class similarity (augmented): {np.mean(aug_similarities_off_diag):.4f}\")\n", " \n", " # Inter-class similarities\n", " inter_similarities = cosine_similarity(original_emb, augmented_emb)\n", " print(f\"Average inter-class similarity: {np.mean(inter_similarities):.4f}\")\n", "\n", "analyze_embedding_space(original_embeddings, augmented_embeddings)\n", "\n", "# 2. Nearest Neighbor Analysis\n", "def nearest_neighbor_analysis(original_emb, augmented_emb, k=5):\n", " \"\"\"Analyze nearest neighbors between original and augmented embeddings\"\"\"\n", " print(f\"\\n=== Nearest Neighbor Analysis (k={k}) ===\")\n", " \n", " # For each original embedding, find its k nearest neighbors in augmented set\n", " similarities = cosine_similarity(original_emb, augmented_emb)\n", " \n", " # Find cases where augmented version is among top-k neighbors\n", " correct_matches = 0\n", " top1_matches = 0\n", " \n", " for i in range(len(original_emb)):\n", " # Get similarity scores for i-th original embedding\n", " sim_scores = similarities[i]\n", " top_k_indices = np.argsort(sim_scores)[-k:][::-1]\n", " \n", " if i in top_k_indices:\n", " correct_matches += 1\n", " if np.argmax(sim_scores) == i:\n", " top1_matches += 1\n", " \n", " print(f\"Augmented SMILES in top-{k} neighbors: {correct_matches}/{len(original_emb)} ({100*correct_matches/len(original_emb):.1f}%)\")\n", " print(f\"Augmented SMILES as top-1 neighbor: {top1_matches}/{len(original_emb)} ({100*top1_matches/len(original_emb):.1f}%)\")\n", "\n", "nearest_neighbor_analysis(original_embeddings, augmented_embeddings)" ] }, { "cell_type": "markdown", "id": "f136743a-c742-469b-a9b8-9a4da8eb4c08", "metadata": {}, "source": [ "Объяснение\n", "\n", "* Embedding Space Analysis. Показывает, что различные величины (L2 norm - длина вектора, близости) практически идентичны для оригинальных, и для аугментированных молекул (показывая, что они отображаются практически одними и теми же для модели)\n", "* Nearest Neighbor Analysis (k=5). Показывает, что топ 5 ближайших векторов к любому из векторов - всегда его аугментация (кроме его самого, естественно). В 97.4 % случаях вектор, соответствующий аугментации, является ближайшим." ] }, { "cell_type": "markdown", "id": "133fb52e-bea8-4159-ab06-386998b71ed0", "metadata": {}, "source": [ "С разными SMILES - ровно обратная картина, как и должно быть (различные smiles являются очень разными)" ] }, { "cell_type": "code", "execution_count": 8, "id": "cc48f04e-b8bf-415b-89f0-b51d98f7e6b6", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "\n", "num_molecules = original_embeddings.shape[0]\n", "\n", "# Shuffle indices for unrelated molecules\n", "unrelated_indices = np.random.permutation(num_molecules)\n", "unrelated_embeddings = augmented_embeddings[unrelated_indices]\n", "\n", "# Compute pairwise cosine similarity between original and unrelated\n", "pairwise_unrelated_similarities = np.array([\n", " cosine_similarity([original_embeddings[i]], [unrelated_embeddings[i]])[0][0]\n", " for i in range(num_molecules)\n", "])\n", "\n", "\n", "plt.figure(figsize=(10,6))\n", "plt.hist(pairwise_unrelated_similarities, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n", "mean_sim = pairwise_unrelated_similarities.mean()\n", "plt.axvline(mean_sim, color='red', linestyle='--', label=f'Mean = {mean_sim:.4f}')\n", "plt.xlabel('Cosine Similarity')\n", "plt.ylabel('Frequency')\n", "plt.title('Cosine Similarity Distribution of Different Molecules (Unrelated SMILES)')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "881eca62-7bf5-4b84-b0df-659942930a73", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean cosine similarity: 0.0037\n", "Std deviation: 0.1393\n", "Range: -0.4398 to 1.0000\n" ] } ], "source": [ "print(f\"Mean cosine similarity: {pairwise_unrelated_similarities.mean():.4f}\")\n", "print(f\"Std deviation: {pairwise_unrelated_similarities.std():.4f}\")\n", "print(f\"Range: {pairwise_unrelated_similarities.min():.4f} to {pairwise_unrelated_similarities.max():.4f}\")\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "051f4149-d3eb-472b-b6d0-7060662efb6a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unrelated in top-5 neighbors: 7/5000\n", "Unrelated as top-1 neighbor: 2/5000\n" ] } ], "source": [ "from sklearn.metrics.pairwise import cosine_similarity\n", "\n", "k = 5\n", "similarities = cosine_similarity(original_embeddings, unrelated_embeddings)\n", "correct_matches = sum(i in np.argsort(similarities[i])[-k:] for i in range(num_molecules))\n", "top1_matches = sum(np.argmax(similarities[i]) == i for i in range(num_molecules))\n", "\n", "print(f\"Unrelated in top-{k} neighbors: {correct_matches}/{num_molecules}\")\n", "print(f\"Unrelated as top-1 neighbor: {top1_matches}/{num_molecules}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:.mlspace-bolgov_simson_training]", "language": "python", "name": "conda-env-.mlspace-bolgov_simson_training-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }