{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f72ec62c-c849-45b2-9321-b913c9f32979", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Encoder parameters loaded\n", "Encoder parameters frozen\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from transformers import BertConfig, BertModel, AutoTokenizer\n", "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, EarlyStoppingCallback\n", "from torch.utils.data import Dataset\n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(0)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = outputs.last_hidden_state\n", " hidden_states = self.dropout(hidden_states)\n", " pooled = global_ap(hidden_states)\n", " out = self.linear(pooled)\n", " return out\n", "\n", "class SimSonDecoder(nn.Module):\n", " def __init__(self, embedding_dim, hidden_dim, vocab_size, max_len):\n", " super(SimSonDecoder, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.hidden_dim = hidden_dim\n", " self.vocab_size = vocab_size\n", " self.max_len = max_len\n", " \n", " # Project embedding to hidden dimension\n", " self.embedding_projection = nn.Linear(embedding_dim, hidden_dim)\n", " \n", " # Token embeddings for decoder input\n", " self.token_embeddings = nn.Embedding(vocab_size, hidden_dim)\n", " self.position_embeddings = nn.Embedding(max_len, hidden_dim)\n", " \n", " # Transformer decoder layers\n", " decoder_layer = nn.TransformerDecoderLayer(\n", " d_model=hidden_dim,\n", " nhead=12,\n", " dim_feedforward=2048,\n", " dropout=0.1,\n", " batch_first=True\n", " )\n", " self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", " \n", " # Output projection to vocabulary\n", " self.output_projection = nn.Linear(hidden_dim, vocab_size)\n", " self.dropout = nn.Dropout(0.1)\n", " \n", " # Layer normalization\n", " self.layer_norm = nn.LayerNorm(hidden_dim)\n", " \n", " def forward(self, molecule_embeddings, decoder_input_ids, decoder_attention_mask=None):\n", " batch_size, seq_len = decoder_input_ids.shape\n", " device = decoder_input_ids.device\n", " \n", " # Project molecule embeddings to memory for cross-attention\n", " memory = self.embedding_projection(molecule_embeddings) # [batch, hidden_dim]\n", " memory = memory.unsqueeze(1) # [batch, 1, hidden_dim]\n", " \n", " # Create decoder input embeddings\n", " token_emb = self.token_embeddings(decoder_input_ids) # [batch, seq_len, hidden_dim]\n", " \n", " # Add positional embeddings\n", " positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)\n", " pos_emb = self.position_embeddings(positions)\n", " \n", " decoder_inputs = self.layer_norm(token_emb + pos_emb)\n", " decoder_inputs = self.dropout(decoder_inputs)\n", " \n", " # Create causal mask for decoder\n", " tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)\n", " \n", " # Create attention mask for decoder inputs\n", " if decoder_attention_mask is not None:\n", " # Convert padding mask to additive mask\n", " tgt_key_padding_mask = (decoder_attention_mask == 0)\n", " else:\n", " tgt_key_padding_mask = None\n", " \n", " # Apply transformer decoder\n", " decoder_output = self.transformer_decoder(\n", " tgt=decoder_inputs,\n", " memory=memory,\n", " tgt_mask=tgt_mask,\n", " tgt_key_padding_mask=tgt_key_padding_mask\n", " )\n", " \n", " # Project to vocabulary\n", " logits = self.output_projection(decoder_output)\n", " \n", " return logits\n", " \n", " def _generate_square_subsequent_mask(self, sz):\n", " \"\"\"Generate causal mask for decoder\"\"\"\n", " mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)\n", " return mask\n", "\n", "# Combined Encoder-Decoder Model\n", "class SimSonEncoderDecoder(nn.Module):\n", " def __init__(self, encoder, decoder, tokenizer):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.tokenizer = tokenizer\n", "\n", " def forward(self, input_ids, attention_mask, labels=None):\n", " # Encode SMILES to embeddings\n", " embeddings = self.encoder(input_ids, attention_mask)\n", " \n", " if labels is not None:\n", " # During training, use teacher forcing\n", " # Shift labels for decoder input (remove last token, add BOS token)\n", " decoder_input_ids = torch.cat([\n", " torch.full((labels.shape[0], 1), self.tokenizer.cls_token_id, \n", " dtype=labels.dtype, device=labels.device),\n", " labels[:, :-1]\n", " ], dim=1)\n", " \n", " # Generate decoder attention mask\n", " decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).long()\n", " \n", " # Decode embeddings to SMILES\n", " logits = self.decoder(embeddings, decoder_input_ids, decoder_attention_mask)\n", " \n", " # Calculate loss\n", " loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)\n", " loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))\n", " \n", " return {\"loss\": loss, \"logits\": logits}\n", " else:\n", " # During inference\n", " return self.generate(embeddings)\n", " \n", " def generate(self, embeddings, max_length=512):\n", " \"\"\"Generate SMILES from embeddings\"\"\"\n", " batch_size = embeddings.shape[0]\n", " device = embeddings.device\n", " \n", " # Start with CLS token\n", " generated = torch.full((batch_size, 1), self.tokenizer.cls_token_id, \n", " dtype=torch.long, device=device)\n", " \n", " for _ in range(max_length - 1):\n", " decoder_attention_mask = (generated != self.tokenizer.pad_token_id).long()\n", " logits = self.decoder(embeddings, generated, decoder_attention_mask)\n", " \n", " # Get next token probabilities\n", " next_token_logits = logits[:, -1, :]\n", " next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)\n", " \n", " # Append to generated sequence\n", " generated = torch.cat([generated, next_tokens], dim=1)\n", " \n", " # Stop if all sequences have generated EOS token\n", " if torch.all(next_tokens.squeeze() == self.tokenizer.sep_token_id):\n", " break\n", " \n", " return generated\n", "\n", "# Initialize tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "\n", "# Initialize encoder config\n", "config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=4,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "# Initialize and load encoder\n", "encoder = SimSonEncoder(config=config, max_len=512)\n", "\n", "# Load encoder parameters (extract encoder from regression model)\n", "regression_state_dict = torch.load('/home/jovyan/simson_training_bolgov/regression/better_regression_states/best_state.bin', weights_only=False)\n", "\n", "encoder_state_dict = {}\n", "for key, value in regression_state_dict.items():\n", " key = key[len('_orig_mod.'):]\n", " if key.startswith('encoder.'):\n", " encoder_state_dict[key[8 + len('_orig_mod.'):]] = value\n", "\n", "print(\"Encoder parameters loaded\")\n", "\n", "# Freeze encoder parameters\n", "for param in encoder.parameters():\n", " param.requires_grad = False\n", "print(\"Encoder parameters frozen\")\n", "\n", "# Initialize decoder\n", "decoder = SimSonDecoder(\n", " embedding_dim=512, # encoder.max_len\n", " hidden_dim=768, # config.hidden_size\n", " vocab_size=tokenizer.vocab_size,\n", " max_len=512\n", ")\n", "\n", "# Create combined model\n", "model = SimSonEncoderDecoder(encoder, decoder, tokenizer)\n" ] }, { "cell_type": "markdown", "id": "27946411-ddb4-48f2-ab5d-baec24e53954", "metadata": {}, "source": [ "regression_params = torch.load('/home/jovyan/simson_training_bolgov/regression/simson_encoder_decoder.bin', weights_only=False)\n", "\n", "uncompiled_state_dict = {}\n", "for key, value in regression_params.items():\n", " print(key)\n", " key = key[len('_orig_mod.'):]\n", " uncompiled_state_dict[key] = value\n", " print(key)\n", " break\n", " \n", "torch.save(uncompiled_state_dict, '/home/jovyan/simson_training_bolgov/regression/encoder_decoder_uncompiled.bin')" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bd4f17d-e369-4e4c-be51-a91cb0b85a0d", "metadata": {}, "outputs": [], "source": [ "# Load dataset\n", "df = pd.read_csv('/home/jovyan/simson_training_bolgov/regression/PI_Tg_P308K_synth_db_chem.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "6a40fac8-9c46-4c96-8a47-94ea32803b21", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "id": "5951cb13-e188-4a2c-8b42-9f8125e96165", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training samples: 6390602\n", "Validation samples: 336348\n" ] } ], "source": [ "class SMILESReconstructionDataset(Dataset):\n", " def __init__(self, dataframe, tokenizer, max_length=256):\n", " self.data = dataframe\n", " self.tokenizer = tokenizer\n", " self.max_length = max_length\n", " \n", " def __len__(self):\n", " return len(self.data)\n", " \n", " def __getitem__(self, idx):\n", " # Get SMILES string (adjust column name as needed)\n", " smiles = self.data.iloc[idx]['smiles'] if 'smiles' in self.data.columns else self.data.iloc[idx]['Smiles']\n", " \n", " # Tokenize input SMILES\n", " encoding = self.tokenizer(\n", " smiles,\n", " max_length=self.max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " )\n", " \n", " return {\n", " 'input_ids': encoding['input_ids'].squeeze(0),\n", " 'attention_mask': encoding['attention_mask'].squeeze(0),\n", " 'labels': encoding['input_ids'].squeeze(0) # Same as input for reconstruction\n", " }\n", "\n", "def create_stratified_splits_regression(\n", " df,\n", " label_cols,\n", " n_bins=10,\n", " val_frac=0.05,\n", " seed=42\n", "):\n", " \n", " values = df[label_cols].values\n", " # Each label gets its own bins, based on the overall distribution\n", " bins = [np.unique(np.quantile(values[:,i], np.linspace(0, 1, n_bins+1))) for i in range(len(label_cols))]\n", " # Assign each row to a bin for each label\n", " inds = [\n", " np.digitize(values[:,i], bins[i][1:-1], right=False) # exclude leftmost/rightmost for in-bin, avoids all bin edges as bins\n", " for i in range(len(label_cols))\n", " ]\n", " # Combine into a single integer stratification variable (tuple or max or sum...)\n", " strat_col = np.maximum.reduce(inds) # This ensures high bin in one = high bin overall\n", " # Use sklearn's train_test_split with stratify\n", " train_idx, val_idx = train_test_split(\n", " df.index.values,\n", " test_size=val_frac,\n", " random_state=seed,\n", " shuffle=True,\n", " stratify=strat_col\n", " )\n", " train = df.loc[train_idx].reset_index(drop=True)\n", " val = df.loc[val_idx].reset_index(drop=True)\n", " return train, val\n", "\n", "\n", "tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR'\n", "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n", "\n", "# Same splits\n", "train, test = create_stratified_splits_regression(\n", " df,\n", " label_cols=['CO2', 'CH4'], \n", " n_bins=10,\n", " val_frac=0.05,\n", " seed=42\n", ")\n", "\n", "train_dataset = SMILESReconstructionDataset(train, tokenizer)\n", "val_dataset = SMILESReconstructionDataset(test, tokenizer)\n", "val_df = test\n", "train_df = train\n", "print(f\"Training samples: {len(train_dataset)}\")\n", "print(f\"Validation samples: {len(val_dataset)}\")\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "2b79c8b4-4f5d-4752-b1a7-08ed2d0bb7f5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_330276/2978564159.py:51: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Seq2SeqTrainer.__init__`. Use `processing_class` instead.\n", " trainer = Seq2SeqTrainer(\n" ] } ], "source": [ "# Custom data collator for encoder-decoder\n", "class EncoderDecoderDataCollator:\n", " def __init__(self, tokenizer):\n", " self.tokenizer = tokenizer\n", " \n", " def __call__(self, batch):\n", " # Stack all batch elements\n", " input_ids = torch.stack([item['input_ids'] for item in batch])\n", " attention_mask = torch.stack([item['attention_mask'] for item in batch])\n", " labels = torch.stack([item['labels'] for item in batch])\n", " \n", " return {\n", " 'input_ids': input_ids,\n", " 'attention_mask': attention_mask,\n", " 'labels': labels\n", " }\n", "\n", "\n", "data_collator = EncoderDecoderDataCollator(tokenizer)\n", "\n", "early_stopping_callback = EarlyStoppingCallback(\n", " early_stopping_patience=10,\n", ")\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir='/home/jovyan/simson_training_bolgov/regression/decoder_checkpoints',\n", " eval_strategy='steps',\n", " eval_steps=10_000,\n", " logging_steps=10_000,\n", " save_steps=10_000,\n", " save_total_limit=3,\n", " learning_rate=5e-5,\n", " per_device_train_batch_size=256,\n", " per_device_eval_batch_size=256,\n", " gradient_accumulation_steps=1,\n", " num_train_epochs=5,\n", " warmup_steps=50,\n", " weight_decay=0.01,\n", " logging_dir='./logs',\n", " report_to='none',\n", " load_best_model_at_end=True,\n", " metric_for_best_model='eval_loss',\n", " greater_is_better=False,\n", " fp16=True,\n", " dataloader_pin_memory=True,\n", " remove_unused_columns=False,\n", " predict_with_generate=False \n", ")\n", "\n", "# Initialize trainer\n", "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " data_collator=data_collator,\n", " tokenizer=tokenizer\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "88eccafd-c269-4bc6-8290-0bdb5d184c31", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "9ac6754e-dd6d-4282-b661-0b11e65cadc3", "metadata": {}, "outputs": [], "source": [ "final_state = trainer.model.state_dict()\n", "torch.save(final_state, '/home/jovyan/simson_training_bolgov/regression/simson_encoder_decoder_better_regression.bin')" ] }, { "cell_type": "code", "execution_count": 5, "id": "18637615-64f4-42d5-a431-0765df5bc024", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model initialized with 72,264,271 trainable parameters\n" ] } ], "source": [ "import torch.nn.functional as F\n", "\n", "\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(0)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = outputs.last_hidden_state\n", " hidden_states = self.dropout(hidden_states)\n", " pooled = global_ap(hidden_states)\n", " out = self.linear(pooled)\n", " return out\n", "\n", "class SimSonDecoder(nn.Module):\n", " def __init__(self, embedding_dim, hidden_dim, vocab_size, max_len):\n", " super(SimSonDecoder, self).__init__()\n", " self.embedding_dim = embedding_dim\n", " self.hidden_dim = hidden_dim\n", " self.vocab_size = vocab_size\n", " self.max_len = max_len\n", " \n", " # Project embedding to hidden dimension\n", " self.embedding_projection = nn.Linear(embedding_dim, hidden_dim)\n", " \n", " # Token embeddings for decoder input\n", " self.token_embeddings = nn.Embedding(vocab_size, hidden_dim)\n", " self.position_embeddings = nn.Embedding(max_len, hidden_dim)\n", " \n", " # Transformer decoder layers\n", " decoder_layer = nn.TransformerDecoderLayer(\n", " d_model=hidden_dim,\n", " nhead=12,\n", " dim_feedforward=2048,\n", " dropout=0.1,\n", " batch_first=True\n", " )\n", " self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", " \n", " # Output projection to vocabulary\n", " self.output_projection = nn.Linear(hidden_dim, vocab_size)\n", " self.dropout = nn.Dropout(0.1)\n", " \n", " # Layer normalization\n", " self.layer_norm = nn.LayerNorm(hidden_dim)\n", " \n", " def forward(self, molecule_embeddings, decoder_input_ids, decoder_attention_mask=None):\n", " batch_size, seq_len = decoder_input_ids.shape\n", " device = decoder_input_ids.device\n", " \n", " # Project molecule embeddings to memory for cross-attention\n", " memory = self.embedding_projection(molecule_embeddings) # [batch, hidden_dim]\n", " memory = memory.unsqueeze(1) # [batch, 1, hidden_dim]\n", " \n", " # Create decoder input embeddings\n", " token_emb = self.token_embeddings(decoder_input_ids) # [batch, seq_len, hidden_dim]\n", " \n", " # Add positional embeddings\n", " positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)\n", " pos_emb = self.position_embeddings(positions)\n", " \n", " decoder_inputs = self.layer_norm(token_emb + pos_emb)\n", " decoder_inputs = self.dropout(decoder_inputs)\n", " \n", " # Create causal mask for decoder\n", " tgt_mask = self._generate_square_subsequent_mask(seq_len).to(device)\n", " \n", " # Create attention mask for decoder inputs\n", " if decoder_attention_mask is not None:\n", " # Convert padding mask to additive mask\n", " tgt_key_padding_mask = (decoder_attention_mask == 0)\n", " else:\n", " tgt_key_padding_mask = None\n", " \n", " # Apply transformer decoder\n", " decoder_output = self.transformer_decoder(\n", " tgt=decoder_inputs,\n", " memory=memory,\n", " tgt_mask=tgt_mask,\n", " tgt_key_padding_mask=tgt_key_padding_mask\n", " )\n", " \n", " # Project to vocabulary\n", " logits = self.output_projection(decoder_output)\n", " \n", " return logits\n", " \n", " def _generate_square_subsequent_mask(self, sz):\n", " \"\"\"Generate causal mask for decoder\"\"\"\n", " mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)\n", " return mask\n", "\n", "# Combined Encoder-Decoder Model\n", "class SimSonEncoderDecoder(nn.Module):\n", " def __init__(self, encoder, decoder, tokenizer):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.tokenizer = tokenizer\n", " \n", " def forward(self, input_ids, attention_mask, labels=None):\n", " # Encode SMILES to embeddings\n", " embeddings = self.encoder(input_ids, attention_mask)\n", " \n", " if labels is not None:\n", " # During training, use teacher forcing\n", " # Shift labels for decoder input (remove last token, add BOS token)\n", " decoder_input_ids = torch.cat([\n", " torch.full((labels.shape[0], 1), self.tokenizer.cls_token_id, \n", " dtype=labels.dtype, device=labels.device),\n", " labels[:, :-1]\n", " ], dim=1)\n", " \n", " # Generate decoder attention mask\n", " decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).long()\n", " \n", " # Decode embeddings to SMILES\n", " logits = self.decoder(embeddings, decoder_input_ids, decoder_attention_mask)\n", " \n", " # Calculate loss\n", " loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)\n", " loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))\n", " \n", " return {\"loss\": loss, \"logits\": logits}\n", " else:\n", " # During inference\n", " return self.generate(embeddings)\n", " \n", " def generate(self, embeddings, max_length=512, temperature=1.0):\n", "\n", " batch_size = embeddings.shape[0]\n", " device = embeddings.device\n", " \n", " # Start with the CLS token for all sequences in the batch\n", " generated = torch.full((batch_size, 1), self.tokenizer.cls_token_id, \n", " dtype=torch.long, device=device)\n", " \n", " # --- THE FIX: Keep track of which sequences have finished ---\n", " # A boolean tensor, initially all False.\n", " is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)\n", " \n", " for _ in range(max_length - 1):\n", " # Pass the current generated sequences to the decoder\n", " decoder_attention_mask = (generated != self.tokenizer.pad_token_id).long()\n", " logits = self.decoder(embeddings, generated, decoder_attention_mask)\n", " \n", " # Focus on the logits for the last token in each sequence\n", " next_token_logits = logits[:, -1, :]\n", " \n", " # Apply temperature sampling\n", " if temperature == 0.0:\n", " next_tokens = torch.argmax(next_token_logits, dim=-1)\n", " else:\n", " probs = F.softmax(next_token_logits / temperature, dim=-1)\n", " next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)\n", " \n", " # --- THE FIX: Prevent finished sequences from generating new tokens ---\n", " # If a sequence is already finished, its next token should be a pad token.\n", " # Otherwise, use the newly generated token.\n", " next_tokens = torch.where(is_finished, self.tokenizer.pad_token_id, next_tokens)\n", " \n", " # Append the new tokens to the generated sequences\n", " generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)\n", " \n", " # --- THE FIX: Update the `is_finished` status for any sequence that just produced an EOS token ---\n", " is_finished |= (next_tokens == self.tokenizer.sep_token_id)\n", " \n", " # --- THE FIX: Stop if all sequences in the batch are finished ---\n", " if torch.all(is_finished):\n", " break\n", " \n", " return generated\n", "\n", "\n", "# Initialize tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "\n", "# Initialize encoder config\n", "config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=12,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "# Initialize and load encoder\n", "encoder = SimSonEncoder(config=config, max_len=512)\n", "\n", "# Initialize decoder\n", "decoder = SimSonDecoder(\n", " embedding_dim=512, # encoder.max_len\n", " hidden_dim=768, # config.hidden_size\n", " vocab_size=tokenizer.vocab_size,\n", " max_len=512\n", ")\n", "\n", "# Create combined model\n", "model = SimSonEncoderDecoder(encoder, decoder, tokenizer)\n", "model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/regression/simson_encoder_decoder.bin', weights_only=False))\n", "model.cuda()\n", "print(f\"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "3f250d5d-0573-4f5b-a665-44ca3ecf135d", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import joblib\n", "import numpy as np\n", "from transformers import BertConfig, AutoTokenizer, BertModel\n", "from tqdm import tqdm\n", "\n", "# Load tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')\n", "\n", "# Initialize model configuration\n", "config = BertConfig(\n", " vocab_size=tokenizer.vocab_size,\n", " hidden_size=768,\n", " num_hidden_layers=4,\n", " num_attention_heads=12,\n", " intermediate_size=2048,\n", " max_position_embeddings=512\n", ")\n", "\n", "# Initialize regression model\n", "def global_ap(x):\n", " return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)\n", "\n", "class SimSonEncoder(nn.Module):\n", " def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):\n", " super(SimSonEncoder, self).__init__()\n", " self.config = config\n", " self.max_len = max_len\n", " self.bert = BertModel(config, add_pooling_layer=False)\n", " self.linear = nn.Linear(config.hidden_size, max_len)\n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " if attention_mask is None:\n", " attention_mask = input_ids.ne(0)\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " hidden_states = outputs.last_hidden_state\n", " hidden_states = self.dropout(hidden_states)\n", " pooled = global_ap(hidden_states)\n", " out = self.linear(pooled)\n", " return out\n", "\n", "class SimSonClassifier(nn.Module):\n", " def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1):\n", " super(SimSonClassifier, self).__init__()\n", " self.encoder = encoder\n", " self.clf = nn.Linear(encoder.max_len, num_labels)\n", " self.relu = nn.ReLU()\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, input_ids, attention_mask=None, labels=None):\n", " x = self.encoder(input_ids, attention_mask)\n", " x = self.relu(self.dropout(x))\n", " x = self.clf(x)\n", " return x\n", "\n", "compiling = False\n", "# Load model and parameters\n", "encoder = SimSonEncoder(config=config, max_len=512)\n", "if compiling:\n", " encoder = torch.compile(encoder)\n", "encoder_expert = SimSonEncoder(config=config, max_len=512)\n", "\n", "regression_model = SimSonClassifier(encoder=encoder, num_labels=6)\n", "regression_expert_model = SimSonClassifier(encoder=encoder_expert, num_labels=6)\n", "\n", "if compiling:\n", " regression_model = torch.compile(regression_model)\n", "\n", "# Load trained parameters\n", "regression_params = torch.load('/home/jovyan/simson_training_bolgov/regression/better_regression_states/best_state.bin', weights_only=False)\n", "regression_expert_params = torch.load('/home/jovyan/simson_training_bolgov/regression/high_regression_old_scalers_simson.pth', weights_only=False)\n", "\n", "\n", "regression_model.load_state_dict(regression_params)\n", "regression_expert_model.load_state_dict(regression_expert_params)\n", "\n", "regression_model.eval()\n", "regression_model.cuda()\n", "\n", "regression_expert_model.eval()\n", "regression_expert_model.cuda()\n", "\n", "# Load scalers\n", "scalers = joblib.load('/home/jovyan/simson_training_bolgov/regression/scalers')\n", "scaler_ch4 = scalers[-2] # CH4 scaler\n", "scaler_co2 = scalers[-1] # CO2 scaler" ] }, { "cell_type": "code", "execution_count": 7, "id": "7a775a04-26de-434f-a138-bafb84b661ea", "metadata": {}, "outputs": [], "source": [ "from rdkit import Chem\n", "from rdkit.Chem import MolToSmiles\n", "from tqdm import tqdm\n", "import torch\n", "\n", "def validate_reconstruction(model, tokenizer, test_smiles_list, num_samples=20, max_length=256):\n", " \"\"\"\n", " Evaluate reconstruction quality.\n", " • exact_match ........ literal string equality\n", " • canonical_match .... same molecule after canonicalisation\n", " • generated_valid .... RDKit is able to parse generated SMILES\n", " \"\"\"\n", " model.eval()\n", " device = next(model.parameters()).device\n", "\n", " results = []\n", " with torch.no_grad():\n", " for i, smiles in enumerate(test_smiles_list[:num_samples]):\n", " # ---------- encode ----------\n", " tokens = tokenizer(\n", " smiles,\n", " max_length=max_length,\n", " truncation=True,\n", " padding=\"max_length\",\n", " return_tensors=\"pt\"\n", " ).to(device)\n", "\n", " embedding = model.encoder(tokens[\"input_ids\"], tokens[\"attention_mask\"])\n", "\n", " # ---------- decode ----------\n", " gen_ids = model.generate(embedding)\n", " gen_smiles = tokenizer.decode(gen_ids[0], skip_special_tokens=True)\n", "\n", " # ---------- validity ----------\n", " mol_orig = Chem.MolFromSmiles(smiles)\n", " mol_gen = Chem.MolFromSmiles(gen_smiles)\n", " orig_valid = mol_orig is not None\n", " gen_valid = mol_gen is not None\n", "\n", " # ---------- canonical comparison ----------\n", " if orig_valid and gen_valid:\n", " can_orig = MolToSmiles(mol_orig, canonical=True)\n", " can_gen = MolToSmiles(mol_gen, canonical=True)\n", " canonical_match = (can_orig == can_gen)\n", " else:\n", " canonical_match = False\n", "\n", " results.append({\n", " \"original\": smiles,\n", " \"reconstructed\": gen_smiles,\n", " \"exact_match\": smiles == gen_smiles,\n", " \"canonical_match\": canonical_match,\n", " \"generated_valid\": gen_valid\n", " })\n", "\n", " print(f\"\\nSample {i+1}\")\n", " print(f\" Original: {smiles}\")\n", " print(f\" Reconstructed: {gen_smiles}\")\n", " print(f\" Valid (gen): {gen_valid}\")\n", " print(f\" Same molecule: {canonical_match}\")\n", " print(\"-\" * 60)\n", "\n", " # ---------- aggregated metrics ----------\n", " exact_acc = sum(r[\"exact_match\"] for r in results) / len(results)\n", " canonical_acc = sum(r[\"canonical_match\"] for r in results) / len(results)\n", " validity_rate = sum(r[\"generated_valid\"] for r in results) / len(results)\n", "\n", " print(f\"\\nExact-string match accuracy ...... {exact_acc:.2%}\")\n", " print(f\"Canonical match accuracy ......... {canonical_acc:.2%}\")\n", " print(f\"Generated validity rate .......... {validity_rate:.2%}\")\n", "\n", " return results\n", "\n", "test_smiles = val_df['smiles' if 'smiles' in val_df.columns else 'Smiles'].tolist()\n", "#validation_results = validate_reconstruction(model, tokenizer, test_smiles, num_samples=20)\n" ] }, { "cell_type": "code", "execution_count": 115, "id": "21d72f76-9bc6-480d-8bf1-205c34ed7e2e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████| 999/999 [04:08<00:00, 4.02it/s]\n" ] } ], "source": [ "from sklearn.metrics import pairwise_distances\n", "from sklearn.preprocessing import MinMaxScaler\n", "\n", "\n", "features = ['CO2', 'CH4']\n", "\n", "# --- 3. Max–Min selection ----------------------------------------------------\n", "def select_diverse_maxmin(data, cols, k=1_000):\n", " X = data[cols].to_numpy()\n", " \n", " # scale to [0,1] so CO₂ and CH₄ get equal weight\n", " X = MinMaxScaler().fit_transform(X)\n", " \n", " n = len(X)\n", " if k >= n: # nothing to do\n", " return data\n", " \n", " # start with a random seed point\n", " selected = [np.random.randint(0, n)]\n", " \n", " # pre-allocate distance cache\n", " min_dist = pairwise_distances(\n", " X, X[selected], metric='euclidean'\n", " ).ravel() # distance to first point\n", " \n", " for _ in tqdm(range(1, k)):\n", " # pick the point with the largest distance to the current set\n", " idx = np.argmax(min_dist)\n", " selected.append(idx)\n", " \n", " # update distance cache (keep the shortest distance to any selected pt)\n", " dist_to_new = pairwise_distances(X, X[[idx]], metric='euclidean').ravel()\n", " min_dist = np.minimum(min_dist, dist_to_new)\n", " \n", " return data.iloc[selected].reset_index(drop=True)\n", "\n", "sample_df = select_diverse_maxmin(df, features, k=1_000)" ] }, { "cell_type": "code", "execution_count": 8, "id": "ecae2eaa-dcd0-4a0d-a8d6-347c254285d0", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import pandas as pd\n", "from rdkit import Chem\n", "from rdkit.Chem import MolToSmiles\n", "from scipy.stats import pearsonr, spearmanr\n", "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "def display_property_fidelity_results(results_df, metrics):\n", " \"\"\"Display comprehensive property fidelity analysis\"\"\"\n", " \n", " valid_results = results_df[results_df['reconstructed_valid']].copy()\n", " total_molecules = len(results_df)\n", " valid_molecules = len(valid_results)\n", " identical_molecules = results_df['chemically_identical'].sum()\n", " \n", " print(f\"\\n{'='*80}\")\n", " print(f\"PROPERTY RECONSTRUCTION FIDELITY ANALYSIS\")\n", " print(f\"{'='*80}\")\n", " \n", " print(f\"Dataset Overview:\")\n", " print(f\" Total molecules tested: {total_molecules}\")\n", " print(f\" Valid reconstructions: {valid_molecules} ({valid_molecules/total_molecules*100:.1f}%)\")\n", " print(f\" Chemically identical: {identical_molecules} ({identical_molecules/total_molecules*100:.1f}%)\")\n", " \n", " if valid_molecules > 0:\n", " print(f\"\\n{'='*60}\")\n", " print(f\"PROPERTY CORRELATION METRICS\")\n", " print(f\"{'='*60}\")\n", " \n", " print(f\"CH₄ Permeability:\")\n", " print(f\" Pearson correlation: {metrics['CH4']['pearson_correlation']:.4f}\")\n", " print(f\" Spearman correlation: {metrics['CH4']['spearman_correlation']:.4f}\")\n", " print(f\" R² score: {metrics['CH4']['r2_score']:.4f}\")\n", " print(f\" MAE: {metrics['CH4']['mae']:.4f}\")\n", " print(f\" RMSE: {metrics['CH4']['rmse']:.4f}\")\n", " print(f\" Mean relative error: {metrics['CH4']['mean_relative_error_pct']:.2f}%\")\n", " print(f\" Median relative error: {metrics['CH4']['median_relative_error_pct']:.2f}%\")\n", " \n", " print(f\"\\nCO₂ Permeability:\")\n", " print(f\" Pearson correlation: {metrics['CO2']['pearson_correlation']:.4f}\")\n", " print(f\" Spearman correlation: {metrics['CO2']['spearman_correlation']:.4f}\")\n", " print(f\" R² score: {metrics['CO2']['r2_score']:.4f}\")\n", " print(f\" MAE: {metrics['CO2']['mae']:.4f}\")\n", " print(f\" RMSE: {metrics['CO2']['rmse']:.4f}\")\n", " print(f\" Mean relative error: {metrics['CO2']['mean_relative_error_pct']:.2f}%\")\n", " print(f\" Median relative error: {metrics['CO2']['median_relative_error_pct']:.2f}%\")\n", " \n", " # Property preservation quality assessment\n", " print(f\"\\n{'='*60}\")\n", " print(f\"PROPERTY PRESERVATION ASSESSMENT\")\n", " print(f\"{'='*60}\")\n", " \n", " # Define quality thresholds\n", " excellent_threshold = 5.0 # <5% relative error\n", " good_threshold = 15.0 # <15% relative error\n", " acceptable_threshold = 30.0 # <30% relative error\n", " \n", " ch4_excellent = (valid_results['CH4_relative_error'] < excellent_threshold).sum()\n", " ch4_good = (valid_results['CH4_relative_error'] < good_threshold).sum()\n", " ch4_acceptable = (valid_results['CH4_relative_error'] < acceptable_threshold).sum()\n", " \n", " co2_excellent = (valid_results['CO2_relative_error'] < excellent_threshold).sum()\n", " co2_good = (valid_results['CO2_relative_error'] < good_threshold).sum()\n", " co2_acceptable = (valid_results['CO2_relative_error'] < acceptable_threshold).sum()\n", " \n", " print(f\"CH₄ Property Preservation Quality:\")\n", " print(f\" Excellent (<{excellent_threshold}% error): {ch4_excellent}/{valid_molecules} ({ch4_excellent/valid_molecules*100:.1f}%)\")\n", " print(f\" Good (<{good_threshold}% error): {ch4_good}/{valid_molecules} ({ch4_good/valid_molecules*100:.1f}%)\")\n", " print(f\" Acceptable (<{acceptable_threshold}% error): {ch4_acceptable}/{valid_molecules} ({ch4_acceptable/valid_molecules*100:.1f}%)\")\n", " \n", " print(f\"\\nCO₂ Property Preservation Quality:\")\n", " print(f\" Excellent (<{excellent_threshold}% error): {co2_excellent}/{valid_molecules} ({co2_excellent/valid_molecules*100:.1f}%)\")\n", " print(f\" Good (<{good_threshold}% error): {co2_good}/{valid_molecules} ({co2_good/valid_molecules*100:.1f}%)\")\n", " print(f\" Acceptable (<{acceptable_threshold}% error): {co2_acceptable}/{valid_molecules} ({co2_acceptable/valid_molecules*100:.1f}%)\")\n", "\n", "\n", "def evaluate_property_reconstruction_fidelity(model, regression_model, tokenizer, \n", " scaler_ch4, scaler_co2, test_smiles_list,\n", " batch_size=16, max_length=256):\n", " \"\"\"\n", " Evaluate how well reconstructed SMILES preserve chemical properties\n", " \n", " Args:\n", " model: Trained encoder-decoder model\n", " regression_model: Trained regression model for property prediction\n", " tokenizer: SMILES tokenizer\n", " scaler_ch4, scaler_co2: Property scalers\n", " test_smiles_list: List of original SMILES to test\n", " batch_size: Batch size for processing\n", " max_length: Maximum SMILES length\n", " \n", " Returns:\n", " results_df: DataFrame with detailed property comparison results\n", " metrics: Dictionary of aggregate evaluation metrics\n", " \"\"\"\n", " \n", " model.eval()\n", " regression_model.eval()\n", " device = next(model.parameters()).device\n", " \n", " results = []\n", " \n", " print(f\"Evaluating property reconstruction fidelity for {len(test_smiles_list)} molecules...\")\n", " \n", " # Process in batches for memory efficiency\n", " for i in tqdm(range(0, len(test_smiles_list), batch_size), desc=\"Processing batches\"):\n", " batch_end = min(i + batch_size, len(test_smiles_list))\n", " batch_smiles = test_smiles_list[i:batch_end]\n", " \n", " with torch.no_grad():\n", " # Step 1: Generate embeddings from original SMILES\n", " original_tokens = tokenizer(\n", " batch_smiles,\n", " max_length=max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " # Get embeddings\n", " embeddings = model.encoder(original_tokens['input_ids'].cuda(), original_tokens['attention_mask'].cuda())\n", " \n", " # Step 2: Reconstruct SMILES from embeddings\n", " reconstructed_ids = model.generate(embeddings, temperature=1.5)\n", " \n", " # Step 3: Predict properties for both original and reconstructed\n", " # Original properties\n", " orig_predictions = regression_model(original_tokens['input_ids'], original_tokens['attention_mask'])\n", " orig_ch4_scaled = orig_predictions[:, -2].cpu().numpy().reshape(-1, 1)\n", " orig_co2_scaled = orig_predictions[:, -1].cpu().numpy().reshape(-1, 1)\n", " orig_ch4 = scaler_ch4.inverse_transform(orig_ch4_scaled).flatten()\n", " orig_co2 = scaler_co2.inverse_transform(orig_co2_scaled).flatten()\n", " \n", " # Reconstructed properties\n", " reconstructed_smiles = [tokenizer.decode(seq, skip_special_tokens=True) for seq in reconstructed_ids]\n", " # Validate reconstructed SMILES and predict properties\n", " for j, (orig_smiles, recon_smiles) in enumerate(zip(batch_smiles, reconstructed_smiles)):\n", " # Chemical validity checks\n", " orig_mol = Chem.MolFromSmiles(orig_smiles)\n", " recon_mol = Chem.MolFromSmiles(recon_smiles)\n", " \n", " orig_valid = orig_mol is not None\n", " recon_valid = recon_mol is not None\n", " \n", " # Canonical SMILES comparison\n", " if orig_valid and recon_valid:\n", " orig_canonical = MolToSmiles(orig_mol, canonical=True)\n", " recon_canonical = MolToSmiles(recon_mol, canonical=True)\n", " is_identical = orig_canonical == recon_canonical\n", " else:\n", " orig_canonical = \"INVALID\" if not orig_valid else MolToSmiles(orig_mol, canonical=True)\n", " recon_canonical = \"INVALID\" if not recon_valid else MolToSmiles(recon_mol, canonical=True)\n", " is_identical = False\n", " \n", " # Property prediction for reconstructed SMILES\n", " if recon_valid:\n", " recon_tokens = tokenizer(\n", " recon_smiles,\n", " max_length=max_length,\n", " padding='max_length',\n", " truncation=True,\n", " return_tensors='pt'\n", " ).to(device)\n", " \n", " recon_predictions = regression_model(recon_tokens['input_ids'], recon_tokens['attention_mask'])\n", " recon_ch4_scaled = recon_predictions[0, -2].cpu().numpy().reshape(-1, 1)\n", " recon_co2_scaled = recon_predictions[0, -1].cpu().numpy().reshape(-1, 1)\n", " recon_ch4 = scaler_ch4.inverse_transform(recon_ch4_scaled)[0, 0]\n", " recon_co2 = scaler_co2.inverse_transform(recon_co2_scaled)[0, 0]\n", " else:\n", " recon_ch4 = np.nan\n", " recon_co2 = np.nan\n", " \n", " # Calculate property differences\n", " ch4_absolute_error = abs(orig_ch4[j] - recon_ch4) if recon_valid else np.nan\n", " co2_absolute_error = abs(orig_co2[j] - recon_co2) if recon_valid else np.nan\n", " \n", " ch4_relative_error = (ch4_absolute_error / orig_ch4[j] * 100) if (recon_valid and orig_ch4[j] != 0) else np.nan\n", " co2_relative_error = (co2_absolute_error / orig_co2[j] * 100) if (recon_valid and orig_co2[j] != 0) else np.nan\n", " \n", " results.append({\n", " 'molecule_id': i + j + 1,\n", " 'original_smiles': orig_smiles,\n", " 'reconstructed_smiles': recon_smiles,\n", " 'original_canonical': orig_canonical,\n", " 'reconstructed_canonical': recon_canonical,\n", " 'chemically_identical': is_identical,\n", " 'original_valid': orig_valid,\n", " 'reconstructed_valid': recon_valid,\n", " 'original_CH4': orig_ch4[j],\n", " 'original_CO2': orig_co2[j],\n", " 'reconstructed_CH4': recon_ch4,\n", " 'reconstructed_CO2': recon_co2,\n", " 'CH4_absolute_error': ch4_absolute_error,\n", " 'CO2_absolute_error': co2_absolute_error,\n", " 'CH4_relative_error': ch4_relative_error,\n", " 'CO2_relative_error': co2_relative_error\n", " })\n", " \n", " # Convert to DataFrame\n", " results_df = pd.DataFrame(results)\n", " \n", " # Calculate aggregate metrics\n", " valid_comparisons = results_df[results_df['reconstructed_valid']].copy()\n", " \n", " if len(valid_comparisons) > 0:\n", " metrics = calculate_property_fidelity_metrics(valid_comparisons)\n", " display_property_fidelity_results(results_df, metrics)\n", " else:\n", " print(\"No valid reconstructions for property comparison!\")\n", " metrics = {}\n", " \n", " return results_df, metrics\n", "\n", "def visualize_property_fidelity(results_df, save_plots=True):\n", " \"\"\"Create comprehensive visualizations of property reconstruction fidelity\"\"\"\n", " \n", " valid_results = results_df[results_df['reconstructed_valid']].copy()\n", " \n", " if len(valid_results) == 0:\n", " print(\"No valid results to visualize!\")\n", " return\n", " \n", " fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", " \n", " # 1. CH4 Property Correlation\n", " axes[0, 0].scatter(valid_results['original_CH4'], valid_results['reconstructed_CH4'], \n", " alpha=0.6, s=30, edgecolors='black', linewidth=0.5)\n", " \n", " # Perfect correlation line\n", " min_ch4 = min(valid_results['original_CH4'].min(), valid_results['reconstructed_CH4'].min())\n", " max_ch4 = max(valid_results['original_CH4'].max(), valid_results['reconstructed_CH4'].max())\n", " axes[0, 0].plot([min_ch4, max_ch4], [min_ch4, max_ch4], 'r--', linewidth=2, label='Perfect Correlation')\n", " \n", " # Calculate and display R²\n", " ch4_r2 = r2_score(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " axes[0, 0].set_xlabel('Original CH₄ Permeability')\n", " axes[0, 0].set_ylabel('Reconstructed CH₄ Permeability')\n", " axes[0, 0].set_title(f'CH₄ Property Reconstruction (R² = {ch4_r2:.3f})')\n", " axes[0, 0].legend()\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " # 2. CO2 Property Correlation\n", " axes[0, 1].scatter(valid_results['original_CO2'], valid_results['reconstructed_CO2'], \n", " alpha=0.6, s=30, edgecolors='black', linewidth=0.5)\n", " \n", " min_co2 = min(valid_results['original_CO2'].min(), valid_results['reconstructed_CO2'].min())\n", " max_co2 = max(valid_results['original_CO2'].max(), valid_results['reconstructed_CO2'].max())\n", " axes[0, 1].plot([min_co2, max_co2], [min_co2, max_co2], 'r--', linewidth=2, label='Perfect Correlation')\n", " \n", " co2_r2 = r2_score(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " axes[0, 1].set_xlabel('Original CO₂ Permeability')\n", " axes[0, 1].set_ylabel('Reconstructed CO₂ Permeability')\n", " axes[0, 1].set_title(f'CO₂ Property Reconstruction (R² = {co2_r2:.3f})')\n", " axes[0, 1].legend()\n", " axes[0, 1].grid(True, alpha=0.3)\n", " \n", " # 3. Relative Error Distribution - CH4\n", " axes[0, 2].hist(valid_results['CH4_relative_error'].dropna(), bins=30, alpha=0.7, \n", " edgecolor='black', color='skyblue')\n", " axes[0, 2].axvline(valid_results['CH4_relative_error'].median(), color='red', \n", " linestyle='--', linewidth=2, label=f'Median: {valid_results[\"CH4_relative_error\"].median():.1f}%')\n", " axes[0, 2].set_xlabel('CH₄ Relative Error (%)')\n", " axes[0, 2].set_ylabel('Frequency')\n", " axes[0, 2].set_title('CH₄ Relative Error Distribution')\n", " axes[0, 2].legend()\n", " axes[0, 2].grid(True, alpha=0.3)\n", " \n", " # 4. Relative Error Distribution - CO2\n", " axes[1, 0].hist(valid_results['CO2_relative_error'].dropna(), bins=30, alpha=0.7, \n", " edgecolor='black', color='lightcoral')\n", " axes[1, 0].axvline(valid_results['CO2_relative_error'].median(), color='red', \n", " linestyle='--', linewidth=2, label=f'Median: {valid_results[\"CO2_relative_error\"].median():.1f}%')\n", " axes[1, 0].set_xlabel('CO₂ Relative Error (%)')\n", " axes[1, 0].set_ylabel('Frequency')\n", " axes[1, 0].set_title('CO₂ Relative Error Distribution')\n", " axes[1, 0].legend()\n", " axes[1, 0].grid(True, alpha=0.3)\n", " \n", " # 5. Error Comparison\n", " error_comparison = pd.DataFrame({\n", " 'CH₄_Error': valid_results['CH4_relative_error'].dropna(),\n", " 'CO₂_Error': valid_results['CO2_relative_error'].dropna()\n", " })\n", " \n", " axes[1, 1].scatter(error_comparison['CH₄_Error'], error_comparison['CO₂_Error'], \n", " alpha=0.6, s=30, edgecolors='black', linewidth=0.5)\n", " axes[1, 1].set_xlabel('CH₄ Relative Error (%)')\n", " axes[1, 1].set_ylabel('CO₂ Relative Error (%)')\n", " axes[1, 1].set_title('Property Error Correlation')\n", " axes[1, 1].grid(True, alpha=0.3)\n", " \n", " # 6. Chemical Identity vs Property Preservation\n", " identical_mask = valid_results['chemically_identical']\n", " \n", " ch4_errors_identical = valid_results[identical_mask]['CH4_relative_error'].dropna()\n", " ch4_errors_different = valid_results[~identical_mask]['CH4_relative_error'].dropna()\n", " \n", " axes[1, 2].boxplot([ch4_errors_identical, ch4_errors_different], \n", " labels=['Chemically Identical', 'Chemically Different'])\n", " axes[1, 2].set_ylabel('CH₄ Relative Error (%)')\n", " axes[1, 2].set_title('Property Error by Chemical Identity')\n", " axes[1, 2].grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " \n", " if save_plots:\n", " plt.savefig('property_reconstruction_fidelity.png', dpi=300, bbox_inches='tight')\n", " print(\"Visualization saved as 'property_reconstruction_fidelity.png'\")\n", " \n", " plt.show()\n", "\n", "\n", "def calculate_property_fidelity_metrics(valid_results):\n", " \"\"\"Calculate comprehensive property fidelity metrics\"\"\"\n", " \n", " metrics = {}\n", " \n", " # Correlation metrics\n", " ch4_corr_pearson, ch4_p_pearson = pearsonr(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_corr_pearson, co2_p_pearson = pearsonr(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " ch4_corr_spearman, ch4_p_spearman = spearmanr(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_corr_spearman, co2_p_spearman = spearmanr(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " # Regression metrics\n", " ch4_r2 = r2_score(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_r2 = r2_score(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " ch4_mae = mean_absolute_error(valid_results['original_CH4'], valid_results['reconstructed_CH4'])\n", " co2_mae = mean_absolute_error(valid_results['original_CO2'], valid_results['reconstructed_CO2'])\n", " \n", " ch4_rmse = np.sqrt(mean_squared_error(valid_results['original_CH4'], valid_results['reconstructed_CH4']))\n", " co2_rmse = np.sqrt(mean_squared_error(valid_results['original_CO2'], valid_results['reconstructed_CO2']))\n", " \n", " # Relative error statistics\n", " ch4_mean_rel_error = valid_results['CH4_relative_error'].mean()\n", " co2_mean_rel_error = valid_results['CO2_relative_error'].mean()\n", " \n", " ch4_median_rel_error = valid_results['CH4_relative_error'].median()\n", " co2_median_rel_error = valid_results['CO2_relative_error'].median()\n", " \n", " metrics = {\n", " 'CH4': {\n", " 'pearson_correlation': ch4_corr_pearson,\n", " 'spearman_correlation': ch4_corr_spearman,\n", " 'r2_score': ch4_r2,\n", " 'mae': ch4_mae,\n", " 'rmse': ch4_rmse,\n", " 'mean_relative_error_pct': ch4_mean_rel_error,\n", " 'median_relative_error_pct': ch4_median_rel_error\n", " },\n", " 'CO2': {\n", " 'pearson_correlation': co2_corr_pearson,\n", " 'spearman_correlation': co2_corr_spearman,\n", " 'r2_score': co2_r2,\n", " 'mae': co2_mae,\n", " 'rmse': co2_rmse,\n", " 'mean_relative_error_pct': co2_mean_rel_error,\n", " 'median_relative_error_pct': co2_median_rel_error\n", " }\n", " }\n", " \n", " return metrics\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "3e6bcd77-7314-4829-8d26-26eab56f418c", "metadata": {}, "outputs": [], "source": [ "from rdkit import RDLogger\n", "\n", "# Disable all RDKit logs\n", "RDLogger.DisableLog('rdApp.*') " ] }, { "cell_type": "code", "execution_count": 26, "id": "04ecf928-10d8-4366-a9d6-1a646301d387", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Unnamed: 0 | \n", "Smiles | \n", "Tg | \n", "He | \n", "N2 | \n", "O2 | \n", "CH4 | \n", "CO2 | \n", "synthesizable | \n", "
---|---|---|---|---|---|---|---|---|---|
5117051 | \n", "5117051 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
1654268 | \n", "1654268 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
6692275 | \n", "6692275 | \n", "Ic1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
830270 | \n", "830270 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
6692276 | \n", "6692276 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
10950 | \n", "10950 | \n", "Ic1cc(Oc2cc(Oc3cc(cc(c3)n3c(=O)c4c(c3=O)cc3c(c... | \n", "538.15 | \n", "1.23546 | \n", "0.00105 | \n", "0.00404 | \n", "0.00195 | \n", "0.00785 | \n", "True | \n", "
5309014 | \n", "5309014 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "552.00 | \n", "2.56385 | \n", "0.00138 | \n", "0.00777 | \n", "0.00287 | \n", "0.00772 | \n", "False | \n", "
2598418 | \n", "2598418 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "552.00 | \n", "2.56385 | \n", "0.00138 | \n", "0.00777 | \n", "0.00287 | \n", "0.00772 | \n", "False | \n", "
746918 | \n", "746918 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "552.00 | \n", "2.56385 | \n", "0.00138 | \n", "0.00777 | \n", "0.00287 | \n", "0.00772 | \n", "False | \n", "
2255 | \n", "2255 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "586.80 | \n", "1.80544 | \n", "0.00068 | \n", "0.00330 | \n", "0.00184 | \n", "0.00419 | \n", "False | \n", "
6726950 rows × 9 columns
\n", "\n", " | Unnamed: 0 | \n", "Smiles | \n", "Tg | \n", "He | \n", "N2 | \n", "O2 | \n", "CH4 | \n", "CO2 | \n", "synthesizable | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... | \n", "494.84 | \n", "2.69524 | \n", "4.75740 | \n", "42.31847 | \n", "1.64086 | \n", "148.43644 | \n", "False | \n", "
1 | \n", "1 | \n", "Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... | \n", "508.26 | \n", "5.33815 | \n", "2.97239 | \n", "26.31118 | \n", "0.86467 | \n", "82.37635 | \n", "False | \n", "
2 | \n", "2 | \n", "O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... | \n", "640.91 | \n", "20.47515 | \n", "0.06353 | \n", "0.90498 | \n", "0.06905 | \n", "2.35993 | \n", "False | \n", "
3 | \n", "3 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... | \n", "568.04 | \n", "4.19692 | \n", "0.00191 | \n", "0.01134 | \n", "0.00362 | \n", "0.01418 | \n", "False | \n", "
4 | \n", "4 | \n", "Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... | \n", "548.10 | \n", "142.68327 | \n", "0.87380 | \n", "8.25409 | \n", "2.52067 | \n", "30.04739 | \n", "False | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
6726945 | \n", "6726945 | \n", "Ic1cccc(c1)Cc1cc(C)c(c(c1)C)c1c(C)cc(cc1C)Cc1c... | \n", "516.27 | \n", "13.32967 | \n", "0.11907 | \n", "1.10144 | \n", "0.16907 | \n", "2.63642 | \n", "False | \n", "
6726946 | \n", "6726946 | \n", "Ic1ccc(nc1)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(... | \n", "510.72 | \n", "8.96454 | \n", "7.92257 | \n", "72.92929 | \n", "2.17324 | \n", "247.14446 | \n", "False | \n", "
6726947 | \n", "6726947 | \n", "Ic1ccc(cn1)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)... | \n", "610.24 | \n", "4.90758 | \n", "20.24364 | \n", "121.24494 | \n", "1.01011 | \n", "650.18364 | \n", "False | \n", "
6726948 | \n", "6726948 | \n", "Ic1ccc(c(c1)C)Oc1ccc2c(c1)Cc1c2ccc(c1)Oc1ccc(c... | \n", "510.90 | \n", "12.40907 | \n", "0.19307 | \n", "1.41335 | \n", "0.11728 | \n", "4.00573 | \n", "True | \n", "
6726949 | \n", "6726949 | \n", "Ic1cc(Sc2ccc3c(c2)cc(cc3)Sc2cc(cc(c2)C(F)(F)F)... | \n", "462.31 | \n", "14.96342 | \n", "0.13916 | \n", "1.05252 | \n", "0.09298 | \n", "2.54411 | \n", "False | \n", "
6726950 rows × 9 columns
\n", "\n", " | Unnamed: 0 | \n", "Smiles | \n", "Tg | \n", "He | \n", "N2 | \n", "O2 | \n", "CH4 | \n", "CO2 | \n", "synthesizable | \n", "
---|---|---|---|---|---|---|---|---|---|
5117051 | \n", "5117051 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
1654268 | \n", "1654268 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
6692275 | \n", "6692275 | \n", "Ic1ccc(cn1)C(C(F)(F)F)(C(F)(F)F)c1ccc(cn1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
830270 | \n", "830270 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
6692276 | \n", "6692276 | \n", "Ic1ccc(nc1)C(C(F)(F)F)(C(F)(F)F)c1ccc(nc1)C(C(... | \n", "544.42 | \n", "343.89398 | \n", "3284.48720 | \n", "18341.89400 | \n", "528.13604 | \n", "161379.46000 | \n", "False | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
10950 | \n", "10950 | \n", "Ic1cc(Oc2cc(Oc3cc(cc(c3)n3c(=O)c4c(c3=O)cc3c(c... | \n", "538.15 | \n", "1.23546 | \n", "0.00105 | \n", "0.00404 | \n", "0.00195 | \n", "0.00785 | \n", "True | \n", "
5309014 | \n", "5309014 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "552.00 | \n", "2.56385 | \n", "0.00138 | \n", "0.00777 | \n", "0.00287 | \n", "0.00772 | \n", "False | \n", "
2598418 | \n", "2598418 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "552.00 | \n", "2.56385 | \n", "0.00138 | \n", "0.00777 | \n", "0.00287 | \n", "0.00772 | \n", "False | \n", "
746918 | \n", "746918 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "552.00 | \n", "2.56385 | \n", "0.00138 | \n", "0.00777 | \n", "0.00287 | \n", "0.00772 | \n", "False | \n", "
2255 | \n", "2255 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1cc(cc(c1)C(=O)O)C(=O... | \n", "586.80 | \n", "1.80544 | \n", "0.00068 | \n", "0.00330 | \n", "0.00184 | \n", "0.00419 | \n", "False | \n", "
6726950 rows × 9 columns
\n", "\n", " | SMILES | \n", "exp_perm_CO2__Barrer_mean | \n", "
---|---|---|
0 | \n", "[*]C(=O)c1ccc(N2C(=O)c3ccc(C(=O)c4ccc(C(=O)c5c... | \n", "0.669329 | \n", "
1 | \n", "[*]c1ccc(Oc2cccc(Oc3ccc(N4C(=O)c5ccc(C(=O)c6cc... | \n", "0.404411 | \n", "
2 | \n", "[*]C(=O)c1ccc(N2C(=O)c3cccc(C(=O)c4ccc(Oc5ccc(... | \n", "0.785077 | \n", "
3 | \n", "[*]C(=O)c1cccc(C(=O)c2ccc3c(c2)C(=O)N(c2ccc(C(... | \n", "0.386811 | \n", "
4 | \n", "[*]Cc1ccc(N2C(=O)c3ccc(Oc4ccc5c(c4)C(=O)N(c4cc... | \n", "0.733848 | \n", "
... | \n", "... | \n", "... | \n", "
95 | \n", "[*]C(=O)c1ccc2c(c1)C(=O)N(c1cccc3c(N4C(=O)c5cc... | \n", "1.072687 | \n", "
96 | \n", "[*]C(=O)c1ccc2c(c1)C(=O)c1cc(C(=O)c3ccc(N4C(=O... | \n", "0.705203 | \n", "
97 | \n", "[*]C(=O)c1cccc(N2C(=O)c3ccc(C(=O)c4ccc5c(c4)C(... | \n", "0.285455 | \n", "
98 | \n", "[*]C(=O)c1ccc(C(=O)c2cccc3c2C(=O)N(c2ccc(C(=O)... | \n", "1.053971 | \n", "
99 | \n", "[*]C(=O)c1cccc(Oc2ccc3c(c2)C(=O)N(c2ccc(C(=O)c... | \n", "0.651191 | \n", "
100 rows × 2 columns
\n", "\n", " | Unnamed: 0 | \n", "Smiles | \n", "Tg | \n", "He | \n", "N2 | \n", "O2 | \n", "CH4 | \n", "CO2 | \n", "synthesizable | \n", "new_preds | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... | \n", "494.84 | \n", "2.69524 | \n", "4.75740 | \n", "42.31847 | \n", "1.64086 | \n", "148.43644 | \n", "False | \n", "1.036345 | \n", "
1 | \n", "1 | \n", "Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... | \n", "508.26 | \n", "5.33815 | \n", "2.97239 | \n", "26.31118 | \n", "0.86467 | \n", "82.37635 | \n", "False | \n", "2.058726 | \n", "
2 | \n", "2 | \n", "O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... | \n", "640.91 | \n", "20.47515 | \n", "0.06353 | \n", "0.90498 | \n", "0.06905 | \n", "2.35993 | \n", "False | \n", "21.385293 | \n", "
3 | \n", "3 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... | \n", "568.04 | \n", "4.19692 | \n", "0.00191 | \n", "0.01134 | \n", "0.00362 | \n", "0.01418 | \n", "False | \n", "0.802363 | \n", "
4 | \n", "4 | \n", "Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... | \n", "548.10 | \n", "142.68327 | \n", "0.87380 | \n", "8.25409 | \n", "2.52067 | \n", "30.04739 | \n", "False | \n", "132.455960 | \n", "
5 | \n", "5 | \n", "Cc1cc(Sc2ccc(nc2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=... | \n", "501.82 | \n", "12.37968 | \n", "1.68540 | \n", "12.25254 | \n", "1.03299 | \n", "67.42722 | \n", "False | \n", "21.316367 | \n", "
6 | \n", "6 | \n", "Cc1cc(Sc2ccc(cn2)Sc2cc(C)c(c(c2)C)I)cc(c1N1C(=... | \n", "501.82 | \n", "12.37968 | \n", "1.68540 | \n", "12.25254 | \n", "1.03299 | \n", "67.42722 | \n", "False | \n", "21.316372 | \n", "
7 | \n", "7 | \n", "Clc1cc(ccc1C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1)Cl)... | \n", "529.64 | \n", "10.15046 | \n", "0.02021 | \n", "0.25611 | \n", "0.13249 | \n", "0.51867 | \n", "False | \n", "1.720072 | \n", "
8 | \n", "8 | \n", "Ic1ccc(c(c1)Cl)C(=O)c1cc(ccc1C)C(=O)c1ccc(c(c1... | \n", "529.64 | \n", "10.15046 | \n", "0.02021 | \n", "0.25611 | \n", "0.13249 | \n", "0.51867 | \n", "False | \n", "1.720071 | \n", "
9 | \n", "9 | \n", "Ic1ccc(c(c1)C)C(=O)c1cc2ccccc2cc1C(=O)c1ccc(cc... | \n", "556.31 | \n", "20.84491 | \n", "0.05426 | \n", "0.43648 | \n", "0.06503 | \n", "0.84243 | \n", "False | \n", "3.767171 | \n", "
10 | \n", "10 | \n", "Ic1ccc(c(c1)Sc1cc(cc(c1)C(=O)O)Sc1cc(ccc1C)N1C... | \n", "470.24 | \n", "3.03201 | \n", "0.02144 | \n", "0.15934 | \n", "0.02558 | \n", "0.33708 | \n", "True | \n", "1.093549 | \n", "
11 | \n", "11 | \n", "Cc1cc(cc(c1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C... | \n", "575.93 | \n", "382.37918 | \n", "4.71113 | \n", "27.50459 | \n", "4.06784 | \n", "134.46742 | \n", "False | \n", "262.278064 | \n", "
12 | \n", "12 | \n", "Ic1cc(C)c(c(c1)C)C(c1ccc2c(c1)[nH]c1c2ccc(c1)C... | \n", "537.39 | \n", "108.94866 | \n", "1.07881 | \n", "8.35180 | \n", "1.67890 | \n", "39.29826 | \n", "True | \n", "206.577505 | \n", "
13 | \n", "13 | \n", "Ic1cc(Oc2ccc(cc2)C2(c3ccc(cc3)Oc3cc(cc(c3)N3C(... | \n", "556.56 | \n", "4.47037 | \n", "0.04159 | \n", "0.28007 | \n", "0.03887 | \n", "0.73329 | \n", "True | \n", "9.642768 | \n", "
14 | \n", "14 | \n", "Ic1ccc(c(c1)Cl)Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)... | \n", "481.62 | \n", "4.67411 | \n", "0.01464 | \n", "0.17388 | \n", "0.05333 | \n", "0.64280 | \n", "False | \n", "1.471597 | \n", "
15 | \n", "15 | \n", "Clc1cc(ccc1Sc1ccc2c(c1)ccc(c2)Sc1ccc(c(c1)Cl)I... | \n", "481.62 | \n", "4.67411 | \n", "0.01464 | \n", "0.17388 | \n", "0.05333 | \n", "0.64280 | \n", "False | \n", "1.471597 | \n", "
16 | \n", "16 | \n", "Cc1cc(ccc1N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1)C(=O... | \n", "596.88 | \n", "18.21650 | \n", "0.09935 | \n", "0.68089 | \n", "0.12015 | \n", "2.16558 | \n", "True | \n", "2.989342 | \n", "
17 | \n", "17 | \n", "Ic1ccc(c(c1)C)C(c1ccc(cc1C)C(c1ccc(c(c1)C)N1C(... | \n", "467.07 | \n", "58.20149 | \n", "0.51975 | \n", "3.97824 | \n", "0.68842 | \n", "15.89978 | \n", "True | \n", "35.433246 | \n", "
18 | \n", "18 | \n", "O=C1c2cccc(c2C(=O)N1c1ccc(c(c1)C)C(c1ccc(cc1C)... | \n", "467.07 | \n", "58.20149 | \n", "0.51975 | \n", "3.97824 | \n", "0.68842 | \n", "15.89978 | \n", "True | \n", "35.433265 | \n", "
19 | \n", "19 | \n", "Cc1cc(c(cc1N1C(=O)c2c(C1=O)cc(cc2)c1ccc2c(c1)C... | \n", "668.95 | \n", "35.18491 | \n", "0.13804 | \n", "1.23473 | \n", "0.07419 | \n", "4.58177 | \n", "False | \n", "7.471155 | \n", "
20 | \n", "20 | \n", "O=C1c2cc(ccc2c2c1cc(cc2)C(C(F)(F)F)(C(F)(F)F)c... | \n", "543.52 | \n", "111.91630 | \n", "1.33930 | \n", "9.69874 | \n", "0.95233 | \n", "24.32080 | \n", "False | \n", "241.805804 | \n", "
21 | \n", "21 | \n", "Ic1ccc(cn1)C(c1cc(Cl)cc(c1)C(c1ccc(nc1)N1C(=O)... | \n", "508.98 | \n", "15.30320 | \n", "7.27329 | \n", "75.86246 | \n", "8.07197 | \n", "315.19387 | \n", "False | \n", "8.272223 | \n", "
22 | \n", "22 | \n", "O=C1c2cccc(c2C(=O)N1c1cc(C)c(c(c1)C)S(=O)(=O)c... | \n", "628.48 | \n", "39.34309 | \n", "0.18496 | \n", "1.18768 | \n", "0.02477 | \n", "3.25713 | \n", "True | \n", "25.922022 | \n", "
23 | \n", "23 | \n", "Ic1cc(C)c(c(c1)C)S(=O)(=O)c1ccc2c(c1)ccc(c2)S(... | \n", "628.48 | \n", "39.34309 | \n", "0.18496 | \n", "1.18768 | \n", "0.02477 | \n", "3.25713 | \n", "True | \n", "25.922022 | \n", "
24 | \n", "24 | \n", "Ic1cc(Sc2cc(cc(c2)n2c(=O)c3c(c2=O)cc2c(c3)c(=O... | \n", "564.08 | \n", "1.07922 | \n", "0.00146 | \n", "0.00629 | \n", "0.00217 | \n", "0.01674 | \n", "True | \n", "0.422021 | \n", "
25 | \n", "25 | \n", "Ic1ccc(cn1)C(=O)c1cc(C)c(c(c1)C)C(=O)c1ccc(cn1... | \n", "564.33 | \n", "11.39561 | \n", "11.73863 | \n", "80.00608 | \n", "8.36759 | \n", "438.07856 | \n", "False | \n", "6.608806 | \n", "
26 | \n", "26 | \n", "Ic1ccc(nc1)C(=O)c1c(C)cc(cc1C)C(=O)c1ccc(nc1)n... | \n", "564.33 | \n", "11.39561 | \n", "11.73863 | \n", "80.00608 | \n", "8.36759 | \n", "438.07856 | \n", "False | \n", "6.608806 | \n", "
27 | \n", "27 | \n", "Ic1cc(cc(c1)C(F)(F)F)C(=O)c1cc(C)c(cc1C)C(=O)c... | \n", "526.25 | \n", "47.17589 | \n", "0.25597 | \n", "1.83550 | \n", "0.38171 | \n", "4.23498 | \n", "False | \n", "14.162805 | \n", "
28 | \n", "28 | \n", "Ic1ccc(cn1)S(=O)(=O)c1cc(C)c(c(c1)C)c1c(C)cc(c... | \n", "618.79 | \n", "21.18174 | \n", "8.28968 | \n", "54.04887 | \n", "0.66770 | \n", "251.39951 | \n", "False | \n", "41.374632 | \n", "
29 | \n", "29 | \n", "Ic1ccc(c(c1)C(C(F)(F)F)(C(F)(F)F)c1cc(C)c(c(c1... | \n", "562.21 | \n", "208.21051 | \n", "2.21987 | \n", "15.29097 | \n", "2.58383 | \n", "58.80082 | \n", "False | \n", "373.653450 | \n", "
\n", " | Unnamed: 0 | \n", "Smiles | \n", "Tg | \n", "He | \n", "N2 | \n", "O2 | \n", "CH4 | \n", "CO2 | \n", "synthesizable | \n", "new_preds | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... | \n", "494.84 | \n", "2.69524 | \n", "4.75740 | \n", "42.31847 | \n", "1.64086 | \n", "148.43644 | \n", "False | \n", "1.036345 | \n", "
1 | \n", "1 | \n", "Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... | \n", "508.26 | \n", "5.33815 | \n", "2.97239 | \n", "26.31118 | \n", "0.86467 | \n", "82.37635 | \n", "False | \n", "2.058726 | \n", "
2 | \n", "2 | \n", "O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... | \n", "640.91 | \n", "20.47515 | \n", "0.06353 | \n", "0.90498 | \n", "0.06905 | \n", "2.35993 | \n", "False | \n", "21.385310 | \n", "
3 | \n", "3 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... | \n", "568.04 | \n", "4.19692 | \n", "0.00191 | \n", "0.01134 | \n", "0.00362 | \n", "0.01418 | \n", "False | \n", "0.802363 | \n", "
4 | \n", "4 | \n", "Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... | \n", "548.10 | \n", "142.68327 | \n", "0.87380 | \n", "8.25409 | \n", "2.52067 | \n", "30.04739 | \n", "False | \n", "132.455960 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "95 | \n", "Ic1cc(cc(c1)C(=O)O)Oc1ccc(cc1C)Oc1cc(cc(c1)N1C... | \n", "528.86 | \n", "3.65897 | \n", "0.00916 | \n", "0.04655 | \n", "0.01135 | \n", "0.10014 | \n", "True | \n", "1.777913 | \n", "
96 | \n", "96 | \n", "Ic1cc(Oc2ccc(c(c2)C)Oc2cc(cc(c2)N2C(=O)c3c(C2=... | \n", "528.86 | \n", "3.65897 | \n", "0.00916 | \n", "0.04655 | \n", "0.01135 | \n", "0.10014 | \n", "True | \n", "1.777911 | \n", "
97 | \n", "97 | \n", "Cc1cc2c3cc(C)c(cc3S(=O)(=O)c2cc1Oc1ccc(c(c1)C)... | \n", "554.76 | \n", "27.71933 | \n", "0.37086 | \n", "2.36356 | \n", "0.10131 | \n", "7.45523 | \n", "False | \n", "10.918775 | \n", "
98 | \n", "98 | \n", "Clc1c(ccc(c1Cl)N1C(=O)c2c(C1=O)cccc2c1ccc2c(c1... | \n", "575.95 | \n", "7.17213 | \n", "0.01350 | \n", "0.26675 | \n", "0.21026 | \n", "0.83210 | \n", "False | \n", "2.465392 | \n", "
99 | \n", "99 | \n", "Ic1ccc(nc1)S(=O)(=O)c1ccc(c(c1Cl)Cl)S(=O)(=O)c... | \n", "639.66 | \n", "6.86828 | \n", "2.80338 | \n", "26.78062 | \n", "0.40880 | \n", "117.95785 | \n", "False | \n", "4.053435 | \n", "
100 rows × 10 columns
\n", "\n", " | Unnamed: 0 | \n", "Smiles | \n", "Tg | \n", "He | \n", "N2 | \n", "O2 | \n", "CH4 | \n", "CO2 | \n", "synthesizable | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "Ic1ccc(nc1)Cc1cccc(c1)Cc1ccc(cn1)N1C(=O)c2c(C1... | \n", "494.84 | \n", "2.69524 | \n", "4.75740 | \n", "42.31847 | \n", "1.64086 | \n", "148.43644 | \n", "False | \n", "
1 | \n", "1 | \n", "Ic1ccc(nc1)Cc1ccc2c(c1)cc(cc2)Cc1ccc(cn1)N1C(=... | \n", "508.26 | \n", "5.33815 | \n", "2.97239 | \n", "26.31118 | \n", "0.86467 | \n", "82.37635 | \n", "False | \n", "
2 | \n", "2 | \n", "O=C1c2cc(ccc2C(=O)N1c1ccc(c(c1Cl)Cl)S(=O)(=O)c... | \n", "640.91 | \n", "20.47515 | \n", "0.06353 | \n", "0.90498 | \n", "0.06905 | \n", "2.35993 | \n", "False | \n", "
3 | \n", "3 | \n", "Ic1cc(cc(c1)C(=O)O)C(=O)c1ccc2c(c1)ccc(c2)C(=O... | \n", "568.04 | \n", "4.19692 | \n", "0.00191 | \n", "0.01134 | \n", "0.00362 | \n", "0.01418 | \n", "False | \n", "
4 | \n", "4 | \n", "Clc1cc(cc(c1)C(C(F)(F)F)(C(F)(F)F)c1c(C)cc(cc1... | \n", "548.10 | \n", "142.68327 | \n", "0.87380 | \n", "8.25409 | \n", "2.52067 | \n", "30.04739 | \n", "False | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
6726945 | \n", "6726945 | \n", "Ic1cccc(c1)Cc1cc(C)c(c(c1)C)c1c(C)cc(cc1C)Cc1c... | \n", "516.27 | \n", "13.32967 | \n", "0.11907 | \n", "1.10144 | \n", "0.16907 | \n", "2.63642 | \n", "False | \n", "
6726946 | \n", "6726946 | \n", "Ic1ccc(nc1)Cc1ccc2c(c1)c1cc(ccc1C2(C)C)Cc1ccc(... | \n", "510.72 | \n", "8.96454 | \n", "7.92257 | \n", "72.92929 | \n", "2.17324 | \n", "247.14446 | \n", "False | \n", "
6726947 | \n", "6726947 | \n", "Ic1ccc(cn1)S(=O)(=O)c1ccc(nc1)N1C(=O)c2c(C1=O)... | \n", "610.24 | \n", "4.90758 | \n", "20.24364 | \n", "121.24494 | \n", "1.01011 | \n", "650.18364 | \n", "False | \n", "
6726948 | \n", "6726948 | \n", "Ic1ccc(c(c1)C)Oc1ccc2c(c1)Cc1c2ccc(c1)Oc1ccc(c... | \n", "510.90 | \n", "12.40907 | \n", "0.19307 | \n", "1.41335 | \n", "0.11728 | \n", "4.00573 | \n", "True | \n", "
6726949 | \n", "6726949 | \n", "Ic1cc(Sc2ccc3c(c2)cc(cc3)Sc2cc(cc(c2)C(F)(F)F)... | \n", "462.31 | \n", "14.96342 | \n", "0.13916 | \n", "1.05252 | \n", "0.09298 | \n", "2.54411 | \n", "False | \n", "
6726950 rows × 9 columns
\n", "