{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f72ec62c-c849-45b2-9321-b913c9f32979", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Encoder parameters loaded\n", "Encoder parameters frozen\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from transformers import BertConfig, BertModel, AutoTokenizer\n", "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, EarlyStoppingCallback\n", "from torch.utils.data import Dataset\n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(0)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = outputs.last_hidden_state\n", " hidden_states = self.dropout(hidden_states)\n", " pooled = global_ap(hidden_states)\n", " out = self.linear(pooled)\n", " return out\n", "\n", "class SimSonDecoder(nn.Module):\n", " def __init__(self, embedding_dim, hidden_dim, vocab_size, max_len):\n", " super(SimSonDecoder, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.hidden_dim = hidden_dim\n", " self.vocab_size = vocab_size\n", " self.max_len = max_len\n", " \n", " # Project embedding to hidden dimension\n", " self.embedding_projection = nn.Linear(embedding_dim, hidden_dim)\n", " \n", " # Token embeddings for decoder input\n", " self.token_embeddings = nn.Embedding(vocab_size, hidden_dim)\n", " self.position_embeddings = nn.Embedding(max_len, hidden_dim)\n", " \n", " # Transformer decoder layers\n", " decoder_layer = nn.TransformerDecoderLayer(\n", " d_model=hidden_dim,\n", " nhead=12,\n", " dim_feedforward=2048,\n", " dropout=0.1,\n", " batch_first=True\n", " )\n", " self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", " \n", " # Output projection to vocabulary\n", " self.output_projection = nn.Linear(hidden_dim, vocab_size)\n", " self.dropout = nn.Dropout(0.1)\n", " \n", " # Layer normalization\n", " self.layer_norm = nn.LayerNorm(hidden_dim)\n", " \n", " def forward(self, molecule_embeddings, decoder_input_ids, decoder_attention_mask=None):\n", " batch_size, seq_len = decoder_input_ids.shape\n", " device = decoder_input_ids.device\n", " \n", " # Project molecule embeddings to memory for cross-attention\n", " memory = self.embedding_projection(molecule_embeddings) # [batch, hidden_dim]\n", " memory = memory.unsqueeze(1) # [batch, 1, hidden_dim]\n", " \n", " # Create decoder input embeddings\n", " token_emb = self.token_embeddings(decoder_input_ids) # [batch, seq_len, hidden_dim]\n", " \n", " # Add positional embeddings\n", " positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)\n", " pos_emb = self.position_embeddings(positions)\n", " \n", " decoder_inputs = self.layer_norm(token_emb + pos_emb)\n", " decoder_inputs = self.dropout(decoder_inputs)\n", " \n", " # Create causal mask for decoder\n", " tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)\n", " \n", " # Create attention mask for decoder inputs\n", " if decoder_attention_mask is not None:\n", " # Convert padding mask to additive mask\n", " tgt_key_padding_mask = (decoder_attention_mask == 0)\n", " else:\n", " tgt_key_padding_mask = None\n", " \n", " # Apply transformer decoder\n", " decoder_output = self.transformer_decoder(\n", " tgt=decoder_inputs,\n", " memory=memory,\n", " tgt_mask=tgt_mask,\n", " tgt_key_padding_mask=tgt_key_padding_mask\n", " )\n", " \n", " # Project to vocabulary\n", " logits = self.output_projection(decoder_output)\n", " \n", " return logits\n", " \n", " def _generate_square_subsequent_mask(self, sz):\n", " \"\"\"Generate causal mask for decoder\"\"\"\n", " mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)\n", " return mask\n", "\n", "# Combined Encoder-Decoder Model\n", "class SimSonEncoderDecoder(nn.Module):\n", " def __init__(self, encoder, decoder, tokenizer):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.tokenizer = tokenizer\n", "\n", " def forward(self, input_ids, attention_mask, labels=None):\n", " # Encode SMILES to embeddings\n", " embeddings = self.encoder(input_ids, attention_mask)\n", " \n", " if labels is not None:\n", " # During training, use teacher forcing\n", " # Shift labels for decoder input (remove last token, add BOS token)\n", " decoder_input_ids = torch.cat([\n", " torch.full((labels.shape[0], 1), self.tokenizer.cls_token_id, \n", " dtype=labels.dtype, device=labels.device),\n", " labels[:, :-1]\n", " ], dim=1)\n", " \n", " # Generate decoder attention mask\n", " decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).long()\n", " \n", " # Decode embeddings to SMILES\n", " logits = self.decoder(embeddings, decoder_input_ids, decoder_attention_mask)\n", " \n", " # Calculate loss\n", " loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)\n", " loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))\n", " \n", " return {\"loss\": loss, \"logits\": logits}\n", " else:\n", " # During inference\n", " return self.generate(embeddings)\n", " \n", " def generate(self, embeddings, max_length=512):\n", " \"\"\"Generate SMILES from embeddings\"\"\"\n", " batch_size = embeddings.shape[0]\n", " device = embeddings.device\n", " \n", " # Start with CLS token\n", " generated = torch.full((batch_size, 1), self.tokenizer.cls_token_id, \n", " dtype=torch.long, device=device)\n", " \n", " for _ in range(max_length - 1):\n", " decoder_attention_mask = (generated != self.tokenizer.pad_token_id).long()\n", " logits = self.decoder(embeddings, generated, decoder_attention_mask)\n", " \n", " # Get next token probabilities\n", " next_token_logits = logits[:, -1, :]\n", " next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)\n", " \n", " # Append to generated sequence\n", " generated = torch.cat([generated, next_tokens], dim=1)\n", " \n", " # Stop if all sequences have generated EOS token\n", " if torch.all(next_tokens.squeeze() == self.tokenizer.sep_token_id):\n", " break\n", " \n", " return generated\n", "\n", "# Initialize tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "\n", "# Initialize encoder config\n", "config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=4,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "# Initialize and load encoder\n", "encoder = SimSonEncoder(config=config, max_len=512)\n", "\n", "# Load encoder parameters (extract encoder from regression model)\n", "regression_state_dict = torch.load('/home/jovyan/simson_training_bolgov/regression/better_regression_states/best_state.bin', weights_only=False)\n", "\n", "encoder_state_dict = {}\n", "for key, value in regression_state_dict.items():\n", " key = key[len('_orig_mod.'):]\n", " if key.startswith('encoder.'):\n", " encoder_state_dict[key[8 + len('_orig_mod.'):]] = value\n", "\n", "print(\"Encoder parameters loaded\")\n", "\n", "# Freeze encoder parameters\n", "for param in encoder.parameters():\n", " param.requires_grad = False\n", "print(\"Encoder parameters frozen\")\n", "\n", "# Initialize decoder\n", "decoder = SimSonDecoder(\n", " embedding_dim=512, # encoder.max_len\n", " hidden_dim=768, # config.hidden_size\n", " vocab_size=tokenizer.vocab_size,\n", " max_len=512\n", ")\n", "\n", "# Create combined model\n", "model = SimSonEncoderDecoder(encoder, decoder, tokenizer)\n" ] }, { "cell_type": "markdown", "id": "27946411-ddb4-48f2-ab5d-baec24e53954", "metadata": {}, "source": [ "regression_params = torch.load('/home/jovyan/simson_training_bolgov/regression/simson_encoder_decoder.bin', weights_only=False)\n", "\n", "uncompiled_state_dict = {}\n", "for key, value in regression_params.items():\n", " print(key)\n", " key = key[len('_orig_mod.'):]\n", " uncompiled_state_dict[key] = value\n", " print(key)\n", " break\n", " \n", "torch.save(uncompiled_state_dict, '/home/jovyan/simson_training_bolgov/regression/encoder_decoder_uncompiled.bin')" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bd4f17d-e369-4e4c-be51-a91cb0b85a0d", "metadata": {}, "outputs": [], "source": [ "# Load dataset\n", "df = pd.read_csv('/home/jovyan/simson_training_bolgov/regression/PI_Tg_P308K_synth_db_chem.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "6a40fac8-9c46-4c96-8a47-94ea32803b21", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "id": "5951cb13-e188-4a2c-8b42-9f8125e96165", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training samples: 6390602\n", "Validation samples: 336348\n" ] } ], "source": [ "class SMILESReconstructionDataset(Dataset):\n", " def __init__(self, dataframe, tokenizer, max_length=256):\n", " self.data = dataframe\n", " self.tokenizer = tokenizer\n", " self.max_length = max_length\n", " \n", " def __len__(self):\n", " return len(self.data)\n", " \n", " def __getitem__(self, idx):\n", " # Get SMILES string (adjust column name as needed)\n", " smiles = self.data.iloc[idx]['smiles'] if 'smiles' in self.data.columns else self.data.iloc[idx]['Smiles']\n", " \n", " # Tokenize input SMILES\n", " encoding = self.tokenizer(\n", " smiles,\n", " max_length=self.max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " )\n", " \n", " return {\n", " 'input_ids': encoding['input_ids'].squeeze(0),\n", " 'attention_mask': encoding['attention_mask'].squeeze(0),\n", " 'labels': encoding['input_ids'].squeeze(0) # Same as input for reconstruction\n", " }\n", "\n", "def create_stratified_splits_regression(\n", " df,\n", " label_cols,\n", " n_bins=10,\n", " val_frac=0.05,\n", " seed=42\n", "):\n", " \n", " values = df[label_cols].values\n", " # Each label gets its own bins, based on the overall distribution\n", " bins = [np.unique(np.quantile(values[:,i], np.linspace(0, 1, n_bins+1))) for i in range(len(label_cols))]\n", " # Assign each row to a bin for each label\n", " inds = [\n", " np.digitize(values[:,i], bins[i][1:-1], right=False) # exclude leftmost/rightmost for in-bin, avoids all bin edges as bins\n", " for i in range(len(label_cols))\n", " ]\n", " # Combine into a single integer stratification variable (tuple or max or sum...)\n", " strat_col = np.maximum.reduce(inds) # This ensures high bin in one = high bin overall\n", " # Use sklearn's train_test_split with stratify\n", " train_idx, val_idx = train_test_split(\n", " df.index.values,\n", " test_size=val_frac,\n", " random_state=seed,\n", " shuffle=True,\n", " stratify=strat_col\n", " )\n", " train = df.loc[train_idx].reset_index(drop=True)\n", " val = df.loc[val_idx].reset_index(drop=True)\n", " return train, val\n", "\n", "\n", "tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR'\n", "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n", "\n", "# Same splits\n", "train, test = create_stratified_splits_regression(\n", " df,\n", " label_cols=['CO2', 'CH4'], \n", " n_bins=10,\n", " val_frac=0.05,\n", " seed=42\n", ")\n", "\n", "train_dataset = SMILESReconstructionDataset(train, tokenizer)\n", "val_dataset = SMILESReconstructionDataset(test, tokenizer)\n", "val_df = test\n", "train_df = train\n", "print(f\"Training samples: {len(train_dataset)}\")\n", "print(f\"Validation samples: {len(val_dataset)}\")\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "2b79c8b4-4f5d-4752-b1a7-08ed2d0bb7f5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_330276/2978564159.py:51: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Seq2SeqTrainer.__init__`. Use `processing_class` instead.\n", " trainer = Seq2SeqTrainer(\n" ] } ], "source": [ "# Custom data collator for encoder-decoder\n", "class EncoderDecoderDataCollator:\n", " def __init__(self, tokenizer):\n", " self.tokenizer = tokenizer\n", " \n", " def __call__(self, batch):\n", " # Stack all batch elements\n", " input_ids = torch.stack([item['input_ids'] for item in batch])\n", " attention_mask = torch.stack([item['attention_mask'] for item in batch])\n", " labels = torch.stack([item['labels'] for item in batch])\n", " \n", " return {\n", " 'input_ids': input_ids,\n", " 'attention_mask': attention_mask,\n", " 'labels': labels\n", " }\n", "\n", "\n", "data_collator = EncoderDecoderDataCollator(tokenizer)\n", "\n", "early_stopping_callback = EarlyStoppingCallback(\n", " early_stopping_patience=10,\n", ")\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir='/home/jovyan/simson_training_bolgov/regression/decoder_checkpoints',\n", " eval_strategy='steps',\n", " eval_steps=10_000,\n", " logging_steps=10_000,\n", " save_steps=10_000,\n", " save_total_limit=3,\n", " learning_rate=5e-5,\n", " per_device_train_batch_size=256,\n", " per_device_eval_batch_size=256,\n", " gradient_accumulation_steps=1,\n", " num_train_epochs=5,\n", " warmup_steps=50,\n", " weight_decay=0.01,\n", " logging_dir='./logs',\n", " report_to='none',\n", " load_best_model_at_end=True,\n", " metric_for_best_model='eval_loss',\n", " greater_is_better=False,\n", " fp16=True,\n", " dataloader_pin_memory=True,\n", " remove_unused_columns=False,\n", " predict_with_generate=False \n", ")\n", "\n", "# Initialize trainer\n", "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " data_collator=data_collator,\n", " tokenizer=tokenizer\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "88eccafd-c269-4bc6-8290-0bdb5d184c31", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "9ac6754e-dd6d-4282-b661-0b11e65cadc3", "metadata": {}, "outputs": [], "source": [ "final_state = trainer.model.state_dict()\n", "torch.save(final_state, '/home/jovyan/simson_training_bolgov/regression/simson_encoder_decoder_better_regression.bin')" ] }, { "cell_type": "code", "execution_count": 5, "id": "18637615-64f4-42d5-a431-0765df5bc024", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model initialized with 72,264,271 trainable parameters\n" ] } ], "source": [ "import torch.nn.functional as F\n", "\n", "\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(0)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = outputs.last_hidden_state\n", " hidden_states = self.dropout(hidden_states)\n", " pooled = global_ap(hidden_states)\n", " out = self.linear(pooled)\n", " return out\n", "\n", "class SimSonDecoder(nn.Module):\n", " def __init__(self, embedding_dim, hidden_dim, vocab_size, max_len):\n", " super(SimSonDecoder, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.hidden_dim = hidden_dim\n", " self.vocab_size = vocab_size\n", " self.max_len = max_len\n", " \n", " # Project embedding to hidden dimension\n", " self.embedding_projection = nn.Linear(embedding_dim, hidden_dim)\n", " \n", " # Token embeddings for decoder input\n", " self.token_embeddings = nn.Embedding(vocab_size, hidden_dim)\n", " self.position_embeddings = nn.Embedding(max_len, hidden_dim)\n", " \n", " # Transformer decoder layers\n", " decoder_layer = nn.TransformerDecoderLayer(\n", " d_model=hidden_dim,\n", " nhead=12,\n", " dim_feedforward=2048,\n", " dropout=0.1,\n", " batch_first=True\n", " )\n", " self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", " \n", " # Output projection to vocabulary\n", " self.output_projection = nn.Linear(hidden_dim, vocab_size)\n", " self.dropout = nn.Dropout(0.1)\n", " \n", " # Layer normalization\n", " self.layer_norm = nn.LayerNorm(hidden_dim)\n", " \n", " def forward(self, molecule_embeddings, decoder_input_ids, decoder_attention_mask=None):\n", " batch_size, seq_len = decoder_input_ids.shape\n", " device = decoder_input_ids.device\n", " \n", " # Project molecule embeddings to memory for cross-attention\n", " memory = self.embedding_projection(molecule_embeddings) # [batch, hidden_dim]\n", " memory = memory.unsqueeze(1) # [batch, 1, hidden_dim]\n", " \n", " # Create decoder input embeddings\n", " token_emb = self.token_embeddings(decoder_input_ids) # [batch, seq_len, hidden_dim]\n", " \n", " # Add positional embeddings\n", " positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)\n", " pos_emb = self.position_embeddings(positions)\n", " \n", " decoder_inputs = self.layer_norm(token_emb + pos_emb)\n", " decoder_inputs = self.dropout(decoder_inputs)\n", " \n", " # Create causal mask for decoder\n", " tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)\n", " \n", " # Create attention mask for decoder inputs\n", " if decoder_attention_mask is not None:\n", " # Convert padding mask to additive mask\n", " tgt_key_padding_mask = (decoder_attention_mask == 0)\n", " else:\n", " tgt_key_padding_mask = None\n", " \n", " # Apply transformer decoder\n", " decoder_output = self.transformer_decoder(\n", " tgt=decoder_inputs,\n", " memory=memory,\n", " tgt_mask=tgt_mask,\n", " tgt_key_padding_mask=tgt_key_padding_mask\n", " )\n", " \n", " # Project to vocabulary\n", " logits = self.output_projection(decoder_output)\n", " \n", " return logits\n", " \n", " def _generate_square_subsequent_mask(self, sz):\n", " \"\"\"Generate causal mask for decoder\"\"\"\n", " mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)\n", " return mask\n", "\n", "# Combined Encoder-Decoder Model\n", "class SimSonEncoderDecoder(nn.Module):\n", " def __init__(self, encoder, decoder, tokenizer):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.tokenizer = tokenizer\n", " \n", " def forward(self, input_ids, attention_mask, labels=None):\n", " # Encode SMILES to embeddings\n", " embeddings = self.encoder(input_ids, attention_mask)\n", " \n", " if labels is not None:\n", " # During training, use teacher forcing\n", " # Shift labels for decoder input (remove last token, add BOS token)\n", " decoder_input_ids = torch.cat([\n", " torch.full((labels.shape[0], 1), self.tokenizer.cls_token_id, \n", " dtype=labels.dtype, device=labels.device),\n", " labels[:, :-1]\n", " ], dim=1)\n", " \n", " # Generate decoder attention mask\n", " decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).long()\n", " \n", " # Decode embeddings to SMILES\n", " logits = self.decoder(embeddings, decoder_input_ids, decoder_attention_mask)\n", " \n", " # Calculate loss\n", " loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)\n", " loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))\n", " \n", " return {\"loss\": loss, \"logits\": logits}\n", " else:\n", " # During inference\n", " return self.generate(embeddings)\n", " \n", " def generate(self, embeddings, max_length=512, temperature=1.0):\n", "\n", " batch_size = embeddings.shape[0]\n", " device = embeddings.device\n", " \n", " # Start with the CLS token for all sequences in the batch\n", " generated = torch.full((batch_size, 1), self.tokenizer.cls_token_id, \n", " dtype=torch.long, device=device)\n", " \n", " # --- THE FIX: Keep track of which sequences have finished ---\n", " # A boolean tensor, initially all False.\n", " is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)\n", " \n", " for _ in range(max_length - 1):\n", " # Pass the current generated sequences to the decoder\n", " decoder_attention_mask = (generated != self.tokenizer.pad_token_id).long()\n", " logits = self.decoder(embeddings, generated, decoder_attention_mask)\n", " \n", " # Focus on the logits for the last token in each sequence\n", " next_token_logits = logits[:, -1, :]\n", " \n", " # Apply temperature sampling\n", " if temperature == 0.0:\n", " next_tokens = torch.argmax(next_token_logits, dim=-1)\n", " else:\n", " probs = F.softmax(next_token_logits / temperature, dim=-1)\n", " next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)\n", " \n", " # --- THE FIX: Prevent finished sequences from generating new tokens ---\n", " # If a sequence is already finished, its next token should be a pad token.\n", " # Otherwise, use the newly generated token.\n", " next_tokens = torch.where(is_finished, self.tokenizer.pad_token_id, next_tokens)\n", " \n", " # Append the new tokens to the generated sequences\n", " generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)\n", " \n", " # --- THE FIX: Update the `is_finished` status for any sequence that just produced an EOS token ---\n", " is_finished |= (next_tokens == self.tokenizer.sep_token_id)\n", " \n", " # --- THE FIX: Stop if all sequences in the batch are finished ---\n", " if torch.all(is_finished):\n", " break\n", " \n", " return generated\n", "\n", "\n", "# Initialize tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "\n", "# Initialize encoder config\n", "config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=12,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "# Initialize and load encoder\n", "encoder = SimSonEncoder(config=config, max_len=512)\n", "\n", "# Initialize decoder\n", "decoder = SimSonDecoder(\n", " embedding_dim=512, # encoder.max_len\n", " hidden_dim=768, # config.hidden_size\n", " vocab_size=tokenizer.vocab_size,\n", " max_len=512\n", ")\n", "\n", "# Create combined model\n", "model = SimSonEncoderDecoder(encoder, decoder, tokenizer)\n", "model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/regression/simson_encoder_decoder.bin', weights_only=False))\n", "model.cuda()\n", "print(f\"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "3f250d5d-0573-4f5b-a665-44ca3ecf135d", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import joblib\n", "import numpy as np\n", "from transformers import BertConfig, AutoTokenizer, BertModel\n", "from tqdm import tqdm\n", "\n", "# Load tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "\n", "# Initialize model configuration\n", "config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=12,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "# Initialize regression model\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(0)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = outputs.last_hidden_state\n", " hidden_states = self.dropout(hidden_states)\n", " pooled = global_ap(hidden_states)\n", " out = self.linear(pooled)\n", " return out\n", "\n", "class SimSonClassifier(nn.Module):\n", " def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1):\n", " super(SimSonClassifier, self).__init__()\n", " self.encoder = encoder\n", " self.clf = nn.Linear(encoder.max_len, num_labels)\n", " self.relu = nn.ReLU()\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, input_ids, attention_mask=None, labels=None):\n", " x = self.encoder(input_ids, attention_mask)\n", " x = self.relu(self.dropout(x))\n", " x = self.clf(x)\n", " return x\n", "\n", "compiling = False\n", "# Load model and parameters\n", "encoder = SimSonEncoder(config=config, max_len=512)\n", "if compiling:\n", " encoder = torch.compile(encoder)\n", "encoder_expert = SimSonEncoder(config=config, max_len=512)\n", "\n", "regression_model = SimSonClassifier(encoder=encoder, num_labels=6)\n", "regression_expert_model = SimSonClassifier(encoder=encoder_expert, num_labels=6)\n", "\n", "if compiling:\n", " regression_model = torch.compile(regression_model)\n", "\n", "# Load trained parameters\n", "regression_params = torch.load('/home/jovyan/simson_training_bolgov/regression/better_regression_states/best_state.bin', weights_only=False)\n", "regression_expert_params = torch.load('/home/jovyan/simson_training_bolgov/regression/high_regression_old_scalers_simson.pth', weights_only=False)\n", "\n", "\n", "regression_model.load_state_dict(regression_params)\n", "regression_expert_model.load_state_dict(regression_expert_params)\n", "\n", "regression_model.eval()\n", "regression_model.cuda()\n", "\n", "regression_expert_model.eval()\n", "regression_expert_model.cuda()\n", "\n", "# Load scalers\n", "scalers = joblib.load('/home/jovyan/simson_training_bolgov/regression/scalers')\n", "scaler_ch4 = scalers[-2] # CH4 scaler\n", "scaler_co2 = scalers[-1] # CO2 scaler" ] }, { "cell_type": "code", "execution_count": 7, "id": "7a775a04-26de-434f-a138-bafb84b661ea", "metadata": {}, "outputs": [], "source": [ "from rdkit import Chem\n", "from rdkit.Chem import MolToSmiles\n", "from tqdm import tqdm\n", "import torch\n", "\n", "def validate_reconstruction(model, tokenizer, test_smiles_list, num_samples=20, max_length=256):\n", " \"\"\"\n", " Evaluate reconstruction quality.\n", " • exact_match ........ literal string equality\n", " • canonical_match .... same molecule after canonicalisation\n", " • generated_valid .... RDKit is able to parse generated SMILES\n", " \"\"\"\n", " model.eval()\n", " device = next(model.parameters()).device\n", "\n", " results = []\n", " with torch.no_grad():\n", " for i, smiles in enumerate(test_smiles_list[:num_samples]):\n", " # ---------- encode ----------\n", " tokens = tokenizer(\n", " smiles,\n", " max_length=max_length,\n", " truncation=True,\n", " padding=\"max_length\",\n", " return_tensors=\"pt\"\n", " ).to(device)\n", "\n", " embedding = model.encoder(tokens[\"input_ids\"], tokens[\"attention_mask\"])\n", "\n", " # ---------- decode ----------\n", " gen_ids = model.generate(embedding)\n", " gen_smiles = tokenizer.decode(gen_ids[0], skip_special_tokens=True)\n", "\n", " # ---------- validity ----------\n", " mol_orig = Chem.MolFromSmiles(smiles)\n", " mol_gen = Chem.MolFromSmiles(gen_smiles)\n", " orig_valid = mol_orig is not None\n", " gen_valid = mol_gen is not None\n", "\n", " # ---------- canonical comparison ----------\n", " if orig_valid and gen_valid:\n", " can_orig = MolToSmiles(mol_orig, canonical=True)\n", " can_gen = MolToSmiles(mol_gen, canonical=True)\n", " canonical_match = (can_orig == can_gen)\n", " else:\n", " canonical_match = False\n", "\n", " results.append({\n", " \"original\": smiles,\n", " \"reconstructed\": gen_smiles,\n", " \"exact_match\": smiles == gen_smiles,\n", " \"canonical_match\": canonical_match,\n", " \"generated_valid\": gen_valid\n", " })\n", "\n", " print(f\"\\nSample {i+1}\")\n", " print(f\" Original: {smiles}\")\n", " print(f\" Reconstructed: {gen_smiles}\")\n", " print(f\" Valid (gen): {gen_valid}\")\n", " print(f\" Same molecule: {canonical_match}\")\n", " print(\"-\" * 60)\n", "\n", " # ---------- aggregated metrics ----------\n", " exact_acc = sum(r[\"exact_match\"] for r in results) / len(results)\n", " canonical_acc = sum(r[\"canonical_match\"] for r in results) / len(results)\n", " validity_rate = sum(r[\"generated_valid\"] for r in results) / len(results)\n", "\n", " print(f\"\\nExact-string match accuracy ...... {exact_acc:.2%}\")\n", " print(f\"Canonical match accuracy ......... {canonical_acc:.2%}\")\n", " print(f\"Generated validity rate .......... {validity_rate:.2%}\")\n", "\n", " return results\n", "\n", "test_smiles = val_df['smiles' if 'smiles' in val_df.columns else 'Smiles'].tolist()\n", "#validation_results = validate_reconstruction(model, tokenizer, test_smiles, num_samples=20)\n" ] }, { "cell_type": "code", "execution_count": 115, "id": "21d72f76-9bc6-480d-8bf1-205c34ed7e2e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████| 999/999 [04:08<00:00, 4.02it/s]\n" ] } ], "source": [ "from sklearn.metrics import pairwise_distances\n", "from sklearn.preprocessing import MinMaxScaler\n", "\n", "\n", "features = ['CO2', 'CH4']\n", "\n", "# --- 3. Max–Min selection ----------------------------------------------------\n", "def select_diverse_maxmin(data, cols, k=1_000):\n", " X = data[cols].to_numpy()\n", " \n", " # scale to [0,1] so CO₂ and CH₄ get equal weight\n", " X = MinMaxScaler().fit_transform(X)\n", " \n", " n = len(X)\n", " if k >= n: # nothing to do\n", " return data\n", " \n", " # start with a random seed point\n", " selected = [np.random.randint(0, n)]\n", " \n", " # pre-allocate distance cache\n", " min_dist = pairwise_distances(\n", " X, X[selected], metric='euclidean'\n", " ).ravel() # distance to first point\n", " \n", " for _ in tqdm(range(1, k)):\n", " # pick the point with the largest distance to the current set\n", " idx = np.argmax(min_dist)\n", " selected.append(idx)\n", " \n", " # update distance cache (keep the shortest distance to any selected pt)\n", " dist_to_new = pairwise_distances(X, X[[idx]], metric='euclidean').ravel()\n", " min_dist = np.minimum(min_dist, dist_to_new)\n", " \n", " return data.iloc[selected].reset_index(drop=True)\n", "\n", "sample_df = select_diverse_maxmin(df, features, k=1_000)" ] }, { "cell_type": "code", "execution_count": 8, "id": "ecae2eaa-dcd0-4a0d-a8d6-347c254285d0", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import pandas as pd\n", "from rdkit import Chem\n", "from rdkit.Chem import MolToSmiles\n", "from scipy.stats import pearsonr, spearmanr\n", "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "def display_property_fidelity_results(results_df, metrics):\n", " \"\"\"Display comprehensive property fidelity analysis\"\"\"\n", " \n", " valid_results = results_df[results_df['reconstructed_valid']].copy()\n", " total_molecules = len(results_df)\n", " valid_molecules = len(valid_results)\n", " identical_molecules = results_df['chemically_identical'].sum()\n", " \n", " print(f\"\\n{'='*80}\")\n", " print(f\"PROPERTY RECONSTRUCTION FIDELITY ANALYSIS\")\n", " print(f\"{'='*80}\")\n", " \n", " print(f\"Dataset Overview:\")\n", " print(f\" Total molecules tested: {total_molecules}\")\n", " print(f\" Valid reconstructions: {valid_molecules} ({valid_molecules/total_molecules*100:.1f}%)\")\n", " print(f\" Chemically identical: {identical_molecules} ({identical_molecules/total_molecules*100:.1f}%)\")\n", " \n", " if valid_molecules > 0:\n", " print(f\"\\n{'='*60}\")\n", " print(f\"PROPERTY CORRELATION METRICS\")\n", " print(f\"{'='*60}\")\n", " \n", " print(f\"CH₄ Permeability:\")\n", " print(f\" Pearson correlation: {metrics['CH4']['pearson_correlation']:.4f}\")\n", " print(f\" Spearman correlation: {metrics['CH4']['spearman_correlation']:.4f}\")\n", " print(f\" R² score: {metrics['CH4']['r2_score']:.4f}\")\n", " print(f\" MAE: {metrics['CH4']['mae']:.4f}\")\n", " print(f\" RMSE: {metrics['CH4']['rmse']:.4f}\")\n", " print(f\" Mean relative error: {metrics['CH4']['mean_relative_error_pct']:.2f}%\")\n", " print(f\" Median relative error: {metrics['CH4']['median_relative_error_pct']:.2f}%\")\n", " \n", " print(f\"\\nCO₂ Permeability:\")\n", " print(f\" Pearson correlation: {metrics['CO2']['pearson_correlation']:.4f}\")\n", " print(f\" Spearman correlation: {metrics['CO2']['spearman_correlation']:.4f}\")\n", " print(f\" R² score: {metrics['CO2']['r2_score']:.4f}\")\n", " print(f\" MAE: {metrics['CO2']['mae']:.4f}\")\n", " print(f\" RMSE: {metrics['CO2']['rmse']:.4f}\")\n", " print(f\" Mean relative error: {metrics['CO2']['mean_relative_error_pct']:.2f}%\")\n", " print(f\" Median relative error: {metrics['CO2']['median_relative_error_pct']:.2f}%\")\n", " \n", " # Property preservation quality assessment\n", " print(f\"\\n{'='*60}\")\n", " print(f\"PROPERTY PRESERVATION ASSESSMENT\")\n", " print(f\"{'='*60}\")\n", " \n", " # Define quality thresholds\n", " excellent_threshold = 5.0 # <5% relative error\n", " good_threshold = 15.0 # <15% relative error\n", " acceptable_threshold = 30.0 # <30% relative error\n", " \n", " ch4_excellent = (valid_results['CH4_relative_error'] < excellent_threshold).sum()\n", " ch4_good = (valid_results['CH4_relative_error'] < good_threshold).sum()\n", " ch4_acceptable = (valid_results['CH4_relative_error'] < acceptable_threshold).sum()\n", " \n", " co2_excellent = (valid_results['CO2_relative_error'] < excellent_threshold).sum()\n", " co2_good = (valid_results['CO2_relative_error'] < good_threshold).sum()\n", " co2_acceptable = (valid_results['CO2_relative_error'] < acceptable_threshold).sum()\n", " \n", " print(f\"CH₄ Property Preservation Quality:\")\n", " print(f\" Excellent (<{excellent_threshold}% error): {ch4_excellent}/{valid_molecules} ({ch4_excellent/valid_molecules*100:.1f}%)\")\n", " print(f\" Good (<{good_threshold}% error): {ch4_good}/{valid_molecules} ({ch4_good/valid_molecules*100:.1f}%)\")\n", " print(f\" Acceptable (<{acceptable_threshold}% error): {ch4_acceptable}/{valid_molecules} ({ch4_acceptable/valid_molecules*100:.1f}%)\")\n", " \n", " print(f\"\\nCO₂ Property Preservation Quality:\")\n", " print(f\" Excellent (<{excellent_threshold}% error): {co2_excellent}/{valid_molecules} ({co2_excellent/valid_molecules*100:.1f}%)\")\n", " print(f\" Good (<{good_threshold}% error): {co2_good}/{valid_molecules} ({co2_good/valid_molecules*100:.1f}%)\")\n", " print(f\" Acceptable (<{acceptable_threshold}% error): {co2_acceptable}/{valid_molecules} ({co2_acceptable/valid_molecules*100:.1f}%)\")\n", "\n", "\n", "def evaluate_property_reconstruction_fidelity(model, regression_model, tokenizer, \n", " scaler_ch4, scaler_co2, test_smiles_list,\n", " batch_size=16, max_length=256):\n", " \"\"\"\n", " Evaluate how well reconstructed SMILES preserve chemical properties\n", " \n", " Args:\n", " model: Trained encoder-decoder model\n", " regression_model: Trained regression model for property prediction\n", " tokenizer: SMILES tokenizer\n", " scaler_ch4, scaler_co2: Property scalers\n", " test_smiles_list: List of original SMILES to test\n", " batch_size: Batch size for processing\n", " max_length: Maximum SMILES length\n", " \n", " Returns:\n", " results_df: DataFrame with detailed property comparison results\n", " metrics: Dictionary of aggregate evaluation metrics\n", " \"\"\"\n", " \n", " model.eval()\n", " regression_model.eval()\n", " device = next(model.parameters()).device\n", " \n", " results = []\n", " \n", " print(f\"Evaluating property reconstruction fidelity for {len(test_smiles_list)} molecules...\")\n", " \n", " # Process in batches for memory efficiency\n", " for i in tqdm(range(0, len(test_smiles_list), batch_size), desc=\"Processing batches\"):\n", " batch_end = min(i + batch_size, len(test_smiles_list))\n", " batch_smiles = test_smiles_list[i:batch_end]\n", " \n", " with torch.no_grad():\n", " # Step 1: Generate embeddings from original SMILES\n", " original_tokens = tokenizer(\n", " batch_smiles,\n", " max_length=max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " # Get embeddings\n", " embeddings = model.encoder(original_tokens['input_ids'].cuda(), original_tokens['attention_mask'].cuda())\n", " \n", " # Step 2: Reconstruct SMILES from embeddings\n", " reconstructed_ids = model.generate(embeddings, temperature=1.5)\n", " \n", " # Step 3: Predict properties for both original and reconstructed\n", " # Original properties\n", " orig_predictions = regression_model(original_tokens['input_ids'], original_tokens['attention_mask'])\n", " orig_ch4_scaled = orig_predictions[:, -2].cpu().numpy().reshape(-1, 1)\n", " orig_co2_scaled = orig_predictions[:, -1].cpu().numpy().reshape(-1, 1)\n", " orig_ch4 = scaler_ch4.inverse_transform(orig_ch4_scaled).flatten()\n", " orig_co2 = scaler_co2.inverse_transform(orig_co2_scaled).flatten()\n", " \n", " # Reconstructed properties\n", " reconstructed_smiles = [tokenizer.decode(seq, skip_special_tokens=True) for seq in reconstructed_ids]\n", " # Validate reconstructed SMILES and predict properties\n", " for j, (orig_smiles, recon_smiles) in enumerate(zip(batch_smiles, reconstructed_smiles)):\n", " # Chemical validity checks\n", " orig_mol = Chem.MolFromSmiles(orig_smiles)\n", " recon_mol = Chem.MolFromSmiles(recon_smiles)\n", " \n", " orig_valid = orig_mol is not None\n", " recon_valid = recon_mol is not None\n", " \n", " # Canonical SMILES comparison\n", " if orig_valid and recon_valid:\n", " orig_canonical = MolToSmiles(orig_mol, canonical=True)\n", " recon_canonical = MolToSmiles(recon_mol, canonical=True)\n", " is_identical = orig_canonical == recon_canonical\n", " else:\n", " orig_canonical = \"INVALID\" if not orig_valid else MolToSmiles(orig_mol, canonical=True)\n", " recon_canonical = \"INVALID\" if not recon_valid else MolToSmiles(recon_mol, canonical=True)\n", " is_identical = False\n", " \n", " # Property prediction for reconstructed SMILES\n", " if recon_valid:\n", " recon_tokens = tokenizer(\n", " recon_smiles,\n", " max_length=max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " recon_predictions = regression_model(recon_tokens['input_ids'], recon_tokens['attention_mask'])\n", " recon_ch4_scaled = recon_predictions[0, -2].cpu().numpy().reshape(-1, 1)\n", " recon_co2_scaled = recon_predictions[0, -1].cpu().numpy().reshape(-1, 1)\n", " recon_ch4 = scaler_ch4.inverse_transform(recon_ch4_scaled)[0, 0]\n", " recon_co2 = scaler_co2.inverse_transform(recon_co2_scaled)[0, 0]\n", " else:\n", " recon_ch4 = np.nan\n", " recon_co2 = np.nan\n", " \n", " # Calculate property differences\n", " ch4_absolute_error = abs(orig_ch4[j] - recon_ch4) if recon_valid else np.nan\n", " co2_absolute_error = abs(orig_co2[j] - recon_co2) if recon_valid else np.nan\n", " \n", " ch4_relative_error = (ch4_absolute_error / orig_ch4[j] * 100) if (recon_valid and orig_ch4[j] != 0) else np.nan\n", " co2_relative_error = (co2_absolute_error / orig_co2[j] * 100) if (recon_valid and orig_co2[j] != 0) else np.nan\n", " \n", " results.append({\n", " 'molecule_id': i + j + 1,\n", " 'original_smiles': orig_smiles,\n", " 'reconstructed_smiles': recon_smiles,\n", " 'original_canonical': orig_canonical,\n", " 'reconstructed_canonical': recon_canonical,\n", " 'chemically_identical': is_identical,\n", " 'original_valid': orig_valid,\n", " 'reconstructed_valid': recon_valid,\n", " 'original_CH4': orig_ch4[j],\n", " 'original_CO2': orig_co2[j],\n", " 'reconstructed_CH4': recon_ch4,\n", " 'reconstructed_CO2': recon_co2,\n", " 'CH4_absolute_error': ch4_absolute_error,\n", " 'CO2_absolute_error': co2_absolute_error,\n", " 'CH4_relative_error': ch4_relative_error,\n", " 'CO2_relative_error': co2_relative_error\n", " })\n", " \n", " # Convert to DataFrame\n", " results_df = pd.DataFrame(results)\n", " \n", " # Calculate aggregate metrics\n", " valid_comparisons = results_df[results_df['reconstructed_valid']].copy()\n", " \n", " if len(valid_comparisons) > 0:\n", " metrics = calculate_property_fidelity_metrics(valid_comparisons)\n", " display_property_fidelity_results(results_df, metrics)\n", " else:\n", " print(\"No valid reconstructions for property comparison!\")\n", " metrics = {}\n", " \n", " return results_df, metrics\n", "\n", "def visualize_property_fidelity(results_df, save_plots=True):\n", " \"\"\"Create comprehensive visualizations of property reconstruction fidelity\"\"\"\n", " \n", " valid_results = results_df[results_df['reconstructed_valid']].copy()\n", " \n", " if len(valid_results) == 0:\n", " print(\"No valid results to visualize!\")\n", " return\n", " \n", " fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", " \n", " # 1. CH4 Property Correlation\n", " axes[0, 0].scatter(valid_results['original_CH4'], valid_results['reconstructed_CH4'], \n", " alpha=0.6, s=30, edgecolors='black', linewidth=0.5)\n", " \n", " # Perfect correlation line\n", " min_ch4 = min(valid_results['original_CH4'].min(), valid_results['reconstructed_CH4'].min())\n", " max_ch4 = max(valid_results['original_CH4'].max(), valid_results['reconstructed_CH4'].max())\n", " axes[0, 0].plot([min_ch4, max_ch4], [min_ch4, max_ch4], 'r--', linewidth=2, label='Perfect Correlation')\n", " \n", " # Calculate and display R²\n", " ch4_r2 = r2_score(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " axes[0, 0].set_xlabel('Original CH₄ Permeability')\n", " axes[0, 0].set_ylabel('Reconstructed CH₄ Permeability')\n", " axes[0, 0].set_title(f'CH₄ Property Reconstruction (R² = {ch4_r2:.3f})')\n", " axes[0, 0].legend()\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " # 2. CO2 Property Correlation\n", " axes[0, 1].scatter(valid_results['original_CO2'], valid_results['reconstructed_CO2'], \n", " alpha=0.6, s=30, edgecolors='black', linewidth=0.5)\n", " \n", " min_co2 = min(valid_results['original_CO2'].min(), valid_results['reconstructed_CO2'].min())\n", " max_co2 = max(valid_results['original_CO2'].max(), valid_results['reconstructed_CO2'].max())\n", " axes[0, 1].plot([min_co2, max_co2], [min_co2, max_co2], 'r--', linewidth=2, label='Perfect Correlation')\n", " \n", " co2_r2 = r2_score(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " axes[0, 1].set_xlabel('Original CO₂ Permeability')\n", " axes[0, 1].set_ylabel('Reconstructed CO₂ Permeability')\n", " axes[0, 1].set_title(f'CO₂ Property Reconstruction (R² = {co2_r2:.3f})')\n", " axes[0, 1].legend()\n", " axes[0, 1].grid(True, alpha=0.3)\n", " \n", " # 3. Relative Error Distribution - CH4\n", " axes[0, 2].hist(valid_results['CH4_relative_error'].dropna(), bins=30, alpha=0.7, \n", " edgecolor='black', color='skyblue')\n", " axes[0, 2].axvline(valid_results['CH4_relative_error'].median(), color='red', \n", " linestyle='--', linewidth=2, label=f'Median: {valid_results[\"CH4_relative_error\"].median():.1f}%')\n", " axes[0, 2].set_xlabel('CH₄ Relative Error (%)')\n", " axes[0, 2].set_ylabel('Frequency')\n", " axes[0, 2].set_title('CH₄ Relative Error Distribution')\n", " axes[0, 2].legend()\n", " axes[0, 2].grid(True, alpha=0.3)\n", " \n", " # 4. Relative Error Distribution - CO2\n", " axes[1, 0].hist(valid_results['CO2_relative_error'].dropna(), bins=30, alpha=0.7, \n", " edgecolor='black', color='lightcoral')\n", " axes[1, 0].axvline(valid_results['CO2_relative_error'].median(), color='red', \n", " linestyle='--', linewidth=2, label=f'Median: {valid_results[\"CO2_relative_error\"].median():.1f}%')\n", " axes[1, 0].set_xlabel('CO₂ Relative Error (%)')\n", " axes[1, 0].set_ylabel('Frequency')\n", " axes[1, 0].set_title('CO₂ Relative Error Distribution')\n", " axes[1, 0].legend()\n", " axes[1, 0].grid(True, alpha=0.3)\n", " \n", " # 5. Error Comparison\n", " error_comparison = pd.DataFrame({\n", " 'CH₄_Error': valid_results['CH4_relative_error'].dropna(),\n", " 'CO₂_Error': valid_results['CO2_relative_error'].dropna()\n", " })\n", " \n", " axes[1, 1].scatter(error_comparison['CH₄_Error'], error_comparison['CO₂_Error'], \n", " alpha=0.6, s=30, edgecolors='black', linewidth=0.5)\n", " axes[1, 1].set_xlabel('CH₄ Relative Error (%)')\n", " axes[1, 1].set_ylabel('CO₂ Relative Error (%)')\n", " axes[1, 1].set_title('Property Error Correlation')\n", " axes[1, 1].grid(True, alpha=0.3)\n", " \n", " # 6. Chemical Identity vs Property Preservation\n", " identical_mask = valid_results['chemically_identical']\n", " \n", " ch4_errors_identical = valid_results[identical_mask]['CH4_relative_error'].dropna()\n", " ch4_errors_different = valid_results[~identical_mask]['CH4_relative_error'].dropna()\n", " \n", " axes[1, 2].boxplot([ch4_errors_identical, ch4_errors_different], \n", " labels=['Chemically Identical', 'Chemically Different'])\n", " axes[1, 2].set_ylabel('CH₄ Relative Error (%)')\n", " axes[1, 2].set_title('Property Error by Chemical Identity')\n", " axes[1, 2].grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " \n", " if save_plots:\n", " plt.savefig('property_reconstruction_fidelity.png', dpi=300, bbox_inches='tight')\n", " print(\"Visualization saved as 'property_reconstruction_fidelity.png'\")\n", " \n", " plt.show()\n", "\n", "\n", "def calculate_property_fidelity_metrics(valid_results):\n", " \"\"\"Calculate comprehensive property fidelity metrics\"\"\"\n", " \n", " metrics = {}\n", " \n", " # Correlation metrics\n", " ch4_corr_pearson, ch4_p_pearson = pearsonr(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_corr_pearson, co2_p_pearson = pearsonr(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " ch4_corr_spearman, ch4_p_spearman = spearmanr(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_corr_spearman, co2_p_spearman = spearmanr(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " # Regression metrics\n", " ch4_r2 = r2_score(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_r2 = r2_score(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " ch4_mae = mean_absolute_error(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_mae = mean_absolute_error(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " ch4_rmse = np.sqrt(mean_squared_error(valid_results['original_CH4'], valid_results['reconstructed_CH4']))\n", " co2_rmse = np.sqrt(mean_squared_error(valid_results['original_CO2'], valid_results['reconstructed_CO2']))\n", " \n", " # Relative error statistics\n", " ch4_mean_rel_error = valid_results['CH4_relative_error'].mean()\n", " co2_mean_rel_error = valid_results['CO2_relative_error'].mean()\n", " \n", " ch4_median_rel_error = valid_results['CH4_relative_error'].median()\n", " co2_median_rel_error = valid_results['CO2_relative_error'].median()\n", " \n", " metrics = {\n", " 'CH4': {\n", " 'pearson_correlation': ch4_corr_pearson,\n", " 'spearman_correlation': ch4_corr_spearman,\n", " 'r2_score': ch4_r2,\n", " 'mae': ch4_mae,\n", " 'rmse': ch4_rmse,\n", " 'mean_relative_error_pct': ch4_mean_rel_error,\n", " 'median_relative_error_pct': ch4_median_rel_error\n", " },\n", " 'CO2': {\n", " 'pearson_correlation': co2_corr_pearson,\n", " 'spearman_correlation': co2_corr_spearman,\n", " 'r2_score': co2_r2,\n", " 'mae': co2_mae,\n", " 'rmse': co2_rmse,\n", " 'mean_relative_error_pct': co2_mean_rel_error,\n", " 'median_relative_error_pct': co2_median_rel_error\n", " }\n", " }\n", " \n", " return metrics\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "3e6bcd77-7314-4829-8d26-26eab56f418c", "metadata": {}, "outputs": [], "source": [ "from rdkit import RDLogger\n", "\n", "# Disable all RDKit logs\n", "RDLogger.DisableLog('rdApp.*') " ] }, { "cell_type": "code", "execution_count": 26, "id": "04ecf928-10d8-4366-a9d6-1a646301d387", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0SmilesTgHeN2O2CH4CO2synthesizable
51170515117051Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
16542681654268Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
66922756692275Ic1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
830270830270Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
66922766692276Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
..............................
1095010950Ic1cc(Oc2cc(Oc3cc(cc(c3)n3c(=O)c4c(c3=O)cc3c(c...538.151.235460.001050.004040.001950.00785True
53090145309014Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...552.002.563850.001380.007770.002870.00772False
25984182598418Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...552.002.563850.001380.007770.002870.00772False
746918746918Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...552.002.563850.001380.007770.002870.00772False
22552255Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...586.801.805440.000680.003300.001840.00419False
\n", "

6726950 rows × 9 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 Smiles \\\n", "5117051 5117051 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... \n", "1654268 1654268 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... \n", "6692275 6692275 Ic1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... \n", "830270 830270 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... \n", "6692276 6692276 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... \n", "... ... ... \n", "10950 10950 Ic1cc(Oc2cc(Oc3cc(cc(c3)n3c(=O)c4c(c3=O)cc3c(c... \n", "5309014 5309014 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "2598418 2598418 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "746918 746918 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "2255 2255 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "\n", " Tg He N2 O2 CH4 CO2 \\\n", "5117051 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "1654268 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "6692275 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "830270 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "6692276 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "... ... ... ... ... ... ... \n", "10950 538.15 1.23546 0.00105 0.00404 0.00195 0.00785 \n", "5309014 552.00 2.56385 0.00138 0.00777 0.00287 0.00772 \n", "2598418 552.00 2.56385 0.00138 0.00777 0.00287 0.00772 \n", "746918 552.00 2.56385 0.00138 0.00777 0.00287 0.00772 \n", "2255 586.80 1.80544 0.00068 0.00330 0.00184 0.00419 \n", "\n", " synthesizable \n", "5117051 False \n", "1654268 False \n", "6692275 False \n", "830270 False \n", "6692276 False \n", "... ... \n", "10950 True \n", "5309014 False \n", "2598418 False \n", "746918 False \n", "2255 False \n", "\n", "[6726950 rows x 9 columns]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.sort_values(by=['CO2'], ascending=False)" ] }, { "cell_type": "code", "execution_count": 56, "id": "576a2486-6184-4d63-9a07-f545c2551df0", "metadata": {}, "outputs": [], "source": [ "enc_gen = model.encoder.state_dict()\n", "enc_reg = regression_model.encoder.state_dict()\n", "\n" ] }, { "cell_type": "code", "execution_count": 89, "id": "19e84d1a-56a8-493e-ba56-5eaa5aebb556", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluating property reconstruction fidelity for 100 molecules...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Processing batches: 100%|█████████████████████████| 1/1 [00:16<00:00, 16.36s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "================================================================================\n", "PROPERTY RECONSTRUCTION FIDELITY ANALYSIS\n", "================================================================================\n", "Dataset Overview:\n", " Total molecules tested: 100\n", " Valid reconstructions: 93 (93.0%)\n", " Chemically identical: 0 (0.0%)\n", "\n", "============================================================\n", "PROPERTY CORRELATION METRICS\n", "============================================================\n", "CH₄ Permeability:\n", " Pearson correlation: 0.0400\n", " Spearman correlation: -0.0705\n", " R² score: -3.0056\n", " MAE: 82.2622\n", " RMSE: 103.3733\n", " Mean relative error: 19.19%\n", " Median relative error: 24.48%\n", "\n", "CO₂ Permeability:\n", " Pearson correlation: -0.1054\n", " Spearman correlation: -0.3209\n", " R² score: -3.9737\n", " MAE: 26978.8016\n", " RMSE: 33540.0662\n", " Mean relative error: 20.97%\n", " Median relative error: 26.74%\n", "\n", "============================================================\n", "PROPERTY PRESERVATION ASSESSMENT\n", "============================================================\n", "CH₄ Property Preservation Quality:\n", " Excellent (<5.0% error): 32/93 (34.4%)\n", " Good (<15.0% error): 33/93 (35.5%)\n", " Acceptable (<30.0% error): 76/93 (81.7%)\n", "\n", "CO₂ Property Preservation Quality:\n", " Excellent (<5.0% error): 28/93 (30.1%)\n", " Good (<15.0% error): 33/93 (35.5%)\n", " Acceptable (<30.0% error): 57/93 (61.3%)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "test_smiles = sample_df['Smiles'].tolist()\n", "test_smiles = df.sort_values(by=['CO2'], ascending=False)['Smiles'].tolist()[:100]\n", "\n", "\n", "results_df, metrics = evaluate_property_reconstruction_fidelity(\n", " model=model, \n", " regression_model=regression_expert_model,\n", " tokenizer=tokenizer,\n", " scaler_ch4=scaler_ch4,\n", " scaler_co2=scaler_co2,\n", " test_smiles_list=test_smiles[:1_000],\n", " batch_size=128\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "id": "62c35e28-78ec-4b04-8211-cdc24ee21969", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from tqdm import tqdm\n", "from rdkit import Chem\n", "from rdkit.Chem import MolToSmiles\n", "\n", "def generate_base_embeddings(model, tokenizer, val_df, max_length=256, batch_size=128):\n", " \"\"\"\n", " Generate embeddings for molecules in validation dataset using batch processing\n", " \n", " Args:\n", " model: Trained encoder-decoder model\n", " tokenizer: SMILES tokenizer\n", " val_df: Validation dataframe containing SMILES\n", " max_length: Maximum SMILES sequence length \n", " batch_size: Number of molecules to process in each batch\n", " \n", " Returns:\n", " base_embeddings: Tensor of shape [num_molecules, embedding_dim]\n", " \"\"\"\n", " model.eval()\n", " device = next(model.parameters()).device\n", " \n", " smiles_list = val_df['Smiles'].tolist()\n", " embeddings_list = []\n", " \n", " with torch.no_grad():\n", " for i in tqdm(range(0, len(smiles_list), batch_size), desc=\"Generating base embeddings in batches\"):\n", " # Get current batch of SMILES\n", " batch_end = min(i + batch_size, len(smiles_list))\n", " batch_smiles = smiles_list[i:batch_end]\n", " \n", " # Tokenize entire batch at once\n", " encoding = tokenizer(\n", " batch_smiles,\n", " max_length=max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " with torch.autocast(dtype=torch.float16, device_type='cuda'):\n", " embedding_batch = model.encoder(encoding['input_ids'], encoding['attention_mask'])\n", " \n", " # Move to CPU and store\n", " embeddings_list.append(embedding_batch.cpu().numpy())\n", " \n", " # Stack all batch embeddings into single array\n", " base_embeddings = np.vstack(embeddings_list)\n", " \n", " return torch.tensor(base_embeddings, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": null, "id": "bdd463d7-da61-4064-915b-54ab4c3671ad", "metadata": {}, "outputs": [], "source": [ "# Generate base embeddings\n", "print(\"Generating 'base' embeddings\")\n", "base_embeddings_s = generate_base_embeddings(regression_model, tokenizer, df[:100])\n", "print(f\"Generated embeddings shape: {base_embeddings.shape}\")\n" ] }, { "cell_type": "code", "execution_count": 39, "id": "9a27dfc8-3510-4339-8799-460f9544a9a1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = regression_model.encoder\n", "b = model.encoder\n", "a == b" ] }, { "cell_type": "code", "execution_count": 68, "id": "3915b796-c57e-44d1-9ff0-7002a7c86499", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['/home/jovyan/simson_training_bolgov/regression/base_embeddings_new_reg.pickle']" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import joblib\n", "\n", "joblib.dump(base_embeddings, '/home/jovyan/simson_training_bolgov/regression/base_embeddings_new_reg.pickle')" ] }, { "cell_type": "code", "execution_count": 11, "id": "a9a8d315-965d-493d-8d9d-7fb096e5c96e", "metadata": {}, "outputs": [], "source": [ "import joblib\n", "\n", "base_embeddings = joblib.load('/home/jovyan/simson_training_bolgov/regression/sample_embeddings.pickle')\n", "#expert_embeddings = joblib.load('/home/jovyan/simson_training_bolgov/regression/expert_embeddings.pickle')" ] }, { "cell_type": "code", "execution_count": 76, "id": "9778d0de-fc1e-4489-8134-fe50ac868039", "metadata": {}, "outputs": [], "source": [ "actual_encoder = model.encoder.state_dict()\n", "torch.save(actual_encoder, '/home/jovyan/simson_training_bolgov/regression/actual_encoder_state.pkl')" ] }, { "cell_type": "code", "execution_count": 12, "id": "2ff3dc41-1501-4ca7-8292-01a58924bb44", "metadata": {}, "outputs": [], "source": [ "training_smiles_set = set(df['Smiles'].tolist())\n", "\n", "def is_valid_polymer(smiles: str) -> bool:\n", " \"\"\"\n", " Checks if a SMILES string represents a valid polymer according to specific rules.\n", "\n", " A valid polymer must:\n", " 1. Be a chemically valid molecule parsable by RDKit.\n", " 2. Contain exactly two 'I' atoms, representing polymer endpoints.\n", " 3. Have identical bond types connecting to both endpoints.\n", " \"\"\"\n", " if not isinstance(smiles, str):\n", " return False\n", "\n", " # Rule 1: Basic chemical validity\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is None:\n", " return False\n", "\n", " # Rule 2: Must contain exactly two endpoints\n", " if smiles.count('I') != 2:\n", " return False\n", "\n", " # Rule 3: Endpoint bonds must match\n", " try:\n", " # Replace 'I' with a standard dummy atom '[*]' for analysis\n", " mol_with_dummy = Chem.MolFromSmiles(smiles.replace('I', '[*]'))\n", " if mol_with_dummy is None:\n", " return False\n", "\n", " # Find the atoms connected to the dummy endpoints\n", " matches = mol_with_dummy.GetSubstructMatches(Chem.MolFromSmarts('[#0]~*'))\n", "\n", " if len(matches) != 2:\n", " return False\n", "\n", " # Get the bonds connecting to the endpoints\n", " bond1 = mol_with_dummy.GetBondBetweenAtoms(matches[0][0], matches[0][1])\n", " bond2 = mol_with_dummy.GetBondBetweenAtoms(matches[1][0], matches[1][1])\n", "\n", " if bond1 is None or bond2 is None:\n", " return False\n", "\n", " # Check if the bond types are identical\n", " if bond1.GetBondType() != bond2.GetBondType():\n", " return False\n", " except Exception:\n", " # Catch any RDKit parsing or processing errors\n", " return False\n", "\n", " return True\n", "\n", "def is_novel_and_valid_polymer(smiles: str, training_set: set) -> bool:\n", " \"\"\"\n", " Performs a final check to ensure a molecule is both novel and a valid polymer.\n", " \"\"\"\n", " # Check 1: Is it new?\n", " if smiles in training_set:\n", " print('NOT NEW')\n", " return False\n", " \n", " # Check 2: Does it meet polymer criteria?\n", " return is_valid_polymer(smiles)" ] }, { "cell_type": "code", "execution_count": 20, "id": "fe3a75c9-eb84-4854-a477-d401761baba1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0SmilesTgHeN2O2CH4CO2synthesizable
00Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1...494.842.695244.7574042.318471.64086148.43644False
11Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=...508.265.338152.9723926.311180.8646782.37635False
22O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c...640.9120.475150.063530.904980.069052.35993False
33Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O...568.044.196920.001910.011340.003620.01418False
44Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1...548.10142.683270.873808.254092.5206730.04739False
..............................
67269456726945Ic1cccc(c1)Cc1cc(C)c(c(c1)C)c1c(C)cc(cc1C)Cc1c...516.2713.329670.119071.101440.169072.63642False
67269466726946Ic1ccc(nc1)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(...510.728.964547.9225772.929292.17324247.14446False
67269476726947Ic1ccc(cn1)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)...610.244.9075820.24364121.244941.01011650.18364False
67269486726948Ic1ccc(c(c1)C)Oc1ccc2c(c1)Cc1c2ccc(c1)Oc1ccc(c...510.9012.409070.193071.413350.117284.00573True
67269496726949Ic1cc(Sc2ccc3c(c2)cc(cc3)Sc2cc(cc(c2)C(F)(F)F)...462.3114.963420.139161.052520.092982.54411False
\n", "

6726950 rows × 9 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 Smiles \\\n", "0 0 Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... \n", "1 1 Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... \n", "2 2 O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... \n", "3 3 Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... \n", "4 4 Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... \n", "... ... ... \n", "6726945 6726945 Ic1cccc(c1)Cc1cc(C)c(c(c1)C)c1c(C)cc(cc1C)Cc1c... \n", "6726946 6726946 Ic1ccc(nc1)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(... \n", "6726947 6726947 Ic1ccc(cn1)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)... \n", "6726948 6726948 Ic1ccc(c(c1)C)Oc1ccc2c(c1)Cc1c2ccc(c1)Oc1ccc(c... \n", "6726949 6726949 Ic1cc(Sc2ccc3c(c2)cc(cc3)Sc2cc(cc(c2)C(F)(F)F)... \n", "\n", " Tg He N2 O2 CH4 CO2 \\\n", "0 494.84 2.69524 4.75740 42.31847 1.64086 148.43644 \n", "1 508.26 5.33815 2.97239 26.31118 0.86467 82.37635 \n", "2 640.91 20.47515 0.06353 0.90498 0.06905 2.35993 \n", "3 568.04 4.19692 0.00191 0.01134 0.00362 0.01418 \n", "4 548.10 142.68327 0.87380 8.25409 2.52067 30.04739 \n", "... ... ... ... ... ... ... \n", "6726945 516.27 13.32967 0.11907 1.10144 0.16907 2.63642 \n", "6726946 510.72 8.96454 7.92257 72.92929 2.17324 247.14446 \n", "6726947 610.24 4.90758 20.24364 121.24494 1.01011 650.18364 \n", "6726948 510.90 12.40907 0.19307 1.41335 0.11728 4.00573 \n", "6726949 462.31 14.96342 0.13916 1.05252 0.09298 2.54411 \n", "\n", " synthesizable \n", "0 False \n", "1 False \n", "2 False \n", "3 False \n", "4 False \n", "... ... \n", "6726945 False \n", "6726946 False \n", "6726947 False \n", "6726948 True \n", "6726949 False \n", "\n", "[6726950 rows x 9 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 13, "id": "aa424b6f-5cf4-42aa-bdce-83eff2b690c7", "metadata": {}, "outputs": [], "source": [ "def pred_from_embeds(embeds, regression_model):\n", " \n", " x = regression_model.relu(embeds)\n", " return regression_model.clf(x)\n", "\n", "def gradient_based_extrapolation(\n", " generative_model, regression_model, regression_expert_model, embeddings, scaler, target_idx=-2,\n", " learning_rate=0.0001, steps=50, batch_size=32, best_idx=None, lambda_reg=0.3, scale_factor=1000, expert_threshold=120_000\n", "):\n", " \"\"\"\n", " Use property gradients to guide extrapolation, with checks for novelty and polymer validity.\n", " If an invalid or non-novel molecule is generated, optimization stops.\n", " \"\"\"\n", " device = next(regression_model.parameters()).device\n", "\n", " # --- Freeze models to prevent parameter updates ---\n", " for param in regression_model.parameters():\n", " param.requires_grad = False\n", " regression_model.eval()\n", "\n", " for param in regression_expert_model.parameters():\n", " param.requires_grad = False\n", " regression_expert_model.eval()\n", " \n", " for param in generative_model.parameters():\n", " param.requires_grad = False\n", " generative_model.eval()\n", "\n", " # --- 1. Find the best starting embedding (batched) ---\n", " if best_idx is None:\n", " properties = []\n", " with torch.no_grad():\n", " for i in tqdm(range(0, len(embeddings), batch_size), desc=\"Computing initial properties\"):\n", " batch = embeddings[i:i+batch_size]\n", " preds = regression_model.clf(batch)\n", " properties.append(preds[:, target_idx].cpu())\n", " properties = torch.cat(properties)\n", " best_idx = torch.argmax(properties)\n", "\n", " start_embedding = embeddings[best_idx].cuda().reshape(1, -1).requires_grad_(True)\n", " initial_embedding = start_embedding.clone().cuda()\n", " # This will store the last embedding that successfully passed all checks\n", " last_valid_embedding = start_embedding.clone().detach()\n", "\n", " # --- ADDITION: Store initial prediction for MAE calculation ---\n", " with torch.no_grad():\n", " initial_pred_scaled = pred_from_embeds(initial_embedding, regression_model)[:, target_idx]\n", " initial_pred_unscaled = initial_pred_scaled * torch.tensor(scaler.scale_, device=device) + torch.tensor(scaler.mean_, device=device)\n", " initial_prediction_value = initial_pred_unscaled.item()\n", "\n", " optimizer = torch.optim.Adam([start_embedding], lr=learning_rate)\n", "\n", " # --- Extract scaler attributes for PyTorch-based inverse transform ---\n", " scale = torch.tensor(scaler.scale_, device=device, dtype=torch.float32)\n", " mean = torch.tensor(scaler.mean_, device=device, dtype=torch.float32)\n", " # --- 2. Run gradient ascent with step-by-step checks ---\n", " predicting_model = regression_expert_model\n", " pred_unscaled = 0\n", "\n", " with torch.no_grad():\n", " pred = pred_from_embeds(start_embedding.cuda(), regression_model)\n", " sample_scaled = pred[:, target_idx]\n", " sample_unscaled = sample_scaled * scale + mean\n", " print(f'FIRST PRED: {sample_unscaled.detach().item()}') \n", "\n", " for step in range(steps):\n", " if pred_unscaled >= expert_threshold:\n", " predicting_model = regression_expert_model\n", " else:\n", " predicting_model = regression_expert_model\n", " \n", " optimizer.zero_grad()\n", "\n", " pred_scaled = pred_from_embeds(start_embedding.cuda(), regression_model)[:, target_idx]\n", " pred_unscaled = pred_scaled * scale + mean\n", " property_loss = -pred_unscaled.mean()\n", " regularization_loss = torch.norm(start_embedding - initial_embedding, p=2)**2\n", " loss = property_loss + scale_factor * (lambda_reg * regularization_loss)\n", " \n", " loss.backward()\n", " optimizer.step()\n", "\n", " with torch.no_grad():\n", " start_embedding.clamp_(-20, 20)\n", "\n", " # --- VALIDATION BLOCK: Check molecule after each step ---\n", " with torch.no_grad():\n", " reconstructed_ids = generative_model.generate(start_embedding.cuda(), temperature=1.7)\n", " reconstructed_smiles = generative_model.tokenizer.decode(reconstructed_ids[0], skip_special_tokens=True)\n", " \n", " # Use the combined helper function for all checks\n", " is_valid_and_novel = is_novel_and_valid_polymer(reconstructed_smiles, training_smiles_set)\n", " \n", " print(f\"Step {step+1}/{steps} | Value: {pred_unscaled.item():.4f} | SMILES: {reconstructed_smiles} | Novel & Valid Polymer: {is_valid_and_novel}\")\n", "\n", " if is_valid_and_novel:\n", " # Tokenize the reconstructed SMILES\n", " recon_tokens = generative_model.tokenizer(\n", " reconstructed_smiles,\n", " max_length=512,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " # Get embedding from the reconstructed SMILES\n", " recon_embedding = generative_model.encoder(recon_tokens['input_ids'], recon_tokens['attention_mask'])\n", " \n", " CO2_scaled = pred_from_embeds(recon_embedding, predicting_model)[:, target_idx]\n", " CO2_unscaled = CO2_scaled * scale + mean\n", " \n", " print(f\" -> Reconstructed molecule CO2 permeability: {CO2_unscaled.item():.4f}\")\n", " \n", " # --- ADDITION 2: Calculate MAE between initial and reconstructed prediction ---\n", " mae_value = float(abs(CO2_unscaled + property_loss))\n", " print(f\" -> MAE between initial embedding prediction and current prediction: {mae_value:.4f}\")\n", " \n", "\n", " # If it passes all checks, save this as the current best state\n", " last_valid_embedding = start_embedding.clone().detach()\n", " else:\n", " pass\n", "\n", " print(f'FINAL VALUE: {pred_unscaled.detach().item():.4f}')\n", " return last_valid_embedding\n" ] }, { "cell_type": "code", "execution_count": null, "id": "53916086-d696-47e4-8984-9e4824bbd4d2", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 14, "id": "34e0c67c-ad86-4b7c-9b3c-8132d9aec45b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(161379.46, 5117051)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sorted_df = df.sort_values(by=['CO2'], ascending=False)\n", "best_idx = sorted_df.index.tolist()[0]\n", "best_co2 = float(sorted_df['CO2'][best_idx])\n", "best_embedding = base_embeddings[best_idx, :].reshape(1, -1).cuda()\n", "best_co2, best_idx" ] }, { "cell_type": "code", "execution_count": null, "id": "d4394402-1921-448d-8bd8-bcca0cd997f3", "metadata": {}, "outputs": [], "source": [ "optimized_embedding = gradient_based_extrapolation(\n", " model, regression_model, regression_model, base_embeddings, scaler_co2, target_idx=-1, learning_rate=1e-2, steps=30, batch_size=32, best_idx=500, lambda_reg=1, scale_factor=0, expert_threshold=100_000\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a1dc47ec-4ce4-44c0-b5da-c14ade1b415b", "metadata": {}, "outputs": [], "source": [ "example_novel_polymer = 'Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)N1C(=O)c2c(C1=O)c(ccc2)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc2c(c1)C(=O)N(C2=O)I'" ] }, { "cell_type": "code", "execution_count": 58, "id": "d30a41fe-a8a8-4c6c-9e48-d5403a079f36", "metadata": {}, "outputs": [], "source": [ "df_robson = sample_df.copy()" ] }, { "cell_type": "code", "execution_count": 4, "id": "f4443775-23a3-4c34-a8d4-3ba4cd5744d2", "metadata": {}, "outputs": [], "source": [ "!cp -r polygnn_kit/polygnn_kit ." ] }, { "cell_type": "code", "execution_count": 79, "id": "30876742-2f6e-4564-90e5-6b186a505fac", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "'GT Open Source General Use License.pdf' polygnn_kit.py tests\n", " __init__.py\t\t\t\t __pycache__\t utils.py\n", " poetry.lock\t\t\t\t pyproject.toml\n", " polygnn_kit\t\t\t\t README.md\n" ] } ], "source": [ "!cd polygnn_kit && ls" ] }, { "cell_type": "code", "execution_count": 81, "id": "2d56df39-af0c-4d87-8966-9a5ee6c776cc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "!rm -rf polygnn_kit/polygnn_kit" ] }, { "cell_type": "code", "execution_count": 15, "id": "d371d73c-6362-4ad7-bf69-2e54ccfac3eb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'exp_perm_CH4__Barrer': 'Forward(MinMaxScaler(dim: 0, max: 4.07918119430542, min: -3.387216091156006))',\n", " 'exp_perm_CO2__Barrer': 'Forward(MinMaxScaler(dim: 0, max: 4.645422458648682, min: -5.92081880569458))',\n", " 'exp_perm_H2__Barrer': 'Forward(MinMaxScaler(dim: 0, max: 4.22788667678833, min: -1.642065167427063))',\n", " 'exp_perm_He__Barrer': 'Forward(MinMaxScaler(dim: 0, max: 3.8041393756866455, min: -1.2612193822860718))',\n", " 'exp_perm_N2__Barrer': 'Forward(MinMaxScaler(dim: 0, max: 3.7075700759887695, min: -3.795880079269409))',\n", " 'exp_perm_O2__Barrer': 'Forward(MinMaxScaler(dim: 0, max: 3.931457757949829, min: -6.15490198135376))',\n", " 'exp_solubility__MPa**0.5': 'Forward(MinMaxScaler(dim: 0, max: 29.200000762939453, min: 12.300000190734863))'}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import joblib\n", "import json\n", "\n", "with open('polygnn/trained_models/solubility_and_permeability/metadata/selectors.json', 'r') as file:\n", " a = json.load(file)\n", " selectors = a['exp_perm_CO2__Barrer'][0]\n", "\n", "with open('polygnn/trained_models/solubility_and_permeability/metadata/scalers.json', 'r') as file:\n", " scalers = json.load(file)\n", " polygnn_scaler = scalers['exp_perm_CO2__Barrer']\n", "\n", "scalers" ] }, { "cell_type": "code", "execution_count": 17, "id": "d0977829-6ef6-4272-a934-023c73b78f1a", "metadata": {}, "outputs": [], "source": [ "root_dir = 'polygnn/trained_models/solubility_and_permeability'" ] }, { "cell_type": "code", "execution_count": 71, "id": "774351fc-6cd0-4626-bcbe-fb6fb0ba4ac7", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pandas as pd\n", "import polygnn\n", "import polygnn_trainer as pt\n", "from tqdm import tqdm\n", "import joblib\n", "import os\n", "\n", "def inverse_minmax(scaled_tensor,\n", " min_val=-5.92081880569458,\n", " max_val= 4.645422458648682):\n", "\n", " return scaled_tensor * (max_val - min_val) + min_val\n", "\n", "def load_polygnn_model(model_path, device='cuda'):\n", " \"\"\"\n", " Load a pre-trained polyGNN model and its configuration dynamically \n", " from pickled files in the metadata folder.\n", " \n", " Args:\n", " model_path (str): Path to the saved model directory.\n", " device (str): Device to load the model on ('cuda' or 'cpu').\n", " \n", " Returns:\n", " tuple: Loaded ensemble model, SMILES featurizer, and the model path.\n", " \"\"\"\n", " metadata_path = os.path.join(model_path, \"metadata\")\n", " if not os.path.exists(metadata_path):\n", " raise FileNotFoundError(f\"Metadata folder not found at {metadata_path}\")\n", "\n", "\n", " \n", " bond_config = polygnn.featurize.BondConfig(True, True, True)\n", " atom_config = polygnn.featurize.AtomConfig(\n", " True,\n", " True,\n", " True,\n", " True,\n", " True,\n", " True,\n", " combo_hybrid=False,\n", " aromatic=True,\n", " )\n", " \n", " ensemble = pt.load.load_ensemble(\n", " model_path,\n", " polygnn.models.polyGNN,\n", " device,\n", " submodel_kwargs_dict={\n", " \"node_size\": atom_config.n_features,\n", " \"edge_size\": bond_config.n_features,\n", " \"selector_dim\": len(selectors),\n", " },\n", " \n", " )\n", " \n", " # --- 5. Create the SMILES featurizer with loaded configs ---\n", " import functools\n", " kwargs = dict(bond_config=bond_config,\n", " atom_config=atom_config,\n", " representation=\"trimer\")\n", " smiles_featurizer = functools.partial(polygnn.featurize.get_minimum_graph_tensor, **kwargs)\n", " \n", " print(\"polyGNN model and configuration loaded successfully from pickle files.\")\n", " return ensemble, smiles_featurizer, model_path\n", "\n", "\n", "def predict_co2_permeability_polygnn(smiles_list, ensemble, smiles_featurizer, model_path, device='cuda'):\n", " \"\"\"\n", " Predict CO2 permeability using polyGNN model\n", " \n", " Args:\n", " smiles_list: List of SMILES strings\n", " ensemble: Loaded polyGNN ensemble model\n", " smiles_featurizer: SMILES featurization function\n", " model_path: Path to model directory (for loading scalers)\n", " device: Device for computation\n", " \n", " Returns:\n", " Tensor of CO2 permeability predictions\n", " \"\"\"\n", " # Create a temporary dataframe for prediction\n", " temp_df = pd.DataFrame({\n", " 'smiles_string': smiles_list,\n", " 'prop': ['exp_perm_CO2__Barrer'] * len(smiles_list), # Adjust property name as needed\n", " })\n", " \n", " # Run predictions\n", " with torch.no_grad():\n", " y, y_mean_hat, y_std_hat, _selectors = pt.infer.eval_ensemble(\n", " model=ensemble,\n", " root_dir=model_path,\n", " dataframe=temp_df,\n", " smiles_featurizer=smiles_featurizer,\n", " device=device,\n", " ensemble_kwargs_dict={\"monte_carlo\": False},\n", " )\n", " return y, y_mean_hat, y_std_hat\n", "\n", "def pred_from_embeds(embeds, regression_model):\n", " x = regression_model.relu(embeds)\n", " return regression_model.clf(x)\n", "\n", "def gradient_based_extrapolation(\n", " generative_model, regression_model, regression_expert_model, embeddings, scaler, target_idx=-2,\n", " learning_rate=0.0001, steps=50, batch_size=32, best_idx=None, lambda_reg=0.3, scale_factor=1000, \n", " expert_threshold=120_000, polygnn_model_path=None\n", "):\n", " \"\"\"\n", " Use property gradients to guide extrapolation, with checks for novelty and polymer validity.\n", " Now includes PolyGNN predictions for CO2 permeability.\n", " \n", " Args:\n", " polygnn_model_path: Path to pre-trained polyGNN model for CO2 permeability prediction\n", " \"\"\"\n", " device = next(regression_model.parameters()).device\n", "\n", " # Load PolyGNN model if path is provided\n", " polygnn_ensemble = None\n", " polygnn_featurizer = None\n", " if polygnn_model_path:\n", " polygnn_ensemble, polygnn_featurizer, _ = load_polygnn_model(polygnn_model_path, device)\n", " \n", "\n", " # --- Freeze models to prevent parameter updates ---\n", " for param in regression_model.parameters():\n", " param.requires_grad = False\n", " regression_model.eval()\n", "\n", " for param in regression_expert_model.parameters():\n", " param.requires_grad = False\n", " regression_expert_model.eval()\n", " \n", " for param in generative_model.parameters():\n", " param.requires_grad = False\n", " generative_model.eval()\n", "\n", " # --- 1. Find the best starting embedding (batched) ---\n", " if best_idx is None:\n", " properties = []\n", " with torch.no_grad():\n", " for i in tqdm(range(0, len(embeddings), batch_size), desc=\"Computing initial properties\"):\n", " batch = embeddings[i:i+batch_size]\n", " preds = regression_model.clf(batch)\n", " properties.append(preds[:, target_idx].cpu())\n", " properties = torch.cat(properties)\n", " best_idx = torch.argmax(properties)\n", "\n", " start_embedding = embeddings[best_idx].cuda().reshape(1, -1).requires_grad_(True)\n", " initial_embedding = start_embedding.clone().cuda()\n", " last_valid_embedding = start_embedding.clone().detach()\n", "\n", " # --- Store initial prediction for MAE calculation ---\n", " with torch.no_grad():\n", " initial_pred_scaled = pred_from_embeds(initial_embedding, regression_model)[:, target_idx]\n", " initial_pred_unscaled = initial_pred_scaled * torch.tensor(scaler.scale_, device=device) + torch.tensor(scaler.mean_, device=device)\n", " initial_prediction_value = initial_pred_unscaled.item()\n", "\n", " optimizer = torch.optim.Adam([start_embedding], lr=learning_rate)\n", "\n", " # --- Extract scaler attributes for PyTorch-based inverse transform ---\n", " scale = torch.tensor(scaler.scale_, device=device, dtype=torch.float32)\n", " mean = torch.tensor(scaler.mean_, device=device, dtype=torch.float32)\n", " \n", " # --- 2. Run gradient ascent with step-by-step checks ---\n", " predicting_model = regression_expert_model\n", " pred_unscaled = 0\n", "\n", " with torch.no_grad():\n", " pred = pred_from_embeds(start_embedding.cuda(), regression_model)\n", " sample_scaled = pred[:, target_idx]\n", " sample_unscaled = sample_scaled * scale + mean\n", " print(f'FIRST PRED: {sample_unscaled.detach().item()}') \n", "\n", " for step in range(steps):\n", " if pred_unscaled >= expert_threshold:\n", " predicting_model = regression_expert_model\n", " else:\n", " predicting_model = regression_expert_model\n", " \n", " optimizer.zero_grad()\n", "\n", " pred_scaled = pred_from_embeds(start_embedding.cuda(), regression_model)[:, target_idx]\n", " pred_unscaled = pred_scaled * scale + mean\n", " property_loss = -pred_unscaled.mean()\n", " regularization_loss = torch.norm(start_embedding - initial_embedding, p=2)**2\n", " loss = property_loss + scale_factor * (lambda_reg * regularization_loss)\n", " \n", " loss.backward()\n", " optimizer.step()\n", "\n", " with torch.no_grad():\n", " start_embedding.clamp_(-20, 20)\n", "\n", " # --- VALIDATION BLOCK: Check molecule after each step ---\n", " with torch.no_grad():\n", " reconstructed_ids = generative_model.generate(start_embedding.cuda(), temperature=1.5)\n", " reconstructed_smiles = generative_model.tokenizer.decode(reconstructed_ids[0], skip_special_tokens=True)\n", " \n", " # Use the combined helper function for all checks\n", " is_valid_and_novel = is_novel_and_valid_polymer(reconstructed_smiles, training_smiles_set)\n", " \n", " print(f\"Step {step+1}/{steps} | Value: {pred_unscaled.item():.4f} | SMILES: {reconstructed_smiles} | Novel & Valid Polymer: {is_valid_and_novel}\")\n", "\n", " if is_valid_and_novel:\n", " # Tokenize the reconstructed SMILES\n", " recon_tokens = generative_model.tokenizer(\n", " reconstructed_smiles,\n", " max_length=512,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " # Get embedding from the reconstructed SMILES\n", " recon_embedding = generative_model.encoder(recon_tokens['input_ids'], recon_tokens['attention_mask'])\n", " \n", " CO2_scaled = pred_from_embeds(recon_embedding, predicting_model)[:, target_idx]\n", " CO2_unscaled = CO2_scaled * scale + mean\n", " \n", " print(f\" -> Reconstructed molecule CO2 permeability: {CO2_unscaled.item():.4f}\")\n", " \n", " # --- NEW: PolyGNN CO2 permeability prediction ---\n", " if polygnn_ensemble and polygnn_featurizer:\n", " corrected_smiles = reconstructed_smiles.replace(\"I\", \"[*]\")\n", " print(corrected_smiles)\n", " polygnn_pred = predict_co2_permeability_polygnn(\n", " [corrected_smiles], \n", " polygnn_ensemble, \n", " polygnn_featurizer, \n", " polygnn_model_path, \n", " device\n", " )\n", " polygnn_pred = inverse_minmax(polygnn_pred)\n", " print(f\" -> PolyGNN CO2 permeability prediction: {polygnn_pred.item():.4f}\")\n", " \n", " # Compare predictions\n", " pred_diff = abs(CO2_unscaled.item() - polygnn_pred.item())\n", " print(f\" -> Prediction difference (|Main - PolyGNN|): {pred_diff:.4f}\")\n", " \n", "\n", " \n", " # --- Calculate MAE between initial and reconstructed prediction ---\n", " mae_value = float(abs(CO2_unscaled + property_loss))\n", " print(f\" -> MAE between initial embedding prediction and current prediction: {mae_value:.4f}\")\n", "\n", " # If it passes all checks, save this as the current best state\n", " last_valid_embedding = start_embedding.clone().detach()\n", " else:\n", " pass\n", "\n", " print(f'FINAL VALUE: {pred_unscaled.detach().item():.4f}')\n", " return last_valid_embedding\n" ] }, { "cell_type": "code", "execution_count": 61, "id": "dd790c00-2e48-4a51-8195-003933fc92c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I'" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['Smiles'][0]" ] }, { "cell_type": "code", "execution_count": null, "id": "9a0b18aa-e95c-4405-bae3-6f1cc72525ea", "metadata": {}, "outputs": [], "source": [ "result = gradient_based_extrapolation(\n", " generative_model=model,\n", " regression_model=regression_model,\n", " regression_expert_model=regression_model,\n", " embeddings=base_embeddings,\n", " scaler=scaler_co2,\n", " best_idx=best_idx,\n", " target_idx=-1,\n", " polygnn_model_path=\"polygnn/trained_models/solubility_and_permeability\"\n", " \n", ")" ] }, { "cell_type": "code", "execution_count": 152, "id": "e93b3bd5-89e3-42d9-8a87-46af7bf3c338", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.00419, 166.28537)" ] }, "execution_count": 152, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_min_max(real_values):\n", " \"\"\"\n", " Derive the Min-Max‐scaler parameters directly from unscaled labels.\n", "\n", " Parameters\n", " ----------\n", " real_values : list | np.ndarray | torch.Tensor\n", " Iterable of ground-truth permeability values (original units).\n", "\n", " Returns\n", " -------\n", " tuple(float, float)\n", " (min_value, max_value) for the dataset.\n", " \"\"\"\n", " arr = np.asarray(real_values, dtype=float)\n", " return float(arr.min()), float(arr.max())\n", "\n", "mi, ma = get_min_max(df.sort_values(by=['CO2'], ascending=True).reset_index(drop=True)[:5_000_000]['CO2'])\n", "mi, ma" ] }, { "cell_type": "code", "execution_count": 151, "id": "b3c2b0ff-04e9-4274-935f-ab9e8440742b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0SmilesTgHeN2O2CH4CO2synthesizable
51170515117051Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
16542681654268Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
66922756692275Ic1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
830270830270Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
66922766692276Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(...544.42343.893983284.4872018341.89400528.13604161379.46000False
..............................
1095010950Ic1cc(Oc2cc(Oc3cc(cc(c3)n3c(=O)c4c(c3=O)cc3c(c...538.151.235460.001050.004040.001950.00785True
53090145309014Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...552.002.563850.001380.007770.002870.00772False
25984182598418Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...552.002.563850.001380.007770.002870.00772False
746918746918Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...552.002.563850.001380.007770.002870.00772False
22552255Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O...586.801.805440.000680.003300.001840.00419False
\n", "

6726950 rows × 9 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 Smiles \\\n", "5117051 5117051 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... \n", "1654268 1654268 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... \n", "6692275 6692275 Ic1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... \n", "830270 830270 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... \n", "6692276 6692276 Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... \n", "... ... ... \n", "10950 10950 Ic1cc(Oc2cc(Oc3cc(cc(c3)n3c(=O)c4c(c3=O)cc3c(c... \n", "5309014 5309014 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "2598418 2598418 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "746918 746918 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "2255 2255 Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... \n", "\n", " Tg He N2 O2 CH4 CO2 \\\n", "5117051 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "1654268 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "6692275 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "830270 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "6692276 544.42 343.89398 3284.48720 18341.89400 528.13604 161379.46000 \n", "... ... ... ... ... ... ... \n", "10950 538.15 1.23546 0.00105 0.00404 0.00195 0.00785 \n", "5309014 552.00 2.56385 0.00138 0.00777 0.00287 0.00772 \n", "2598418 552.00 2.56385 0.00138 0.00777 0.00287 0.00772 \n", "746918 552.00 2.56385 0.00138 0.00777 0.00287 0.00772 \n", "2255 586.80 1.80544 0.00068 0.00330 0.00184 0.00419 \n", "\n", " synthesizable \n", "5117051 False \n", "1654268 False \n", "6692275 False \n", "830270 False \n", "6692276 False \n", "... ... \n", "10950 True \n", "5309014 False \n", "2598418 False \n", "746918 False \n", "2255 False \n", "\n", "[6726950 rows x 9 columns]" ] }, "execution_count": 151, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.sort_values(by=['CO2'], ascending=False)" ] }, { "cell_type": "code", "execution_count": 76, "id": "bfb10438-7d6b-4e73-8856-33fdea15ec72", "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 77, "id": "25fa51a4-27a0-4f63-872e-3e5f80953efe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5ccc(C(=O)c6ccc7c(c6)C(=O)N(c6ccc(C(=O)c8cccc([*])c8C)c(Cl)c6)C7=O)cn5)cn4)cc3C2=O)c(Cl)c1',\n", " '[*]c1ccc(Oc2cccc(Oc3ccc(N4C(=O)c5ccc(C(=O)c6ccc(C(=O)c7ccc8c(c7)C(=O)N([*])C8=O)cn6)cc5C4=O)cc3C)c2)c(C)c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3cccc(C(=O)c4ccc(Oc5ccc(C(=O)c6ccc7c(c6)C(=O)N(c6ccc(C(=O)c8ccc(C)c([*])c8)cc6Cl)C7=O)cc5)cc4)c3C2=O)cc1Cl',\n", " '[*]C(=O)c1cccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc(C(=O)c4ccc(C(=O)c5ccc(N6C(=O)c7ccc(C(=O)c8ccc([*])cc8)cc7C6=O)cc5Cl)cc4)c(Cl)c2)C3=O)c1',\n", " '[*]Cc1ccc(N2C(=O)c3ccc(Oc4ccc5c(c4)C(=O)N(c4ccc(Cc6cc(C)c([*])cc6C)c(Cl)c4)C5=O)cc3C2=O)cc1Cl',\n", " '[*]C(=O)c1ccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc(N4C(=O)c5ccc(C(=O)c6ccc([*])cc6)cc5C4=O)c(Cl)c2Cl)C3=O)cc1',\n", " '[*]Oc1ccc2c(c1)[nH]c1cc(Oc3cc(C(=O)O)cc(N4C(=O)c5ccc(Oc6ccc7ccc(Oc8cccc9c8C(=O)N(c8cc([*])cc(C(=O)O)c8)C9=O)cc7c6)cc5C4=O)c3)ccc12',\n", " '[*]c1ccc(C)c(N2C(=O)c3cccc(Oc4cccc(Sc5ccc6c(c5)C(=O)N([*])C6=O)c4)c3C2=O)c1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc3c(c1)Cc1cc(N4C(=O)c5ccc(C(=O)c6ccc7cc([*])ccc7c6)cc5C4=O)ccc1-3)C2=O',\n", " '[*]Sc1cccc(Sc2ccc(C)c(N3C(=O)c4cccc(C(=O)c5cccc(C(=O)c6ccc7c(c6)C(=O)N(c6cc([*])ccc6C)C7=O)c5)c4C3=O)c2)c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(Sc4cccc(Sc5ccc6c(c5)C(=O)N(c5ccc(-c7ccc(C)c([*])c7)cc5)C6=O)c4)cc3C2=O)cc1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(Oc4ccc(Oc5ccc(Oc6ccc7c(c6)C(=O)N(c6ccc(C(=O)c8ccc9c(c8)[nH]c8cc([*])ccc89)c(Cl)c6Cl)C7=O)cc5)cc4)cc3C2=O)c(Cl)c1Cl',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(N3C(=O)c4cccc(C(=O)c5ccc(Sc6ccc([*])cc6)cc5)c4C3=O)cc1)C2=O',\n", " '[*]Oc1ccc(Sc2ccc(Oc3ccc4c(c3)C(=O)N(c3cc(Cl)cc(N5C(=O)c6cccc([*])c6C5=O)c3)C4=O)cc2)cc1',\n", " '[*]c1ccc(C(=O)c2cccc(C(=O)c3ccc(N4C(=O)c5ccc(Oc6cccc(Oc7cccc(Oc8cccc9c8C(=O)N([*])C9=O)c7)c6)cc5C4=O)c(C)c3)c2)c(C)c1',\n", " '[*]C(=O)c1cccc(C(=O)c2cccc(N3C(=O)c4cccc(Oc5cccc(C(=O)c6cccc(Oc7cccc8c7C(=O)N(c7cccc([*])c7C)C8=O)c6)c5)c4C3=O)c2C)c1',\n", " '[*]Oc1ccc(N2C(=O)c3ccc(C(=O)c4cc5ccccc5cc4C(=O)c4cccc5c4C(=O)N(c4ccc([*])c(C)c4)C5=O)cc3C2=O)c(Cl)c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5cccc6c5C(=O)N(c5ccc(C(=O)c7ccc(C)c([*])c7)cc5C)C6=O)cn4)cc3C2=O)c(C)c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4cccc(Oc5cccc(C(=O)c6cccc7c6C(=O)N(c6ccc(C(=O)c8cc([*])cc(C(=O)O)c8)c(Cl)c6Cl)C7=O)c5)c4)cc3C2=O)c(Cl)c1Cl',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc5ccc(C(=O)c6ccc7c(c6)C(=O)N(c6ccc(C(=O)c8ccc([*])cc8C)c(Cl)c6)C7=O)cc5c4)cc3C2=O)cc1Cl',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(Sc3cccc(Sc4ccc(N5C(=O)c6cccc(C(=O)c7ccc(C(=O)c8ccc([*])cn8)nc7)c6C5=O)c(Cl)c4Cl)c3)c(Cl)c1Cl)C2=O',\n", " '[*]C(=O)c1cccc(N2C(=O)c3cccc(C(=O)c4ccc(C(=O)c5ccc(C(=O)c6cccc7c6C(=O)N(c6cccc(C(=O)c8cc([*])ccc8C)c6)C7=O)cn5)nc4)c3C2=O)c1',\n", " '[*]C(=O)c1cccc(C(=O)c2ccc(N3C(=O)c4ccc(Oc5cccc(Oc6ccc7c(c6)C(=O)N(c6ccc([*])cc6C)C7=O)c5)cc4C3=O)cc2C)c1',\n", " '[*]C(=O)c1ccc(C(=O)c2cccc3c2C(=O)N(c2ccc(Cc4cc(-c5ccc(-c6ccc(C(=O)c7ccc([*])cn7)cc6Cl)c(Cl)c5)cc(C(F)(F)F)c4)c(Cl)c2)C3=O)cn1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(Oc3ccc4c(c3)[nH]c3cc(Oc5ccc(N6C(=O)c7ccc(C(=O)c8ccc9ccc([*])cc9c8)cc7C6=O)c(Cl)c5)ccc34)cc1Cl)C2=O',\n", " '[*]c1ccc(C(=O)c2ccc3c(c2)[nH]c2cc(C(=O)c4ccc(N5C(=O)c6ccc(C(=O)c7cccc(Sc8cccc(C(=O)c9ccc%10c(c9)C(=O)N([*])C%10=O)c8)c7)cc6C5=O)c(C)c4)ccc23)c(C)c1',\n", " '[*]c1cc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5ccc6c(c5)C(=O)N([*])C6=O)nc4)cc3C2=O)ccc1C',\n", " '[*]Nc1cccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc4cc(N5C(=O)c6ccc(C(=O)c7cccc(C([*])=O)c7)cc6C5=O)ccc4c2)C3=O)c1',\n", " '[*]C(=O)c1cccc(C(=O)c2cccc(C(=O)c3ccc4c(c3)C(=O)N(c3ccc(-c5c(C)cc([*])cc5C)c(Cl)c3Cl)C4=O)c2)c1',\n", " '[*]c1cccc(C(C)(C)c2cc(C(=O)O)cc(N3C(=O)c4ccc(Oc5ccc6ccc(Oc7ccc8c(c7)C(=O)N(c7cc(C(=O)O)cc(C([*])(C)C)c7)C8=O)cc6c5)cc4C3=O)c2)c1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(N3C(=O)c4ccc(C(=O)c5ccc(Oc6ccc([*])cn6)nc5)cc4C3=O)nc1)C2=O',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4cccc(Cc5cccc(C(=O)c6cccc7c6C(=O)N(c6ccc(C(=O)c8cc([*])cc(C(=O)O)c8)cc6)C7=O)c5)c4)cc3C2=O)cc1',\n", " '[*]c1ccc(C)c(C(=O)c2cccc(C(=O)c3ccc(C)c(N4C(=O)c5cccc(Oc6cccc(C(=O)c7cccc(Oc8cccc9c8C(=O)N([*])C9=O)c7)c6)c5C4=O)c3)c2C)c1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)C3=CC(C)=C(N4C(=O)c5cccc(C(=O)c6ccc7c(c6)[nH]c6cc(C(=O)c8ccc9c(c8)C(=O)N(c8ccc([*])cc8C)C9=O)ccc67)c5C4=O)C=CS3)ccc1-2',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(Oc4ccc5c(c4)C(=O)N(c4ccc(Oc6ccc([*])c(Cl)c6)cc4Cl)C5=O)cc3C2=O)cc1Cl',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3ccc(N4C(=O)c5ccc(C(=O)c6ccc7c(c6)C(=O)N(c6ccc([*])c(C)c6)C7=O)cc5C4=O)c(C)c3)ccc1-2',\n", " '[*]C(=O)c1ccc(Cc2cccc3c2C(=O)N(c2ccc4c(c2)[nH]c2cc(-c5cccc(N6C(=O)c7cccc(Cc8ccc([*])cn8)c7C6=O)c5)ccc24)C3=O)cn1',\n", " '[*]C(=O)c1ccc(Oc2ccc(Oc3ccc4c(c3)C(=O)N(c3ccc(C(=O)c5cc6ccccc6cc5C(=O)c5ccc(N6C(=O)c7ccc(Oc8ccc([*])nc8)cc7C6=O)c(Cl)c5)cc3Cl)C4=O)cn2)cn1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(C(=O)c3ccc4c(c3)Cc3cc(C(=O)c5ccc(N6C(=O)c7ccc(C(=O)c8cccc(Oc9cccc([*])c9)c8)cc7C6=O)nc5)ccc3-4)nc1)C2=O',\n", " '[*]Cc1ccc(N2C(=O)c3ccc(C(=O)c4ccc(Sc5ccc(C(=O)c6ccc7c(c6)C(=O)N(c6ccc(Cc8ccc(Oc9ccc(C)c([*])c9)c(Cl)c8)cc6Cl)C7=O)nc5)cn4)cc3C2=O)c(Cl)c1',\n", " '[*]C1=C2C=C3OC(=C1)C=C(C(=O)O)C=CC(C(=O)c1ccc(C(=O)c4cccc(C(=O)c5ccc6c(c5)C(=O)N([*])C6=O)c4)cn1)=Cc1ccc4c(c1)C(=O)N(C4=O)C2=C3C',\n", " '[*]C(=O)c1ccc(Sc2cccc3c2C(=O)N(c2cccc(N4C(=O)c5ccc(Sc6ccc([*])cn6)cc5C4=O)c2)C3=O)nc1',\n", " '[*]Sc1ccc(N2C(=O)c3ccc(C(=O)c4ccc5cc(C(=O)c6ccc7c(c6)C(=O)N(c6ccc([*])cc6C)C7=O)ccc5c4)cc3C2=O)c(C)c1',\n", " '[*]Sc1cccc(C(=O)c2ccc(N3C(=O)c4cccc(C(=O)c5cccc(Sc6cccc(C(=O)c7ccc8c(c7)C(=O)N(c7ccc(C([*])=O)cn7)C8=O)c6)c5)c4C3=O)nc2)c1',\n", " '[*]Oc1ccc2cc(Oc3ccc(N4C(=O)c5ccc(C(=O)c6ccc(C(=O)c7ccc(C(=O)c8ccc(C(=O)C(=O)c9ccc%10c(c9)C(=O)N(c9ccc([*])cc9C)C%10=O)cc8)cc7)cc6)cc5C4=O)c(C)c3)ccc2c1',\n", " '[*]c1ccc(C(=O)c2cc(C(=O)O)cc(C(=O)c3ccc(N4C(=O)c5cccc(Sc6cccc(C(=O)c7cccc(Sc8ccc9c(c8)C(=O)N([*])C9=O)c7)c6)c5C4=O)cc3C)c2)c(C)c1',\n", " '[*]c1cccc(C(=O)c2cccc3c2C(=O)N(c2ccc(Oc4ccc(Oc5ccc(N6C(=O)c7ccc(C(=O)c8cccc(C([*])(C)C)c8)cc7C6=O)c(Cl)c5)cc4)cc2Cl)C3=O)c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5ccc6c(c5)C(=O)N(c5ccc(C(=O)c7cc([*])cc(C(=O)O)c7)cn5)C6=O)cn4)cc3C2=O)nc1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(Oc4ccc5ccc(Oc6ccc7c(c6)C(=O)N(c6ccc([*])c(Cl)c6Cl)C7=O)cc5c4)cc3C2=O)c(Cl)c1Cl',\n", " '[*]C(=O)c1cc(C(=O)O)cc(C(=O)c2ccc(N3C(=O)c4cccc(Oc5ccc(C(=O)c6ccc(Oc7cccc8c7C(=O)N(c7ccc([*])cc7C)C8=O)cn6)nc5)c4C3=O)c(C)c2)c1',\n", " '[*]C(=O)c1ccc2cc(C(=O)c3ccc(N4C(=O)c5cccc(C(=O)c6cccc(Oc7cccc(C(=O)c8ccc9c(c8)C(=O)N(c8ccc([*])cc8)C9=O)c7)c6)c5C4=O)cc3)ccc2c1',\n", " '[*]c1ccc(C)c(C(=O)c2ccc(C(=O)c3cc(N4C(=O)c5cccc(C(=O)c6cccc(C(=O)c7cccc(C(=O)c8ccc9c(c8)C(=O)N([*])C9=O)c7)c6)c5C4=O)ccc3C)c(Cl)c2)c1',\n", " '[*]C(=O)c1ccc2ccc(C(=O)c3cc(C(=O)O)cc(N4C(=O)c5ccc(C(=O)c6ccc(C(=O)c7ccc(C(=O)c8ccc9c(c8)C(=O)N(c8cc([*])cc(C(=O)O)c8)C9=O)nc7)cn6)cc5C4=O)c3)cc2c1',\n", " '[*]C(=O)c1ccc(C(=O)c2ccc(C)c(N3C(=O)c4ccc(Oc5ccc6ccc(Oc7cc(-c8cc([*])ccc8C)ccc7C)cc6c5)cc4C3=O)c2)cc1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5ccc(C(=O)c6ccc(C(=O)c7ccc8c(c7)C(=O)N(c7ccc([*])cc7C)C8=O)nc6)cc5)nc4)cc3C2=O)c(C)c1',\n", " '[*]C(=O)c1ccc(Oc2cccc3c2C(=O)N(c2cccc(N4C(=O)c5ccc(Oc6ccc([*])nc6)cc5C4=O)c2)C3=O)cn1',\n", " '[*]C(=O)c1cc(C(=O)O)cc(C(=O)c2cc(C)c(N3C(=O)c4ccc(C(=O)c5cccc(Oc6cccc(C(=O)c7ccc8c(c7)C(=O)N(c7c(C)cc([*])cc7C)C8=O)c6)c5)cc4C3=O)c(C)c2)c1',\n", " '[*]Sc1ccc2c(c1)C(=O)N(c1ccc(-c3ccc(N4C(=O)c5ccc([*])cc5C4=O)c(Cl)c3)c(Cl)c1Cl)C2=O',\n", " '[*]Oc1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5ccc6c(c5)C(=O)N(c5ccc(Oc7ccc([*])cc7C)cn5)C6=O)nc4)cc3C2=O)cn1',\n", " '[*]C(=O)c1ccc2cc(C(=O)c3ccc4c(c3)C(=O)N(c3ccc(C(=O)c5cc(S)cc(C(=O)c6ccc(N7C(=O)c8ccc([*])cc8C7=O)c(Cl)c6Cl)c5)c(Cl)c3Cl)C4=O)ccc2c1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(Oc3ccc(Oc4ccc(N5C(=O)c6cccc(C(=O)c7ccc(Oc8ccc([*])cn8)nc7)c6C5=O)c(Cl)c4)cn3)cc1Cl)C2=O',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3cc(C)c(N4C(=O)c5cccc(Oc6ccc(C(=O)c7ccc(Oc8ccc9c(c8)C(=O)N(c8cc(C)c([*])cc8C)C9=O)nc7)cn6)c5C4=O)cc3C)ccc1-2',\n", " '[*]C(=O)c1ccc2c(c1)Cc1cc(C(=O)c3ccc(N4C(=O)c5ccc(Sc6cccc(C(=O)c7cccc(Sc8ccc9c(c8)C(=O)N(c8ccc([*])cc8C)C9=O)c7)c6)cc5C4=O)c(C)c3)ccc1-2',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1ccc(C(=O)c3ccc(C(=O)c4ccc(N5C(=O)c6cccc(C(=O)c7cccc(C(=O)c8cccc([*])c8)c7)c6C5=O)cn4)cc3Cl)cn1)C2=O',\n", " '[*]C(=O)c1ccc2cc(C(=O)c3ccc(N4C(=O)c5ccc(Sc6cccc(Oc7cccc(Sc8ccc9c(c8)C(=O)N(c8ccc([*])cn8)C9=O)c7)c6)cc5C4=O)nc3)ccc2c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3cccc(C(=O)c4cccc(C(=O)c5ccc6c(c5)C(=O)N(c5ccc(C(=O)c7ccc([*])cn7)nc5)C6=O)c4)c3C2=O)nc1',\n", " '[*]C(=O)c1cccc(C(=O)c2ccc3c(c2)C(=O)N(c2cc(Cl)cc(Sc4cccc5c(Sc6cc(Cl)cc(N7C(=O)c8ccc([*])cc8C7=O)c6)cccc45)c2)C3=O)c1',\n", " '[*]c1ccc(Cc2cc3ccccc3cc2Cc2ccc(N3C(=O)c4ccc(C(=O)c5ccc(Oc6ccc(C(=O)c7ccc8c(c7)C(=O)N([*])C8=O)cc6)cc5)cc4C3=O)c(C)c2)cc1C',\n", " '[*]C(=O)c1cccc2c(C(=O)c3cccc4c3C(=O)N(c3cccc(N5C(=O)c6ccc([*])cc6C5=O)c3)C4=O)cccc12',\n", " '[*]Cc1ccc2c(c1)C(=O)N(c1ccc(Sc3ccc4cc(Sc5ccc(N6C(=O)c7ccc([*])cc7C6=O)cc5Cl)ccc4c3)cc1Cl)C2=O',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4cccc(C(=O)c5cccc(C(=O)c6cccc7c6C(=O)N(c6ccc(C(=O)c8c(C)cc([*])cc8C)cc6Cl)C7=O)c5)c4)cc3C2=O)cc1Cl',\n", " '[*]C(=O)c1ccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc(C(=O)c4ccc(C(=O)c5ccc(N6C(=O)c7ccc(C(=O)c8ccc([*])cc8)cc7C6=O)c(Cl)c5)c(Cl)c4Cl)cc2Cl)C3=O)cc1',\n", " '[*]C(=O)c1ccc2c(c1)[nH]c1cc(C(=O)c3cc(C(=O)O)cc(N4C(=O)c5cccc(C(=O)c6ccc(Oc7ccc8c(c7)C(=O)N(c7cc([*])cc(C(=O)O)c7)C8=O)nc6)c5C4=O)c3)ccc12',\n", " '[*]C(=O)c1cccc(N2C(=O)c3ccc(Oc4cccc(C(=O)c5cccc(Oc6ccc7c(c6)C(=O)N(c6cccc(C(=O)c8cc([*])ccc8C)c6)C7=O)c5)c4)cc3C2=O)c1',\n", " '[*]C(=O)c1ccc2ccc(C(=O)c3cccc(N4C(=O)c5cccc(C(=O)c6ccc(C(=O)c7ccc8c(c7)C(=O)N(c7cccc([*])c7)C8=O)nc6)c5C4=O)c3)cc2c1',\n", " '[*]C(=O)c1cc(C(=O)O)cc(C(=O)c2ccc(C)c(N3C(=O)c4cccc(Oc5cccc(C(=O)c6cccc(Oc7ccc8c(c7)C(=O)N(c7cc([*])ccc7C)C8=O)c6)c5)c4C3=O)c2)c1',\n", " '[*]Cc1cccc(N2C(=O)c3ccc(C(=O)c4cc5ccccc5cc4C(=O)c4ccc5c(c4)C(=O)N(c4cccc(Cc6ccc([*])c(C)c6)c4)C5=O)cc3C2=O)c1',\n", " '[*]c1cccc(N2C(=O)c3ccc(Sc4cccc(C(=O)c5cccc(Sc6cccc7c6C(=O)N([*])C7=O)c5)c4)cc3C2=O)c1C',\n", " '[*]c1ccc(C)c(C(=O)c2ccc3c(c2)C(=O)c2cc(C(=O)c4cc(N5C(=O)c6ccc(C(=O)c7ccc(Oc8ccc(C(=O)c9ccc%10c(c9)C(=O)N([*])C%10=O)cn8)nc7)cc6C5=O)ccc4C)ccc2-3)c1',\n", " '[*]C(=O)c1ccc(C(=O)c2ccc(C)c(N3C(=O)c4ccc(C(=O)c5ccc6cc(C(=O)c7ccc8c(c7)C(=O)N(c7cc([*])ccc7C)C8=O)ccc6c5)cc4C3=O)c2)nc1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1cccc(C(=O)c3cccc(C(=O)c4cccc(N5C(=O)c6cccc(C(=O)c7ccc8c(c7)-c7cc([*])ccc7C8(C)C)c6C5=O)c4)c3)c1)C2=O',\n", " '[*]C(=O)c1cccc(C(=O)c2ccc(N3C(=O)c4cccc(C(=O)c5cccc(Oc6cccc(C(=O)c7ccc8c(c7)C(=O)N(c7ccc([*])cc7C)C8=O)c6)c5)c4C3=O)c(C)c2)c1',\n", " '[*]C(=O)c1ccc(C)c(-c2ccc(C)c(C(=O)c3ccc4c(c3)C(=O)N(c3cc(C(=O)c5cc([*])ccc5C)ccc3C)C4=O)c2)c1',\n", " '[*]Oc1ccc2c(c1)C(=O)c1cc(Oc3ccc(N4C(=O)c5cccc(C(=O)c6ccc7cc(C(=O)c8ccc9c(c8)C(=O)N(c8ccc([*])cc8C)C9=O)ccc7c6)c5C4=O)c(C)c3)ccc1-2',\n", " '[*]C(=O)c1ccc(N2C(=O)c3ccc(Oc4ccc(Oc5ccc(Oc6ccc7c(c6)C(=O)N(c6ccc(C(=O)c8cc([*])cc(C(=O)O)c8)c(Cl)c6Cl)C7=O)nc5)cn4)cc3C2=O)c(Cl)c1Cl',\n", " '[*]C(=O)c1cccc(N2C(=O)c3ccc(C(=O)c4cccc(Sc5cccc(C(=O)c6cccc7c6C(=O)N(c6cccc(C(=O)c8ccc(C)c([*])c8)c6)C7=O)c5)c4)cc3C2=O)c1',\n", " '[*]c1ccc(C)c(C(=O)c2ccc(C(=O)c3cc(N4C(=O)c5ccc(C(=O)c6cccc(Sc7cccc(C(=O)c8ccc9c(c8)C(=O)N([*])C9=O)c7)c6)cc5C4=O)ccc3C)cn2)c1',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3cc(C(=O)O)cc(N4C(=O)c5cccc(C(=O)c6ccc7c(c6)C(=O)N(c6cc([*])cc(C(=O)O)c6)C7=O)c5C4=O)c3)ccc1-2',\n", " '[*]C(=O)c1cccc(C(=O)c2cc(N3C(=O)c4ccc(C(=O)c5ccc(C(=O)c6ccc7c(c6)C(=O)N(c6cc([*])ccc6C)C7=O)nc5)cc4C3=O)ccc2C)c1',\n", " '[*]C(=O)c1ccc(C(=O)c2ccc(N3C(=O)c4ccc(Oc5ccc(C(=O)c6cccc7c6C(=O)N(c6ccc([*])cc6C)C7=O)cc5)cc4C3=O)c(C)c2)cn1',\n", " '[*]c1cccc(N2C(=O)c3ccc(Sc4cccc(Oc5ccc6c(c5)C(=O)N([*])C6=O)c4)cc3C2=O)c1C',\n", " '[*]C(=O)c1ccc(C(=O)c2ccc(C)c(N3C(=O)c4ccc(Cc5cccc(C(=O)c6cccc(Cc7ccc8c(c7)C(=O)N(c7cc([*])ccc7C)C8=O)c6)c5)cc4C3=O)c2)cn1',\n", " '[*]c1ccc(C)c(C(=O)c2ccc(C(=O)c3ccc(C)c(N4C(=O)c5ccc(C(=O)c6cccc7c(C(=O)c8ccc9c(c8)C(=O)N([*])C9=O)cccc67)cc5C4=O)c3)cc2C)c1',\n", " '[*]c1ccc(C)c(C(=O)c2ccc(C(=O)c3ccc(C(=O)c4ccc(C)c(N5C(=O)c6ccc(C(=O)c7ccc(C(=O)c8ccc9c(c8)C(=O)N([*])C9=O)cc7)cc6C5=O)c4)cc3)cc2)c1',\n", " '[*]C(=O)c1ccc(N2C(=O)c3cccc(Sc4cccc(Oc5cccc(Sc6ccc7c(c6)C(=O)N(c6ccc(C(=O)c8ccc([*])c(Cl)c8Cl)c(Cl)c6Cl)C7=O)c5)c4)c3C2=O)c(Cl)c1Cl',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)N(c1cccc3c(N4C(=O)c5ccc(C(=O)c6ccc([*])cc6)cc5C4=O)cccc13)C2=O',\n", " '[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3ccc(N4C(=O)c5ccc(C(=O)c6cccc(Oc7cccc(C(=O)c8ccc9c(c8)C(=O)N(c8ccc([*])cc8C)C9=O)c7)c6)cc5C4=O)c(C)c3)ccc1-2',\n", " '[*]C(=O)c1cccc(N2C(=O)c3ccc(C(=O)c4ccc5c(c4)C(=O)N(c4cccc(C(=O)c6ccc([*])cn6)c4)C5=O)cc3C2=O)c1',\n", " '[*]C(=O)c1ccc(C(=O)c2cccc3c2C(=O)N(c2ccc(C(=O)c4ccc5c(c4)[nH]c4cc(C(=O)c6ccc(N7C(=O)c8ccc([*])cc8C7=O)c(Cl)c6Cl)ccc45)c(Cl)c2Cl)C3=O)cn1',\n", " '[*]C(=O)c1cccc(Oc2ccc3c(c2)C(=O)N(c2ccc(C(=O)c4ccc5ccc(C(=O)c6ccc(N7C(=O)c8cccc(Oc9cccc([*])c9)c8C7=O)c(Cl)c6)cc5c4)c(Cl)c2)C3=O)c1']" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_sample = pd.read_csv('/home/jovyan/simson_training_bolgov/regression/polyGNN_combined_mols_.csv')[['SMILES', 'exp_perm_CO2__Barrer_mean']]\n", "df_sample = df_sample.iloc[:100]\n", "df_sample['SMILES'].to_list()" ] }, { "cell_type": "code", "execution_count": 78, "id": "c7b1a601-9076-4427-844b-bed96dec70e1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SMILESexp_perm_CO2__Barrer_mean
0[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5c...0.669329
1[*]c1ccc(Oc2cccc(Oc3ccc(N4C(=O)c5ccc(C(=O)c6cc...0.404411
2[*]C(=O)c1ccc(N2C(=O)c3cccc(C(=O)c4ccc(Oc5ccc(...0.785077
3[*]C(=O)c1cccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc(C(...0.386811
4[*]Cc1ccc(N2C(=O)c3ccc(Oc4ccc5c(c4)C(=O)N(c4cc...0.733848
.........
95[*]C(=O)c1ccc2c(c1)C(=O)N(c1cccc3c(N4C(=O)c5cc...1.072687
96[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3ccc(N4C(=O...0.705203
97[*]C(=O)c1cccc(N2C(=O)c3ccc(C(=O)c4ccc5c(c4)C(...0.285455
98[*]C(=O)c1ccc(C(=O)c2cccc3c2C(=O)N(c2ccc(C(=O)...1.053971
99[*]C(=O)c1cccc(Oc2ccc3c(c2)C(=O)N(c2ccc(C(=O)c...0.651191
\n", "

100 rows × 2 columns

\n", "
" ], "text/plain": [ " SMILES \\\n", "0 [*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5c... \n", "1 [*]c1ccc(Oc2cccc(Oc3ccc(N4C(=O)c5ccc(C(=O)c6cc... \n", "2 [*]C(=O)c1ccc(N2C(=O)c3cccc(C(=O)c4ccc(Oc5ccc(... \n", "3 [*]C(=O)c1cccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc(C(... \n", "4 [*]Cc1ccc(N2C(=O)c3ccc(Oc4ccc5c(c4)C(=O)N(c4cc... \n", ".. ... \n", "95 [*]C(=O)c1ccc2c(c1)C(=O)N(c1cccc3c(N4C(=O)c5cc... \n", "96 [*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3ccc(N4C(=O... \n", "97 [*]C(=O)c1cccc(N2C(=O)c3ccc(C(=O)c4ccc5c(c4)C(... \n", "98 [*]C(=O)c1ccc(C(=O)c2cccc3c2C(=O)N(c2ccc(C(=O)... \n", "99 [*]C(=O)c1cccc(Oc2ccc3c(c2)C(=O)N(c2ccc(C(=O)c... \n", "\n", " exp_perm_CO2__Barrer_mean \n", "0 0.669329 \n", "1 0.404411 \n", "2 0.785077 \n", "3 0.386811 \n", "4 0.733848 \n", ".. ... \n", "95 1.072687 \n", "96 0.705203 \n", "97 0.285455 \n", "98 1.053971 \n", "99 0.651191 \n", "\n", "[100 rows x 2 columns]" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_sample" ] }, { "cell_type": "code", "execution_count": 74, "id": "875a49f5-3aad-466e-a84f-3d77055a5ae4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "polyGNN model and configuration loaded successfully from pickle files.\n", "The following properties will be modeled: ['exp_perm_CO2__Barrer']\n", "Detected 100 data points for exp_perm_CO2__Barrer\n", "[]\n" ] } ], "source": [ "import math\n", "def try_normalize(smiles):\n", " # функция выполняет перевод молекулы в формат rdkit и обратно. Это фильтрует некорректные молекулы и нормализует их, т.е. приводит к единому виду, чтобы затем можно было отфильтровать дубликаты молекул\n", " try:\n", " return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))\n", " except Exception as e:\n", " # print(e)\n", " return None\n", "\n", "polygnn_ensemble, polygnn_featurizer, path = load_polygnn_model(\"polygnn/trained_models/solubility_and_permeability\", 'cuda')\n", "\n", "def make_pred(smiles):\n", " corrected_smiles = [try_normalize(smile) for smile in smiles]\n", " corrected_smiles = [smile.replace(\"I\", \"[*]\") for smile in smiles]\n", " \n", " #corrected_smiles = [smile.replace('*', 'I') for smile in corrected_smiles]\n", " y, mean, dev = predict_co2_permeability_polygnn(\n", " corrected_smiles, \n", " polygnn_ensemble, \n", " polygnn_featurizer, \n", " path, \n", " 'cuda'\n", " )\n", " polygnn_preds = [10**(pred) for pred in mean]\n", " return polygnn_preds\n", "\n", "df_sample = df.iloc[:100]\n", "new_preds = make_pred(df_sample['Smiles'].to_list())" ] }, { "cell_type": "code", "execution_count": 75, "id": "9c0245f1-3a68-4f9b-b799-407e3436b4b5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0SmilesTgHeN2O2CH4CO2synthesizablenew_preds
00Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1...494.842.695244.7574042.318471.64086148.43644False1.036345
11Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=...508.265.338152.9723926.311180.8646782.37635False2.058726
22O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c...640.9120.475150.063530.904980.069052.35993False21.385293
33Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O...568.044.196920.001910.011340.003620.01418False0.802363
44Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1...548.10142.683270.873808.254092.5206730.04739False132.455960
55Cc1cc(Sc2ccc(nc2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=...501.8212.379681.6854012.252541.0329967.42722False21.316367
66Cc1cc(Sc2ccc(cn2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=...501.8212.379681.6854012.252541.0329967.42722False21.316372
77Clc1cc(ccc1C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1)Cl)...529.6410.150460.020210.256110.132490.51867False1.720072
88Ic1ccc(c(c1)Cl)C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1...529.6410.150460.020210.256110.132490.51867False1.720071
99Ic1ccc(c(c1)C)C(=O)c1cc2ccccc2cc1C(=O)c1ccc(cc...556.3120.844910.054260.436480.065030.84243False3.767171
1010Ic1ccc(c(c1)Sc1cc(cc(c1)C(=O)O)Sc1cc(ccc1C)N1C...470.243.032010.021440.159340.025580.33708True1.093549
1111Cc1cc(cc(c1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C...575.93382.379184.7111327.504594.06784134.46742False262.278064
1212Ic1cc(C)c(c(c1)C)C(c1ccc2c(c1)[nH]c1c2ccc(c1)C...537.39108.948661.078818.351801.6789039.29826True206.577505
1313Ic1cc(Oc2ccc(cc2)C2(c3ccc(cc3)Oc3cc(cc(c3)N3C(...556.564.470370.041590.280070.038870.73329True9.642768
1414Ic1ccc(c(c1)Cl)Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)...481.624.674110.014640.173880.053330.64280False1.471597
1515Clc1cc(ccc1Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)Cl)I...481.624.674110.014640.173880.053330.64280False1.471597
1616Cc1cc(ccc1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O...596.8818.216500.099350.680890.120152.16558True2.989342
1717Ic1ccc(c(c1)C)C(c1ccc(cc1C)C(c1ccc(c(c1)C)N1C(...467.0758.201490.519753.978240.6884215.89978True35.433246
1818O=C1c2cccc(c2C(=O)N1c1ccc(c(c1)C)C(c1ccc(cc1C)...467.0758.201490.519753.978240.6884215.89978True35.433265
1919Cc1cc(c(cc1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C...668.9535.184910.138041.234730.074194.58177False7.471155
2020O=C1c2cc(ccc2c2c1cc(cc2)C(C(F)(F)F)(C(F)(F)F)c...543.52111.916301.339309.698740.9523324.32080False241.805804
2121Ic1ccc(cn1)C(c1cc(Cl)cc(c1)C(c1ccc(nc1)N1C(=O)...508.9815.303207.2732975.862468.07197315.19387False8.272223
2222O=C1c2cccc(c2C(=O)N1c1cc(C)c(c(c1)C)S(=O)(=O)c...628.4839.343090.184961.187680.024773.25713True25.922022
2323Ic1cc(C)c(c(c1)C)S(=O)(=O)c1ccc2c(c1)ccc(c2)S(...628.4839.343090.184961.187680.024773.25713True25.922022
2424Ic1cc(Sc2cc(cc(c2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O...564.081.079220.001460.006290.002170.01674True0.422021
2525Ic1ccc(cn1)C(=O)c1cc(C)c(c(c1)C)C(=O)c1ccc(cn1...564.3311.3956111.7386380.006088.36759438.07856False6.608806
2626Ic1ccc(nc1)C(=O)c1c(C)cc(cc1C)C(=O)c1ccc(nc1)n...564.3311.3956111.7386380.006088.36759438.07856False6.608806
2727Ic1cc(cc(c1)C(F)(F)F)C(=O)c1cc(C)c(cc1C)C(=O)c...526.2547.175890.255971.835500.381714.23498False14.162805
2828Ic1ccc(cn1)S(=O)(=O)c1cc(C)c(c(c1)C)c1c(C)cc(c...618.7921.181748.2896854.048870.66770251.39951False41.374632
2929Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1cc(C)c(c(c1...562.21208.210512.2198715.290972.5838358.80082False373.653450
\n", "
" ], "text/plain": [ " Unnamed: 0 Smiles Tg \\\n", "0 0 Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... 494.84 \n", "1 1 Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... 508.26 \n", "2 2 O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... 640.91 \n", "3 3 Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... 568.04 \n", "4 4 Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... 548.10 \n", "5 5 Cc1cc(Sc2ccc(nc2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=... 501.82 \n", "6 6 Cc1cc(Sc2ccc(cn2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=... 501.82 \n", "7 7 Clc1cc(ccc1C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1)Cl)... 529.64 \n", "8 8 Ic1ccc(c(c1)Cl)C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1... 529.64 \n", "9 9 Ic1ccc(c(c1)C)C(=O)c1cc2ccccc2cc1C(=O)c1ccc(cc... 556.31 \n", "10 10 Ic1ccc(c(c1)Sc1cc(cc(c1)C(=O)O)Sc1cc(ccc1C)N1C... 470.24 \n", "11 11 Cc1cc(cc(c1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C... 575.93 \n", "12 12 Ic1cc(C)c(c(c1)C)C(c1ccc2c(c1)[nH]c1c2ccc(c1)C... 537.39 \n", "13 13 Ic1cc(Oc2ccc(cc2)C2(c3ccc(cc3)Oc3cc(cc(c3)N3C(... 556.56 \n", "14 14 Ic1ccc(c(c1)Cl)Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)... 481.62 \n", "15 15 Clc1cc(ccc1Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)Cl)I... 481.62 \n", "16 16 Cc1cc(ccc1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O... 596.88 \n", "17 17 Ic1ccc(c(c1)C)C(c1ccc(cc1C)C(c1ccc(c(c1)C)N1C(... 467.07 \n", "18 18 O=C1c2cccc(c2C(=O)N1c1ccc(c(c1)C)C(c1ccc(cc1C)... 467.07 \n", "19 19 Cc1cc(c(cc1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C... 668.95 \n", "20 20 O=C1c2cc(ccc2c2c1cc(cc2)C(C(F)(F)F)(C(F)(F)F)c... 543.52 \n", "21 21 Ic1ccc(cn1)C(c1cc(Cl)cc(c1)C(c1ccc(nc1)N1C(=O)... 508.98 \n", "22 22 O=C1c2cccc(c2C(=O)N1c1cc(C)c(c(c1)C)S(=O)(=O)c... 628.48 \n", "23 23 Ic1cc(C)c(c(c1)C)S(=O)(=O)c1ccc2c(c1)ccc(c2)S(... 628.48 \n", "24 24 Ic1cc(Sc2cc(cc(c2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O... 564.08 \n", "25 25 Ic1ccc(cn1)C(=O)c1cc(C)c(c(c1)C)C(=O)c1ccc(cn1... 564.33 \n", "26 26 Ic1ccc(nc1)C(=O)c1c(C)cc(cc1C)C(=O)c1ccc(nc1)n... 564.33 \n", "27 27 Ic1cc(cc(c1)C(F)(F)F)C(=O)c1cc(C)c(cc1C)C(=O)c... 526.25 \n", "28 28 Ic1ccc(cn1)S(=O)(=O)c1cc(C)c(c(c1)C)c1c(C)cc(c... 618.79 \n", "29 29 Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1cc(C)c(c(c1... 562.21 \n", "\n", " He N2 O2 CH4 CO2 synthesizable \\\n", "0 2.69524 4.75740 42.31847 1.64086 148.43644 False \n", "1 5.33815 2.97239 26.31118 0.86467 82.37635 False \n", "2 20.47515 0.06353 0.90498 0.06905 2.35993 False \n", "3 4.19692 0.00191 0.01134 0.00362 0.01418 False \n", "4 142.68327 0.87380 8.25409 2.52067 30.04739 False \n", "5 12.37968 1.68540 12.25254 1.03299 67.42722 False \n", "6 12.37968 1.68540 12.25254 1.03299 67.42722 False \n", "7 10.15046 0.02021 0.25611 0.13249 0.51867 False \n", "8 10.15046 0.02021 0.25611 0.13249 0.51867 False \n", "9 20.84491 0.05426 0.43648 0.06503 0.84243 False \n", "10 3.03201 0.02144 0.15934 0.02558 0.33708 True \n", "11 382.37918 4.71113 27.50459 4.06784 134.46742 False \n", "12 108.94866 1.07881 8.35180 1.67890 39.29826 True \n", "13 4.47037 0.04159 0.28007 0.03887 0.73329 True \n", "14 4.67411 0.01464 0.17388 0.05333 0.64280 False \n", "15 4.67411 0.01464 0.17388 0.05333 0.64280 False \n", "16 18.21650 0.09935 0.68089 0.12015 2.16558 True \n", "17 58.20149 0.51975 3.97824 0.68842 15.89978 True \n", "18 58.20149 0.51975 3.97824 0.68842 15.89978 True \n", "19 35.18491 0.13804 1.23473 0.07419 4.58177 False \n", "20 111.91630 1.33930 9.69874 0.95233 24.32080 False \n", "21 15.30320 7.27329 75.86246 8.07197 315.19387 False \n", "22 39.34309 0.18496 1.18768 0.02477 3.25713 True \n", "23 39.34309 0.18496 1.18768 0.02477 3.25713 True \n", "24 1.07922 0.00146 0.00629 0.00217 0.01674 True \n", "25 11.39561 11.73863 80.00608 8.36759 438.07856 False \n", "26 11.39561 11.73863 80.00608 8.36759 438.07856 False \n", "27 47.17589 0.25597 1.83550 0.38171 4.23498 False \n", "28 21.18174 8.28968 54.04887 0.66770 251.39951 False \n", "29 208.21051 2.21987 15.29097 2.58383 58.80082 False \n", "\n", " new_preds \n", "0 1.036345 \n", "1 2.058726 \n", "2 21.385293 \n", "3 0.802363 \n", "4 132.455960 \n", "5 21.316367 \n", "6 21.316372 \n", "7 1.720072 \n", "8 1.720071 \n", "9 3.767171 \n", "10 1.093549 \n", "11 262.278064 \n", "12 206.577505 \n", "13 9.642768 \n", "14 1.471597 \n", "15 1.471597 \n", "16 2.989342 \n", "17 35.433246 \n", "18 35.433265 \n", "19 7.471155 \n", "20 241.805804 \n", "21 8.272223 \n", "22 25.922022 \n", "23 25.922022 \n", "24 0.422021 \n", "25 6.608806 \n", "26 6.608806 \n", "27 14.162805 \n", "28 41.374632 \n", "29 373.653450 " ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_sample['new_preds'] = new_preds\n", "df_sample.head(30)" ] }, { "cell_type": "code", "execution_count": 40, "id": "5bf78fae-dd12-4093-a4da-77429214a114", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c1ccc2c(c1)c1cc(ccc1C2(C)C)S(=O)(=O)c1ccc(c(c1Cl)Cl)I)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O)c1cc(cc(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C(=O)O',\n", " 'Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1C)c1cc(C)c(c(c1)C)C(C(F)(F)F)(C(F)(F)F)c1cc(Cl)cc(c1)I)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(Sc2ccc(nc2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'Cc1cc(Sc2ccc(cn2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'Clc1cc(ccc1C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1)Cl)I)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)Cl)C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1)Cl)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)C)C(=O)c1cc2ccccc2cc1C(=O)c1ccc(cc1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)Sc1cc(cc(c1)C(=O)O)Sc1cc(ccc1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'Cc1cc(cc(c1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)C)C(C(F)(F)F)(C(F)(F)F)c1cc(cc(c1)C(F)(F)F)C(C(F)(F)F)(C(F)(F)F)c1cc(C)c(c(c1)C)I',\n", " 'Ic1cc(C)c(c(c1)C)C(c1ccc2c(c1)[nH]c1c2ccc(c1)C(c1c(C)cc(cc1C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)(C)C)(C)C',\n", " 'Ic1cc(Oc2ccc(cc2)C2(c3ccc(cc3)Oc3cc(cc(c3)N3C(=O)c4c(C3=O)cccc4c3ccc4c(c3)C(=O)N(C4=O)I)C(=O)O)c3ccccc3c3c2cccc3)cc(c1)C(=O)O',\n", " 'Ic1ccc(c(c1)Cl)Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)Cl)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Clc1cc(ccc1Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)Cl)I)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Cc1cc(ccc1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C(=O)c1ccc(c(c1)C)I',\n", " 'Ic1ccc(c(c1)C)C(c1ccc(cc1C)C(c1ccc(c(c1)C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)(C)C)(C)C',\n", " 'O=C1c2cccc(c2C(=O)N1c1ccc(c(c1)C)C(c1ccc(cc1C)C(c1ccc(c(c1)C)I)(C)C)(C)C)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(c(cc1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)C)S(=O)(=O)c1ccc(c(c1Cl)Cl)S(=O)(=O)c1cc(C)c(cc1C)I',\n", " 'O=C1c2cc(ccc2c2c1cc(cc2)C(C(F)(F)F)(C(F)(F)F)c1cccc(c1C)I)C(C(F)(F)F)(C(F)(F)F)c1cccc(c1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(cn1)C(c1cc(Cl)cc(c1)C(c1ccc(nc1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)(C)C)(C)C',\n", " 'O=C1c2cccc(c2C(=O)N1c1cc(C)c(c(c1)C)S(=O)(=O)c1ccc2c(c1)ccc(c2)S(=O)(=O)c1cc(C)c(c(c1)C)I)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(C)c(c(c1)C)S(=O)(=O)c1ccc2c(c1)ccc(c2)S(=O)(=O)c1cc(C)c(c(c1)C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(Sc2cc(cc(c2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O)n(c2=O)I)C(=O)O)cc(c1)C(=O)O',\n", " 'Ic1ccc(cn1)C(=O)c1cc(C)c(c(c1)C)C(=O)c1ccc(cn1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(nc1)C(=O)c1c(C)cc(cc1C)C(=O)c1ccc(nc1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1cc(cc(c1)C(F)(F)F)C(=O)c1cc(C)c(cc1C)C(=O)c1cc(cc(c1)C(F)(F)F)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(cn1)S(=O)(=O)c1cc(C)c(c(c1)C)c1c(C)cc(cc1C)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1cc(C)c(c(c1)C)c1c(C)cc(cc1C)C(C(F)(F)F)(C(F)(F)F)c1cc(ccc1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'Cc1cc(cc(c1c1c(C)cc(cc1C)S(=O)(=O)c1ccc(c(c1)C)I)C)S(=O)(=O)c1ccc(c(c1)C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'O=C(c1ccc(c(c1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)C)c1ccc(cc1)C1(c2ccc(cc2)C(=O)c2ccc(c(c2)I)C)c2ccccc2c2c1cccc2',\n", " 'Cc1cc(C)c(c(c1C(c1ccc(cc1)C1(c2ccc(cc2)C(c2c(C)cc(c(c2C)I)C)(C)C)c2ccccc2c2c1cccc2)(C)C)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(c(c1)C)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(c(c1)C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(ccc1Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(c(c1)C)I)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1ccc2c(c1)C(=O)c1c2ccc(c1)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C)C',\n", " 'Cc1ccc(cc1C(C(F)(F)F)(C(F)(F)F)c1ccc2c(c1)C(=O)c1c2ccc(c1)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)I)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1cccc(c1)Cc1cccc(c1)Cc1cccc(c1)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)Oc1cc(Oc2ccc(c(c2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O)n(c2=O)I)C)cc(c1)C(F)(F)F)C',\n", " 'Cc1ccc(cc1Oc1cc(Oc2ccc(c(c2)I)C)cc(c1)C(F)(F)F)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(nc1)Oc1ccc(cc1)C1(c2ccc(cc2)Oc2ccc(nc2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O)n(c2=O)I)c2ccccc2c2c1cccc2',\n", " 'Ic1ccc(cn1)Oc1ccc(cc1)C1(c2ccc(cc2)Oc2ccc(cn2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O)n(c2=O)I)c2ccccc2c2c1cccc2',\n", " 'Cc1ccc(cc1I)Oc1ccc(cc1)N(c1ccccc1)c1ccc(cc1)Oc1ccc(c(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C',\n", " 'OC(=O)c1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)I)C)C(C(F)(F)F)(C(F)(F)F)c1cc(ccc1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C)C(=O)O)C',\n", " 'O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c1cc2ccccc2cc1S(=O)(=O)c1ccc(c(c1Cl)Cl)I)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(cc(c1C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)Cl)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C)c1cc(C)c(c(c1)C)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)Cl)I',\n", " 'Ic1ccc(c(c1)C)Cc1ccc(c(c1)Cc1ccc(c(c1)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C',\n", " 'Cc1cc(ccc1Cc1ccc(c(c1)Cc1ccc(c(c1)C)I)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)S(=O)(=O)c1c(C)cc(cc1C)S(=O)(=O)c1cc(ccc1C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'Ic1ccc(c(c1)S(=O)(=O)c1cc(C)c(c(c1)C)S(=O)(=O)c1cc(ccc1C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'Ic1ccc(c(c1)Cl)C(=O)c1ccc2c(c1)[nH]c1c2ccc(c1)C(=O)c1ccc(cc1Cl)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Cc1ccc(cc1C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)Cl)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)I)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)Cl)C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C)C',\n", " 'Ic1ccc(cc1)C(c1c(C)cc(cc1C)c1cc(C)c(c(c1)C)C(c1ccc(cc1)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)(C)C)(C)C',\n", " 'Cc1ccc(cc1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)Cc1ccc2c(c1)ccc(c2)Cc1ccc(c(c1)I)C',\n", " 'Cc1cc(c(cc1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)C)C(c1cc(C)c(cc1C)C(c1cc(C)c(cc1C)I)(C)C)(C)C',\n", " 'Ic1ccc(nc1)C(c1ccc2c(c1)ccc(c2)C(c1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)(C)C)(C)C',\n", " 'Ic1ccc(nc1)C(=O)c1ccc(cc1)c1ccc(cc1)C(=O)c1ccc(cn1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(c(c1)Sc1ccc2c(c1)cc(cc2)Sc1ccc(c(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C)C',\n", " 'Cc1ccc(cc1Sc1ccc2c(c1)cc(cc2)Sc1ccc(c(c1)I)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(cn1)Oc1cc(C)c(c(c1)C)Oc1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(nc1)Oc1c(C)cc(cc1C)Oc1ccc(nc1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(cc1)Oc1cc(C)c(cc1C)Oc1ccc(cc1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)Cl)Oc1cccc2c1cccc2Oc1ccc(c(c1)Cl)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Clc1cc(ccc1Oc1cccc2c1cccc2Oc1ccc(c(c1)Cl)I)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(c(cc1Sc1ccc(cc1)c1ccc(cc1)Sc1cc(C)c(cc1C)I)C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(nc1)Sc1cc(ccc1C)Sc1ccc(cn1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(nc1)Sc1ccc(c(c1)Sc1ccc(cn1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C',\n", " 'O=C(c1cc(C)c(c(c1)C)I)c1ccc(cc1)N(c1ccccc1)c1ccc(cc1)C(=O)c1c(C)cc(cc1C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(C)c(c(c1)C)C(=O)c1ccc(cc1)N(c1ccccc1)c1ccc(cc1)C(=O)c1cc(C)c(c(c1)C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(cc(c1)C(F)(F)F)C(=O)c1ccc2c(c1)c1cc(ccc1C2(C)C)C(=O)c1cc(cc(c1)C(F)(F)F)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Clc1cc(cc(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C(=O)c1ccc(cc1)c1ccc(cc1)C(=O)c1cc(Cl)cc(c1)I',\n", " 'Cc1cc(cc(c1S(=O)(=O)c1cc(ccc1C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C)c1cc(C)c(c(c1)C)S(=O)(=O)c1ccc(c(c1)I)C',\n", " 'Ic1ccc(c(c1)S(=O)(=O)c1c(C)cc(cc1C)c1cc(C)c(c(c1)C)S(=O)(=O)c1ccc(c(c1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C)C',\n", " 'Clc1c(ccc(c1Cl)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)Oc1ccc2c(c1)[nH]c1c2ccc(c1)Oc1ccc(c(c1Cl)Cl)I',\n", " 'Ic1ccc(c(c1)C)Sc1ccc(cc1Cl)Sc1ccc(cc1C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(c(c1)C)Sc1ccc(c(c1)Cl)Sc1ccc(cc1C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1ccc(cc1)S(=O)(=O)c1ccc2c(c1)c1cc(ccc1C2(C)C)S(=O)(=O)c1ccc(cc1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1cc(C)c(c(c1)C)Cc1ccc2c(c1)Cc1c2ccc(c1)Cc1cc(C)c(c(c1)C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(cc(c1Cc1ccc2c(c1)Cc1c2ccc(c1)Cc1cc(C)c(c(c1)C)I)C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(C)c(c(c1)C)Sc1ccc2c(c1)cc(cc2)Sc1cc(C)c(c(c1)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Cc1cc(cc(c1Sc1ccc2c(c1)cc(cc2)Sc1cc(C)c(c(c1)C)I)C)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(nc1)Sc1c(C)cc(c(c1C)Sc1ccc(cn1)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I)C',\n", " 'Ic1cc(C)c(c(c1)C)Oc1ccc(nc1)Oc1c(C)cc(cc1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(C)c(c(c1)C)Oc1ccc(cn1)Oc1c(C)cc(cc1C)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(ccc1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C(C(F)(F)F)(C(F)(F)F)c1cccc2c1cccc2C(C(F)(F)F)(C(F)(F)F)c1ccc(c(c1)C)I',\n", " 'Ic1cc(C)c(c(c1)C)C(c1ccc(c(c1)Cl)C(c1cc(C)c(c(c1)C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)(C)C)(C)C',\n", " 'O=C1c2cc(ccc2C(=O)N1c1cc(C)c(c(c1)C)C(c1ccc(c(c1)Cl)C(c1cc(C)c(c(c1)C)I)(C)C)(C)C)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Cc1cc(Cc2ccc(cc2)c2ccc(cc2)Cc2cc(C)c(c(c2)C)I)cc(c1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C',\n", " 'O=C(c1ccc2c(c1)[nH]c1c2ccc(c1)C(=O)c1ccc(c(c1)Cl)I)c1ccc(c(c1)Cl)n1c(=O)c2c(c1=O)cc1c(c2)c(=O)n(c1=O)I',\n", " 'Ic1ccc(c(c1)Cl)C(=O)c1ccc2c(c1)ccc(c2)C(=O)c1ccc(c(c1)Cl)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Clc1cc(ccc1C(=O)c1ccc2c(c1)ccc(c2)C(=O)c1ccc(c(c1)Cl)I)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(C)c(c(c1)C)Cc1ccc2c(c1)C(=O)c1c2ccc(c1)Cc1c(C)cc(cc1C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I',\n", " 'Ic1cc(cc(c1)C(=O)O)Oc1ccc(cc1C)Oc1cc(cc(c1)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)C(=O)O',\n", " 'Ic1cc(Oc2ccc(c(c2)C)Oc2cc(cc(c2)N2C(=O)c3c(C2=O)cccc3c2ccc3c(c2)C(=O)N(C3=O)I)C(=O)O)cc(c1)C(=O)O',\n", " 'Cc1cc2c3cc(C)c(cc3S(=O)(=O)c2cc1Oc1ccc(c(c1)C)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I)Oc1ccc(c(c1)C)I',\n", " 'Clc1c(ccc(c1Cl)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O)N(C2=O)I)Cc1ccc(cc1)Cc1ccc(c(c1Cl)Cl)I',\n", " 'Ic1ccc(nc1)S(=O)(=O)c1ccc(c(c1Cl)Cl)S(=O)(=O)c1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I']" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_sample['Smiles'].to_list()" ] }, { "cell_type": "code", "execution_count": 32, "id": "cd87f55e-d88b-4fcf-b2a8-3c7ba6250494", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0SmilesTgHeN2O2CH4CO2synthesizablenew_preds
00Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1...494.842.695244.7574042.318471.64086148.43644False1.036345
11Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=...508.265.338152.9723926.311180.8646782.37635False2.058726
22O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c...640.9120.475150.063530.904980.069052.35993False21.385310
33Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O...568.044.196920.001910.011340.003620.01418False0.802363
44Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1...548.10142.683270.873808.254092.5206730.04739False132.455960
.................................
9595Ic1cc(cc(c1)C(=O)O)Oc1ccc(cc1C)Oc1cc(cc(c1)N1C...528.863.658970.009160.046550.011350.10014True1.777913
9696Ic1cc(Oc2ccc(c(c2)C)Oc2cc(cc(c2)N2C(=O)c3c(C2=...528.863.658970.009160.046550.011350.10014True1.777911
9797Cc1cc2c3cc(C)c(cc3S(=O)(=O)c2cc1Oc1ccc(c(c1)C)...554.7627.719330.370862.363560.101317.45523False10.918775
9898Clc1c(ccc(c1Cl)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1...575.957.172130.013500.266750.210260.83210False2.465392
9999Ic1ccc(nc1)S(=O)(=O)c1ccc(c(c1Cl)Cl)S(=O)(=O)c...639.666.868282.8033826.780620.40880117.95785False4.053435
\n", "

100 rows × 10 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 Smiles Tg \\\n", "0 0 Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... 494.84 \n", "1 1 Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... 508.26 \n", "2 2 O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... 640.91 \n", "3 3 Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... 568.04 \n", "4 4 Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... 548.10 \n", ".. ... ... ... \n", "95 95 Ic1cc(cc(c1)C(=O)O)Oc1ccc(cc1C)Oc1cc(cc(c1)N1C... 528.86 \n", "96 96 Ic1cc(Oc2ccc(c(c2)C)Oc2cc(cc(c2)N2C(=O)c3c(C2=... 528.86 \n", "97 97 Cc1cc2c3cc(C)c(cc3S(=O)(=O)c2cc1Oc1ccc(c(c1)C)... 554.76 \n", "98 98 Clc1c(ccc(c1Cl)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1... 575.95 \n", "99 99 Ic1ccc(nc1)S(=O)(=O)c1ccc(c(c1Cl)Cl)S(=O)(=O)c... 639.66 \n", "\n", " He N2 O2 CH4 CO2 synthesizable \\\n", "0 2.69524 4.75740 42.31847 1.64086 148.43644 False \n", "1 5.33815 2.97239 26.31118 0.86467 82.37635 False \n", "2 20.47515 0.06353 0.90498 0.06905 2.35993 False \n", "3 4.19692 0.00191 0.01134 0.00362 0.01418 False \n", "4 142.68327 0.87380 8.25409 2.52067 30.04739 False \n", ".. ... ... ... ... ... ... \n", "95 3.65897 0.00916 0.04655 0.01135 0.10014 True \n", "96 3.65897 0.00916 0.04655 0.01135 0.10014 True \n", "97 27.71933 0.37086 2.36356 0.10131 7.45523 False \n", "98 7.17213 0.01350 0.26675 0.21026 0.83210 False \n", "99 6.86828 2.80338 26.78062 0.40880 117.95785 False \n", "\n", " new_preds \n", "0 1.036345 \n", "1 2.058726 \n", "2 21.385310 \n", "3 0.802363 \n", "4 132.455960 \n", ".. ... \n", "95 1.777913 \n", "96 1.777911 \n", "97 10.918775 \n", "98 2.465392 \n", "99 4.053435 \n", "\n", "[100 rows x 10 columns]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_sample['new_preds'] = new_preds\n", "df_sample" ] }, { "cell_type": "code", "execution_count": 33, "id": "e6d9e29e-eee0-4c5a-8966-58622a4260f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C(=O)N(C2=O)I'" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_sample['Smiles'][0]" ] }, { "cell_type": "code", "execution_count": 67, "id": "51a57a9b-e8b6-4c25-a7b0-38463ebed37e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing batches: 100%|█████████████████████████| 8/8 [00:00<00:00, 18.36it/s]\n" ] } ], "source": [ "CO2s = []\n", "co2_preds = []\n", "ch4_preds = []\n", "\n", "robson_smiles = sample_df['Smiles'].tolist()\n", "\n", "batch_size_robson = 128\n", "for i in tqdm(range(0, len(robson_smiles), batch_size_robson), desc=\"Processing batches\"):\n", " batch_end = min(i + batch_size_robson, len(robson_smiles))\n", " batch_smiles = robson_smiles[i:batch_end]\n", " \n", " # Step 3: Predict properties for both original and reconstructed\n", " # Original properties\n", " encoding = tokenizer(\n", " batch_smiles,\n", " max_length=256,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " )\n", " \n", " with torch.autocast(dtype=torch.float16, device_type='cuda'):\n", " with torch.no_grad():\n", " orig_predictions = regression_model(encoding['input_ids'].cuda(), encoding['attention_mask'].cuda())\n", " \n", " ch4_scaled = orig_predictions[:, -2].cpu().numpy().reshape(-1, 1)\n", " co2_scaled = orig_predictions[:, -1].cpu().numpy().reshape(-1, 1)\n", "\n", " ch4 = scaler_ch4.inverse_transform(ch4_scaled.astype(np.float64)).flatten()\n", " co2 = scaler_co2.inverse_transform(co2_scaled.astype(np.float64)).flatten()\n", "\n", " co2_preds.extend([float(pred) for pred in co2])\n", " ch4_preds.extend([float(pred) for pred in ch4])\n", "\n", "df_robson['CO2_pred'] = co2_preds\n", "df_robson['CH4_pred'] = ch4_preds\n" ] }, { "cell_type": "code", "execution_count": null, "id": "253af0c3-0ced-417b-a572-034c4cfb601f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 70, "id": "5aec327c-1fd9-4ebc-aa25-74ca8749b477", "metadata": {}, "outputs": [], "source": [ "if 'P_CO2/P_CH4' not in df_robson.columns:\n", " df_robson['P_CO2/P_CH4'] = df_robson['CO2_pred']/df_robson['CH4_pred']" ] }, { "cell_type": "code", "execution_count": 71, "id": "7b46c427-5297-4565-b657-27fc7bf5bc38", "metadata": {}, "outputs": [], "source": [ "k_2019 = 2.26e7\n", "n_2019 = -2.401\n", "k_2008 = 5_369_140\n", "n_2008 = -2.636\n", "k_1991 = 1_073_700\n", "n_1991 = -2.6264\n", "\n", "df_robson['upper_bound_Robeson'] = (df_robson['CO2_pred']/k_2019)**(1/n_2019)\n", "df_robson['above_Robeson'] = df_robson['P_CO2/P_CH4'] > df_robson['upper_bound_Robeson']\n", "\n", "df_robson['upper_bound_Robeson_2008'] = (df_robson['CO2_pred']/k_2008)**(1/n_2008)\n", "df_robson['above_Robeson_2008'] = df_robson['P_CO2/P_CH4'] > df_robson['upper_bound_Robeson_2008']\n", "\n", "df_robson['upper_bound_Robeson_1991'] = (df_robson['CO2_pred']/k_1991)**(1/n_1991)\n", "df_robson['above_Robeson_1991'] = df_robson['P_CO2/P_CH4'] > df_robson['upper_bound_Robeson_1991']" ] }, { "cell_type": "code", "execution_count": 72, "id": "3e2a37c5-9534-4cde-8ae9-bad906c39ae9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(log) x_points_on_Robeson2019_line [np.float64(-0.3245689409002736), np.float64(-0.11311188919602205), np.float64(0.0983451625082295), np.float64(0.30980221421248105), np.float64(0.5212592659167326), np.float64(0.7327163176209841), np.float64(0.9441733693252357), np.float64(1.1556304210294872), np.float64(1.3670874727337388), np.float64(1.5785445244379903), np.float64(1.7900015761422419), np.float64(2.0014586278464934), np.float64(2.212915679550745), np.float64(2.424372731254997), np.float64(2.6358297829592483), np.float64(2.8472868346634996), np.float64(3.0587438863677514), np.float64(3.270200938072003), np.float64(3.4816579897762545), np.float64(3.693115041480506), np.float64(3.9045720931847576)]\n", "x_points_on_Robeson2019_line [0.4736211185063034, 0.7707048833980547, 1.2541375248783457, 2.0408083109233783, 3.3209265166816033, 5.404012160362607, 8.79373490580214, 14.309696443824098, 23.285602136958403, 37.89173788621753, 61.65972396131289, 100.33642612016979, 163.27349134558037, 265.688484302277, 432.3443451174293, 703.5368252632127, 1144.8376047731194, 1862.9488809093987, 3031.502912213795, 4933.044594478952, 8027.3480434650555]\n" ] } ], "source": [ "num_bins = 20\n", "x_points_on_Robeson2019_line = [ # референсные значения, использованные при подготовке обучающей выборки для первого этапа обучения\n", " 0.4736211185063034,\n", " 0.7707048833980547,\n", " 1.2541375248783457,\n", " 2.0408083109233783,\n", " 3.3209265166816033,\n", " 5.404012160362607,\n", " 8.79373490580214,\n", " 14.309696443824098,\n", " 23.285602136958403,\n", " 37.89173788621753,\n", " 61.65972396131289,\n", " 100.33642612016979,\n", " 163.27349134558037,\n", " 265.688484302277,\n", " 432.3443451174293,\n", " 703.5368252632127,\n", " 1144.8376047731194,\n", " 1862.9488809093987,\n", " 3031.502912213795,\n", " 4933.044594478952,\n", " 8027.3480434650555]\n", "\n", "log_bin_size = (np.log10(x_points_on_Robeson2019_line[1])-\n", " np.log10(x_points_on_Robeson2019_line[0]))/num_bins\n", "\n", "print('(log) x_points_on_Robeson2019_line', [np.log10(xp) for xp in x_points_on_Robeson2019_line])\n", "print('x_points_on_Robeson2019_line', x_points_on_Robeson2019_line)" ] }, { "cell_type": "code", "execution_count": 73, "id": "4ec4c897-02fa-41a6-bc42-f353a068bace", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_265640/2901639124.py:11: RuntimeWarning:\n", "\n", "invalid value encountered in log10\n", "\n" ] } ], "source": [ "# определение бинов Робсона, для последующей отрисовки матрицы путанности\n", "\n", "def calculate_bin(row, x_col, y_col):\n", " x_point = row[x_col]\n", " y_point = row[y_col]\n", " slope = (1/n_2019)\n", " \n", " bin_for_point = 1\n", " for x_line in x_points_on_Robeson2019_line[1:]:\n", " y_Robeson = (x_line/k_2019)**(1/n_2019) # upper_bound_Robeson_2019\n", " y_bin = 10**(-1/slope * np.log10(x_point/x_line) + np.log10(y_Robeson))\n", " if y_point > y_bin:\n", " break\n", " bin_for_point += 1\n", " return bin_for_point\n", "\n", "df_robson['Robeson_bin'] = df_robson.apply(calculate_bin,\n", " x_col='CO2_pred',\n", " y_col='P_CO2/P_CH4',\n", " axis=1)" ] }, { "cell_type": "code", "execution_count": 74, "id": "576b03a6-8c94-4dfa-ae58-d38652376f69", "metadata": {}, "outputs": [], "source": [ "import ipywidgets\n", "import numpy as np\n", "import pandas as pd\n", "import plotly.express as px\n", "import plotly.graph_objects as go\n", "import plotly.io as pio\n", "\n", "pio.renderers.default = \"notebook\" # fixes duplicate plotly plots problem" ] }, { "cell_type": "code", "execution_count": 77, "id": "d5c23fc9-926e-4b05-8edb-88c872944f9c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0SmilesTgHeN2O2CH4CO2synthesizable
00Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1...494.842.695244.7574042.318471.64086148.43644False
11Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=...508.265.338152.9723926.311180.8646782.37635False
22O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c...640.9120.475150.063530.904980.069052.35993False
33Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O...568.044.196920.001910.011340.003620.01418False
44Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1...548.10142.683270.873808.254092.5206730.04739False
..............................
67269456726945Ic1cccc(c1)Cc1cc(C)c(c(c1)C)c1c(C)cc(cc1C)Cc1c...516.2713.329670.119071.101440.169072.63642False
67269466726946Ic1ccc(nc1)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(...510.728.964547.9225772.929292.17324247.14446False
67269476726947Ic1ccc(cn1)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)...610.244.9075820.24364121.244941.01011650.18364False
67269486726948Ic1ccc(c(c1)C)Oc1ccc2c(c1)Cc1c2ccc(c1)Oc1ccc(c...510.9012.409070.193071.413350.117284.00573True
67269496726949Ic1cc(Sc2ccc3c(c2)cc(cc3)Sc2cc(cc(c2)C(F)(F)F)...462.3114.963420.139161.052520.092982.54411False
\n", "

6726950 rows × 9 columns

\n", "
" ], "text/plain": [ " Unnamed: 0 Smiles \\\n", "0 0 Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... \n", "1 1 Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... \n", "2 2 O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... \n", "3 3 Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... \n", "4 4 Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... \n", "... ... ... \n", "6726945 6726945 Ic1cccc(c1)Cc1cc(C)c(c(c1)C)c1c(C)cc(cc1C)Cc1c... \n", "6726946 6726946 Ic1ccc(nc1)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(... \n", "6726947 6726947 Ic1ccc(cn1)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)... \n", "6726948 6726948 Ic1ccc(c(c1)C)Oc1ccc2c(c1)Cc1c2ccc(c1)Oc1ccc(c... \n", "6726949 6726949 Ic1cc(Sc2ccc3c(c2)cc(cc3)Sc2cc(cc(c2)C(F)(F)F)... \n", "\n", " Tg He N2 O2 CH4 CO2 \\\n", "0 494.84 2.69524 4.75740 42.31847 1.64086 148.43644 \n", "1 508.26 5.33815 2.97239 26.31118 0.86467 82.37635 \n", "2 640.91 20.47515 0.06353 0.90498 0.06905 2.35993 \n", "3 568.04 4.19692 0.00191 0.01134 0.00362 0.01418 \n", "4 548.10 142.68327 0.87380 8.25409 2.52067 30.04739 \n", "... ... ... ... ... ... ... \n", "6726945 516.27 13.32967 0.11907 1.10144 0.16907 2.63642 \n", "6726946 510.72 8.96454 7.92257 72.92929 2.17324 247.14446 \n", "6726947 610.24 4.90758 20.24364 121.24494 1.01011 650.18364 \n", "6726948 510.90 12.40907 0.19307 1.41335 0.11728 4.00573 \n", "6726949 462.31 14.96342 0.13916 1.05252 0.09298 2.54411 \n", "\n", " synthesizable \n", "0 False \n", "1 False \n", "2 False \n", "3 False \n", "4 False \n", "... ... \n", "6726945 False \n", "6726946 False \n", "6726947 False \n", "6726948 True \n", "6726949 False \n", "\n", "[6726950 rows x 9 columns]" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 76, "id": "69c7cfef-8075-41eb-aa2d-2d621ce4d1c4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f5c49709dca144f69bb703be1756c165", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntSlider(value=10, continuous_update=False, description='Num points', max=1000, min=10, step=10)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2f3aca246b424d2c914e9e1fb84b2429", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# отрисовка троек молекул (2 разгонных + 1 сгенерированная) на диаграмме Робсона\n", "\n", "# new_pi_df_unique = new_pi_df.drop_duplicates(['smiles_start'])\n", "sub_df_preds = df_robson.sample(frac=1, random_state=1010)\n", "\n", "\n", "def plot_with_controls(num_points=1):\n", "\n", "\n", " sub_df_preds_plotted = sub_df_preds.iloc[:num_points]\n", "\n", " x_stuff = 'CO2'\n", " y_stuff = 'P_CO2/P_CH4'\n", "\n", " \n", " log_scale = True\n", " fig = px.scatter(sub_df_preds_plotted, x=x_stuff, y=y_stuff, log_x=log_scale, log_y=log_scale, \n", " color_discrete_sequence=['red'],\n", " )\n", "\n", " \n", " fig2 = px.line(sub_df_preds.iloc[:1000],\n", " x='CO2_pred', y='upper_bound_Robeson', labels='Robeson_2019')\n", " fig2.update_traces(line_color='red', line_width=1)\n", " fig3 = px.line(sub_df_preds.iloc[:1000],\n", " x='CO2_pred', y='upper_bound_Robeson_2008', labels='Robeson_2008')\n", " fig3.update_traces(line_color='green', line_width=1)\n", " fig4 = px.line(sub_df_preds.iloc[:1000],\n", " x='CO2_pred', y='upper_bound_Robeson_1991', labels='Robeson_1991')\n", " fig4.update_traces(line_color='blue', line_width=1)\n", " fig = go.Figure(data=fig.data + fig2.data + fig3.data + fig4.data, layout = fig.layout)\n", " fig.update_layout(fig2.layout)\n", "\n", " fig.update_layout(width=1100, height=550, margin=dict(l=40, r=40, t=10, b=10),)\n", "\n", " list_of_all_arrows = []\n", " \n", " \n", " \n", " fig.show()\n", "\n", "\n", "num_points_slider = ipywidgets.IntSlider(\n", " # value=len(sub_df_preds),\n", " value=10,\n", " # min=1000,\n", " min=10,\n", " # max=len(sub_df_preds),\n", " max=min(len(sub_df_preds),1000),\n", " step=10,\n", " description='Num points',\n", " disabled=False,\n", " continuous_update=False,\n", " orientation='horizontal',\n", " readout=True,\n", " readout_format='d'\n", ")\n", "\n", "\n", "\n", "output = ipywidgets.interactive_output(plot_with_controls,\n", " {\n", " 'num_points': num_points_slider,\n", " })\n", "display(num_points_slider, output)" ] }, { "cell_type": "code", "execution_count": null, "id": "07e0e913-820e-4041-bae4-3be8f85d292c", "metadata": {}, "outputs": [], "source": [ "\n", "reconstructed_ids = model.generate(best_embedding)\n", "reconstructed_smiles = [tokenizer.decode(seq, skip_special_tokens=True) for seq in reconstructed_ids]\n", "mol_gen = Chem.MolFromSmiles(reconstructed_smiles[0])\n", "\n", "print(reconstructed_smiles[0])\n", "\n", "if mol_gen is not None:\n", " print('Valid')\n", "else:\n", " print('Skill issue')" ] }, { "cell_type": "code", "execution_count": null, "id": "93af70eb-5e42-4372-b0f9-4e03577742bb", "metadata": {}, "outputs": [], "source": [ "def comprehensive_generation(model, tokenizer, val_df, base_embeddings):\n", " \"\"\"Generate molecules with different enhancement levels\"\"\"\n", " \n", " all_results = []\n", " extrapolation_factors = [factor / 1000 for factor in range(0, 100, 10)]\n", " print(extrapolation_factors)\n", " for factor in extrapolation_factors:\n", " co2_results = generate_gradient(\n", " model, tokenizer, base_embeddings, val_df, 'CO2', extrapolation_factor=factor)\n", "\n", " all_results.extend(co2_results)# + ch4_results + dual_results)\n", " \n", " return all_results\n", "\n", "# Run comprehensive generation\n", "all_generated_molecules = comprehensive_generation(model, tokenizer, df, base_embeddings)" ] }, { "cell_type": "code", "execution_count": null, "id": "8eeed790-7e65-4ffa-b7cc-c92036ba0aba", "metadata": {}, "outputs": [], "source": [ "def predict_molecule_properties(smiles_list, regression_model, tokenizer, scaler_ch4, scaler_co2, \n", " batch_size=32, max_length=512, \n", " baseline_ch4=None, baseline_co2=None):\n", " \"\"\"\n", " Predict CO2 and CH4 permeability properties for a list of SMILES\n", "\n", " Returns:\n", " ch4_predictions: Array of CH4 permeability predictions\n", " CO2ictions: Array of CO2 permeability predictions\n", " molecules_exceeding_baselines: List of dicts for molecules exceeding baselines\n", " \"\"\"\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " regression_model.to(device)\n", " regression_model.eval()\n", "\n", " all_predictions = []\n", " with torch.no_grad():\n", " for i in range(0, len(smiles_list), batch_size):\n", " batch_smiles = smiles_list[i:i+batch_size]\n", "\n", " tokens = tokenizer(\n", " batch_smiles,\n", " max_length=max_length,\n", " truncation=True,\n", " padding='max_length',\n", " return_tensors='pt'\n", " )\n", " input_ids = tokens['input_ids'].to(device)\n", " attention_mask = tokens['attention_mask'].to(device)\n", "\n", " predictions = regression_model(input_ids, attention_mask)\n", " all_predictions.append(predictions.cpu().numpy())\n", "\n", " all_predictions = np.vstack(all_predictions)\n", " ch4_scaled = all_predictions[:, -2].reshape(-1, 1)\n", " co2_scaled = all_predictions[:, -1].reshape(-1, 1)\n", "\n", " ch4_predictions = scaler_ch4.inverse_transform(ch4_scaled).flatten()\n", " CO2ictions = scaler_co2.inverse_transform(co2_scaled).flatten()\n", "\n", " # Use provided baselines or max in predictions\n", " if baseline_ch4 is None:\n", " baseline_ch4 = ch4_predictions.max()\n", " if baseline_co2 is None:\n", " baseline_co2 = CO2ictions.max()\n", "\n", " molecules_exceeding_baselines = []\n", " for idx, (smiles, ch4_pred, CO2) in enumerate(zip(smiles_list, ch4_predictions, CO2ictions)):\n", " exceeds_ch4 = ch4_pred > baseline_ch4\n", " exceeds_co2 = CO2 > baseline_co2\n", " if exceeds_ch4 or exceeds_co2:\n", " molecules_exceeding_baselines.append({\n", " \"index\": idx,\n", " \"smiles\": smiles,\n", " \"predicted_CH4\": ch4_pred,\n", " \"predicted_CO2\": CO2,\n", " \"exceeds_CH4\": exceeds_ch4,\n", " \"exceeds_CO2\": exceeds_co2\n", " })\n", "\n", " return ch4_predictions, CO2ictions, molecules_exceeding_baselines\n", "\n", "\n", "def analyze_generated_molecules_with_properties(generated_molecules, regression_model, tokenizer, \n", " scaler_ch4, scaler_co2, val_df):\n", " \"\"\"\n", " Analyze generated molecules and predict their actual properties\n", " \n", " Args:\n", " generated_molecules: List of generated molecule dictionaries\n", " regression_model: Trained regression model\n", " tokenizer: SMILES tokenizer\n", " scaler_ch4, scaler_co2: Property scalers\n", " val_df: Validation dataframe for baseline comparison\n", " \n", " Returns:\n", " enhanced_results: DataFrame with predicted properties\n", " \"\"\"\n", " \n", " # Extract SMILES from generated molecules\n", " generated_smiles = []\n", " for mol in generated_molecules:\n", " if mol['is_valid']: # Only predict for valid molecules\n", " generated_smiles.append(mol['generated_smiles'])\n", " \n", " if not generated_smiles:\n", " print(\"No valid molecules to analyze!\")\n", " return None\n", " \n", " print(f\"Predicting properties for {len(generated_smiles)} valid generated molecules...\")\n", " \n", " # Predict properties\n", " pred_ch4, pred_co2, molecules_exceeding_baselines = predict_molecule_properties(\n", " generated_smiles, regression_model, tokenizer, scaler_ch4, scaler_co2,\n", " baseline_ch4=val_df['CH4'].max(),\n", " baseline_co2=val_df['CO2'].max()\n", " )\n", " \n", " # Create results dataframe\n", " results_data = []\n", " pred_idx = 0\n", " \n", " for mol in generated_molecules:\n", " if mol['is_valid']:\n", " results_data.append({\n", " 'generated_smiles': mol['generated_smiles'],\n", " 'target_property': mol['target_property'],\n", " 'extrapolation_factor': mol['extrapolation_factor'],\n", " 'predicted_CH4': pred_ch4[pred_idx],\n", " 'predicted_CO2': pred_co2[pred_idx],\n", " 'is_valid': mol['is_valid']\n", " })\n", " pred_idx += 1\n", " else:\n", " # Include invalid molecules with NaN predictions\n", " results_data.append({\n", " 'generated_smiles': mol['generated_smiles'],\n", " 'target_property': mol['target_property'],\n", " 'extrapolation_factor': mol['extrapolation_factor'],\n", " 'predicted_CH4': np.nan,\n", " 'predicted_CO2': np.nan,\n", " 'is_valid': mol['is_valid']\n", " })\n", " \n", " results_df = pd.DataFrame(results_data)\n", " \n", " # Calculate baseline statistics from validation set\n", " baseline_ch4_mean = val_df['CH4'].mean()\n", " baseline_ch4_max = val_df['CH4'].max()\n", " baseline_co2_mean = val_df['CO2'].mean()\n", " baseline_co2_max = val_df['CO2'].max()\n", " \n", " # Calculate enhancement statistics for valid molecules only\n", " valid_results = results_df[results_df['is_valid']].copy()\n", " \n", " if len(valid_results) > 0:\n", " print(f\"\\n{'='*80}\")\n", " print(f\"INDIVIDUAL MOLECULE PROPERTIES\")\n", " print(f\"{'='*80}\")\n", " \n", " # Display properties for each individual molecule\n", " for idx, (_, mol) in enumerate(valid_results.iterrows(), 1):\n", " ch4_vs_baseline = (mol['predicted_CH4'] / baseline_ch4_max - 1) * 100\n", " co2_vs_baseline = (mol['predicted_CO2'] / baseline_co2_max - 1) * 100\n", " \n", " print(f\" SMILES: {mol['generated_smiles']}\")\n", " print(f\" Target: {mol['target_property']} (factor: {mol['extrapolation_factor']})\")\n", " print(f\" CH₄ Permeability: {mol['predicted_CH4']:.4f} ({ch4_vs_baseline:+.1f}% vs baseline max)\")\n", " print(f\" CO₂ Permeability: {mol['predicted_CO2']:.4f} ({co2_vs_baseline:+.1f}% vs baseline max)\")\n", " \n", " # Enhancement indicators\n", " enhancement_flags = []\n", " if mol['predicted_CH4'] > baseline_ch4_max:\n", " enhancement_flags.append(\"CH₄↑\")\n", " if mol['predicted_CO2'] > baseline_co2_max:\n", " enhancement_flags.append(\"CO₂↑\")\n", " if enhancement_flags:\n", " print(f\" Enhancements: {', '.join(enhancement_flags)}\")\n", " \n", " print(\"-\" * 70)\n", " \n", " print(f\"Baseline Dataset Statistics:\")\n", " print(f\" CH4 - Mean: {baseline_ch4_mean:.4f}, Max: {baseline_ch4_max:.4f}\")\n", " print(f\" CO2 - Mean: {baseline_co2_mean:.4f}, Max: {baseline_co2_max:.4f}\")\n", " \n", " print(f\"\\nGenerated Molecules Statistics:\")\n", " print(f\" CH4 - Mean: {valid_results['predicted_CH4'].mean():.4f}, Max: {valid_results['predicted_CH4'].max():.4f}\")\n", " print(f\" CO2 - Mean: {valid_results['predicted_CO2'].mean():.4f}, Max: {valid_results['predicted_CO2'].max():.4f}\")\n", " \n", " # Check for improvements\n", " ch4_improvements = valid_results['predicted_CH4'] > baseline_ch4_max\n", " co2_improvements = valid_results['predicted_CO2'] > baseline_co2_max\n", " print(f\"\\nEnhancement Analysis:\")\n", " print(f\" Molecules exceeding baseline CH4 max: {ch4_improvements.sum()}/{len(valid_results)} ({ch4_improvements.mean()*100:.1f}%)\")\n", " print(f\" Molecules exceeding baseline CO2 max: {co2_improvements.sum()}/{len(valid_results)} ({co2_improvements.mean()*100:.1f}%)\")\n", " # Enhancement by target property\n", " property_analysis = valid_results.groupby('target_property').agg({\n", " 'predicted_CH4': ['mean', 'max', 'count'],\n", " 'predicted_CO2': ['mean', 'max', 'count']\n", " }).round(4)\n", " \n", " print(f\"\\nProperty Enhancement by Generation Target:\")\n", " \n", " # Enhancement by extrapolation factor\n", " factor_analysis = valid_results.groupby('extrapolation_factor').agg({\n", " 'predicted_CH4': ['mean', 'max'],\n", " 'predicted_CO2': ['mean', 'max']\n", " }).round(4)\n", " \n", " print(f\"\\nProperty Enhancement by Extrapolation Factor:\")\n", " print(\"\\nMolecules that exceed baselines:\")\n", " for mol in molecules_exceeding_baselines:\n", " print(f\"SMILES: {mol['smiles']}\")\n", " print(f\" CH₄ Predicted: {mol['predicted_CH4']:.4f} (Exceeds baseline: {mol['exceeds_CH4']})\")\n", " print(f\" CO₂ Predicted: {mol['predicted_CO2']:.4f} (Exceeds baseline: {mol['exceeds_CO2']})\\n\")\n", "\n", " \n", " \n", " else:\n", " print(\"No valid molecules generated for property prediction!\")\n", " \n", " return results_df, molecules_exceeding_baselines\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3df27d92-9da1-433c-b56b-27cd68b8004f", "metadata": {}, "outputs": [], "source": [ "results_df, molecules_exceeding_baselines = analyze_generated_molecules_with_properties(\n", " all_generated_molecules, regression_model, tokenizer, scaler_ch4, scaler_co2, df\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "262b158d-afe5-4f44-85df-c49c175fd454", "metadata": {}, "outputs": [], "source": [ "exceeding_mols = []\n", "for mol_info in molecules_exceeding_baselines:\n", " if mol_info['exceeds_CH4'] and mol_info['exceeds_CO2']:\n", " exceeding_mols.append(mol_info['smiles'])\n", "\n", "exceeding_mols" ] }, { "cell_type": "code", "execution_count": null, "id": "c133d4c6-cae8-4a43-be7e-8871d0403f1e", "metadata": {}, "outputs": [], "source": [ "smiles_generated_col = 'smiles'\n", "gen_df = pd.DataFrame({smiles_generated_col: exceeding_mols})" ] }, { "cell_type": "code", "execution_count": null, "id": "503ba331-1748-430c-96ba-a6de9658b053", "metadata": {}, "outputs": [], "source": [ "def split_dot(smiles):\n", " if '.' in smiles:\n", " parts = smiles.split('.')\n", " # выбираем наиболее длинную отдельную часть молекулы в качестве самой молекулы\n", " smiles = sorted(parts, key=lambda x:len(x), reverse=True)[0]\n", " return smiles\n", "\n", "partial_mols_count = sum(gen_df[smiles_generated_col].str.contains('.', regex=False))\n", "print(f'{partial_mols_count} partial mols fixed')\n", "if partial_mols_count > 0:\n", " gen_df[smiles_generated_col] = gen_df[smiles_generated_col].transform(split_dot)" ] }, { "cell_type": "code", "execution_count": null, "id": "ad57b93d-4c4c-4e08-b4ba-6beaa4ab55d6", "metadata": {}, "outputs": [], "source": [ "def try_normalize(smiles):\n", " # функция выполняет перевод молекулы в формат rdkit и обратно. Это фильтрует некорректные молекулы и нормализует их, т.е. приводит к единому виду, чтобы затем можно было отфильтровать дубликаты молекул\n", " try:\n", " return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))\n", " except Exception as e:\n", " # print(e)\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "d4bc739e-eb6d-42b7-a8e9-4aeb82d6f93e", "metadata": {}, "outputs": [], "source": [ "symbols = ['*', '[*]', 'I']\n", "contain_counts = {}\n", "for symbol in symbols:\n", " contain_counts[symbol] = gen_df[smiles_generated_col].apply(lambda smiles: symbol in smiles).sum()\n", " print(f'contain {symbol}: {contain_counts[symbol]} mols')\n", "\n", "most_frequent_symbol = max(contain_counts, key=contain_counts.get)\n", "assert most_frequent_symbol == 'I' # else think about it" ] }, { "cell_type": "code", "execution_count": null, "id": "11079464-727b-4b44-8fd7-c674edf76b70", "metadata": {}, "outputs": [], "source": [ "gen_df[smiles_generated_col] = gen_df[smiles_generated_col].transform(lambda x: x.replace('[*]', 'I'))\n", "gen_df[smiles_generated_col] = gen_df[smiles_generated_col].transform(lambda x: x.replace('*', 'I'))" ] }, { "cell_type": "code", "execution_count": null, "id": "a2893bdd-c582-4380-bbed-602c5bc22c22", "metadata": {}, "outputs": [], "source": [ "smiles_normalized_col = 'SMILES_normalized'\n", "\n", "gen_df[smiles_normalized_col] = gen_df[smiles_generated_col].apply(try_normalize)\n", "\n", "n_before = len(gen_df)\n", "gen_df = gen_df.loc[~gen_df[smiles_normalized_col].isnull()]\n", "n_after = len(gen_df)\n", "\n", "print(f'deleted: {n_before-n_after} incorrect mols'\n", " f' (before: {n_before} mols, after: {n_after} mols)')" ] }, { "cell_type": "code", "execution_count": null, "id": "75d34fa6-1e40-4f63-b134-938d1b9cd309", "metadata": {}, "outputs": [], "source": [ "temp_smiles_col = smiles_normalized_col+'2'\n", "gen_df[temp_smiles_col] = gen_df[smiles_normalized_col].apply(try_normalize)\n", "\n", "n_before = len(gen_df)\n", "gen_df = gen_df.loc[~gen_df[temp_smiles_col].isnull()]\n", "n_after = len(gen_df)\n", "\n", "gen_df = gen_df.drop(columns=[temp_smiles_col])\n", "\n", "print(f'deleted: {n_before-n_after} incorrect mols'\n", " f' (before: {n_before} mols, after: {n_after} mols)')" ] }, { "cell_type": "code", "execution_count": null, "id": "32c392d9-145d-46ba-8cb7-9b05dd453514", "metadata": {}, "outputs": [], "source": [ "n_before = len(gen_df)\n", "gen_df = gen_df.drop_duplicates(subset=[smiles_normalized_col])\n", "n_after = len(gen_df)\n", "\n", "print(f'deleted: {n_before-n_after} duplicates'\n", " f' (before: {n_before} mols, after: {n_after} mols)')" ] }, { "cell_type": "code", "execution_count": null, "id": "740e57e4-ad55-4638-ab0f-f753a242ee0a", "metadata": {}, "outputs": [], "source": [ "def filter_two_endpoints(smiles):\n", " return smiles if smiles.count('I') == 2 else None\n", "\n", "\n", "gen_df[smiles_generated_col] = gen_df[smiles_generated_col].apply(filter_two_endpoints)\n", "\n", "n_before = len(gen_df)\n", "gen_df = gen_df.loc[~gen_df[smiles_generated_col].isnull()]\n", "n_after = len(gen_df)\n", "\n", "print(f'deleted: {n_before-n_after} incorrect mols'\n", " f' (before: {n_before} mols, after: {n_after} mols)')" ] }, { "cell_type": "code", "execution_count": null, "id": "2cf9b3e8-caa6-4cd3-b76f-c1a3bcfe47db", "metadata": {}, "outputs": [], "source": [ "def filter_matching_endpoint_bonds(smiles):\n", " # фильтрация на предмет того, что типы связей у эндпоинтов должны быть одинаковы\n", " try:\n", " mol = Chem.MolFromSmiles(smiles.replace('I', '[*]'))\n", " inds = tuple(mol.GetSubstructMatches(Chem.MolFromSmarts(\"[#0]~*\")))\n", " inds = tuple(zip(*inds))\n", " star_inds = list(inds[0])\n", " connector_inds = list(inds[1])\n", " b1_type = mol.GetBondBetweenAtoms(star_inds[0], connector_inds[0]).GetBondType()\n", " b2_type = mol.GetBondBetweenAtoms(star_inds[1], connector_inds[1]).GetBondType()\n", " if b1_type != b2_type:\n", " return None\n", " else:\n", " return smiles\n", " except:\n", " return None\n", "gen_df[smiles_generated_col] = gen_df[smiles_generated_col].apply(filter_matching_endpoint_bonds)\n", "\n", "n_before = len(gen_df)\n", "gen_df = gen_df.loc[~gen_df[smiles_generated_col].isnull()]\n", "n_after = len(gen_df)\n", "\n", "print(f'deleted: {n_before-n_after} incorrect mols'\n", " f' (before: {n_before} mols, after: {n_after} mols)')" ] }, { "cell_type": "code", "execution_count": null, "id": "4a889461-90fd-4e00-a865-8dad2edccc73", "metadata": {}, "outputs": [], "source": [ "molecules_exceeding_baselines" ] }, { "cell_type": "code", "execution_count": null, "id": "99f92d05-4166-4b2a-a586-78f38e9a3041", "metadata": {}, "outputs": [], "source": [ "gen_df = gen_df.reset_index(drop=True)\n", "gen_df" ] }, { "cell_type": "code", "execution_count": null, "id": "d3e2cdab-226d-421a-99c1-516c6709671f", "metadata": {}, "outputs": [], "source": [ "exceeding_mols_co2 = []\n", "exceeding_mols_ch4 = []\n", "\n", "for mol in gen_df['smiles']:\n", " for candidat_mol_info in molecules_exceeding_baselines:\n", " if mol == candidat_mol_info['smiles']:\n", " exceeding_mols_co2.append(float(candidat_mol_info['predicted_CO2']))\n", " exceeding_mols_ch4.append(float(candidat_mol_info['predicted_CH4']))\n", " break\n", "\n", "exceeding_mols_co2, exceeding_mols_ch4" ] }, { "cell_type": "code", "execution_count": null, "id": "f6fc01a8-9bd8-451a-993e-fe3329e70ee8", "metadata": {}, "outputs": [], "source": [ "gen_df['smiles'].to_list()" ] }, { "cell_type": "code", "execution_count": null, "id": "ed1ae34b-0f9b-4ef3-948e-cec35411764b", "metadata": {}, "outputs": [], "source": [ "gen_df['predicted_CO2'] = exceeding_mols_co2\n", "gen_df['predicted_CH4'] = exceeding_mols_ch4\n", "gen_df" ] }, { "cell_type": "code", "execution_count": null, "id": "db560332-8eeb-438b-86b6-210a64efceea", "metadata": {}, "outputs": [], "source": [ "final_output = '/home/jovyan/simson_training_bolgov/regression/exceeding_mols.csv'\n", "gen_df.to_csv(final_output, index=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "8ce9bb58-943e-4ee7-b22b-3f7e91cce862", "metadata": {}, "outputs": [], "source": [ "df_robson = gen_df\n", "df_robson['P_CO2/P_CH4'] = gen_df['predicted_CO2']/gen_df['predicted_CH4']" ] }, { "cell_type": "code", "execution_count": null, "id": "7f294eef-86cc-4aed-8105-35e3c261bd69", "metadata": {}, "outputs": [], "source": [ "t = df_robson['smiles'][0]\n", "df_robson.iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "56e42e5e-4988-444f-a476-3ca364fc9d95", "metadata": {}, "outputs": [], "source": [ "df[df['Smiles'] == t]" ] }, { "cell_type": "code", "execution_count": null, "id": "9f3cb85b-a441-4491-bd3a-f96f619fd0ad", "metadata": {}, "outputs": [], "source": [ "df_robson['in_synth_DB'] = df_robson['smiles'].isin(df['Smiles'])\n", "df_robson" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:.mlspace-bolgov_simson_training]", "language": "python", "name": "conda-env-.mlspace-bolgov_simson_training-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }