In [11]:
import torch
from tqdm import tqdm
import os
import warnings

import joblib
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (
 AutoConfig,
 AutoModel,
 AutoTokenizer,
 BertConfig,
 BertModel,
 BertTokenizerFast,
 PreTrainedModel,
)
from transformers.activations import ACT2FN

def global_ap(x):
 return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)

class SimSonEncoder(nn.Module):
 def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
 super(SimSonEncoder, self).__init__()
 self.config = config
 self.max_len = max_len
 
 self.bert = BertModel(config, add_pooling_layer=False)
 
 self.linear = nn.Linear(config.hidden_size, max_len)
 self.dropout = nn.Dropout(dropout)
 
 def forward(self, input_ids, attention_mask=None):
 if attention_mask is None:
 attention_mask = input_ids.ne(0)
 
 outputs = self.bert(
 input_ids=input_ids,
 attention_mask=attention_mask
 )
 
 hidden_states = outputs.last_hidden_state
 
 hidden_states = self.dropout(hidden_states)
 
 pooled = global_ap(hidden_states)
 
 out = self.linear(pooled)
 
 return out

class SimSonClassifier(nn.Module):
 def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1):
 super(SimSonClassifier, self).__init__()
 self.encoder = encoder
 self.clf = nn.Linear(encoder.max_len, num_labels)
 self.relu = nn.ReLU()
 self.dropout = nn.Dropout(dropout)

 def forward(self, input_ids, attention_mask=None, labels=None):
 x = self.encoder(input_ids, attention_mask)
 x = self.relu(self.dropout(x))
 x = self.clf(x)
 return x

In [8]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from tqdm import tqdm

# 1. BINNING, SAMPLING, SAMPLE WEIGHTING
def compute_multitarget_sample_weights(labels_unscaled, bins4, bins5):
 # Each returns shape (N,)
 inds4 = np.digitize(labels_unscaled[:, 0], bins4, right=False) - 1
 inds5 = np.digitize(labels_unscaled[:, 1], bins5, right=False) - 1
 freq4 = np.bincount(inds4, minlength=len(bins4))
 freq5 = np.bincount(inds5, minlength=len(bins5))
 w4 = 1.0 / (freq4[inds4] + 1e-8)
 w5 = 1.0 / (freq5[inds5] + 1e-8)
 main_weights = np.maximum(w4, w5) # Or average, or sum as suits
 main_weights /= main_weights.mean() # Normalize for stability
 return main_weights # shape (N_samples,)


class TargetedSampler(Sampler):
 """
 Enforces a proportion of 'high value' samples for either of the main labels in every batch.
 """
 def __init__(self, inds4, inds5, high_bins4, high_bins5, batch_size, high_frac=0.3, shuffle=True):
 # indices for which either label 4 or label 5 is in a 'high' bin
 high_mask = (inds4 >= high_bins4) | (inds5 >= high_bins5)
 self.high_indices = np.where(high_mask)[0]
 self.low_indices = np.where(~high_mask)[0]
 self.batch_size = batch_size
 self.high_count = int(batch_size * high_frac)
 self.low_count = batch_size - self.high_count
 self.shuffle = shuffle
 
 def __iter__(self):
 high = np.copy(self.high_indices)
 low = np.copy(self.low_indices)
 if self.shuffle:
 np.random.shuffle(high)
 np.random.shuffle(low)
 hi_ptr, low_ptr = 0, 0
 while hi_ptr < len(high) or low_ptr < len(low):
 batch_high = high[hi_ptr: hi_ptr+self.high_count]
 batch_low = low[low_ptr: low_ptr+self.low_count]
 if len(batch_high) < self.high_count:
 np.random.shuffle(high)
 hi_ptr = 0
 batch_high = high[hi_ptr: hi_ptr+self.high_count]
 if len(batch_low) < self.low_count:
 np.random.shuffle(low)
 low_ptr = 0
 batch_low = low[low_ptr: low_ptr+self.low_count]
 batch = np.concatenate([batch_high, batch_low])
 np.random.shuffle(batch)
 yield batch.tolist()
 hi_ptr += self.high_count
 low_ptr += self.low_count
 
 def __len__(self):
 return (len(self.high_indices) + len(self.low_indices)) // self.batch_size

class SMILESDataset(torch.utils.data.Dataset):
 def __init__(self, smiles_list, labels, sample_weights, tokenizer, max_length=256):
 self.smiles_list = smiles_list
 self.labels = labels # shape (N, 6), already scaled
 self.tokenizer = tokenizer
 self.max_length = max_length
 self.sample_weights = sample_weights

 def __len__(self):
 return len(self.smiles_list)

 def __getitem__(self, idx):
 smiles = self.tokenizer.cls_token + self.smiles_list[idx]
 encoding = self.tokenizer(
 smiles,
 truncation=True,
 padding='max_length',
 max_length=self.max_length,
 return_tensors='pt'
 )
 return {
 'input_ids': encoding['input_ids'].flatten(),
 'attention_mask': encoding['attention_mask'].flatten(),
 'labels': torch.tensor(self.labels[idx], dtype=torch.float32),
 'weight': torch.tensor(self.sample_weights[idx], dtype=torch.float32),
 'index': torch.tensor(idx, dtype=torch.long)
 }


def calculate_weighted_loss(predictions, labels, weights):
 """
 Calculate weighted loss for two labels with masking
 
 Args:
 predictions: Model outputs (batch_size, 6)
 labels: Ground truth labels (batch_size, 6)
 label_mask: Mask for valid labels (batch_size, 6)
 label_weights: Weights for each label (6,)
 """
 loss_fn = nn.MSELoss(reduction='none')
 
 # Calculate per-sample, per-label losses
 losses = loss_fn(predictions, labels) # Shape: (batch_size, 2)
 return losses.mean()

def stratified_metrics(preds, trues, bins, scalers):
 # Only for last two labels
 results = {}
 for li, label_col in enumerate([4, 5]):
 unscaled_pred = scalers[label_col].inverse_transform(preds[:,label_col].reshape(-1,1)).flatten()
 unscaled_true = scalers[label_col].inverse_transform(trues[:,label_col].reshape(-1,1)).flatten()
 bin_idx = np.digitize(unscaled_true, bins[li])
 for i in range(len(bins[li])+1):
 in_bin = bin_idx == i
 if np.sum(in_bin) == 0:
 continue
 bin_mae = np.mean(np.abs(unscaled_pred[in_bin] - unscaled_true[in_bin]))
 bin_r2 = 1 - (np.mean((unscaled_pred[in_bin] - unscaled_true[in_bin])**2) /
 (np.var(unscaled_true[in_bin]) + 1e-8))
 results[f'label{label_col}_bin{i}_count'] = np.sum(in_bin)
 results[f'label{label_col}_bin{i}_mae'] = bin_mae
 return results

# 5. MAIN TRAIN/VAL LOOP WITH TARGETED SAMPLING AND STRATIFIED EVALUATION
def run_training(smiles_train, smiles_test, labels_train, labels_test, model, tokenizer, scalers,
 num_epochs=5, learning_rate=1e-5, batch_size=256, validation_steps=500):
 # 1. Bins for columns 4 & 5 using **unscaled train-data**
 bins_label4 = create_bins(train['unscaled_CO2'], n_bins=10)
 bins_label5 = create_bins(train['unscaled_CH4'], n_bins=10)
 bins = [bins_label4, bins_label5]
 # 2. Bin indicators for each sample in train
 inds4 = np.digitize(train['unscaled_CO2'], bins_label4, right=False) - 1
 inds5 = np.digitize(train['unscaled_CH4'], bins_label5, right=False) - 1
 # 3. Choose high-bin threshold (e.g. top bin, or top 2 bins as "high"), adjust as needed
 high_bins4 = len(bins_label4) - 1
 high_bins5 = len(bins_label5) - 1
 # 4. Compute multitarget weights (max rarity of either label-of-interest, UNscaled)
 sample_weights = compute_multitarget_sample_weights(
 train[['unscaled_CO2', 'unscaled_CH4']].values, bins_label4, bins_label5)
 val_sample_weights = compute_multitarget_sample_weights(
 test[['unscaled_CO2', 'unscaled_CH4']].values, bins_label4, bins_label5)
 # 5. Dataset and batch sampler
 targeted_sampler = TargetedSampler(inds4, inds5, high_bins4, high_bins5, batch_size, high_frac=0.3)
 train_dataset = SMILESDataset(smiles_train, labels_train, sample_weights, tokenizer)
 val_dataset = SMILESDataset(smiles_test, labels_test, val_sample_weights, tokenizer)
 train_loader = DataLoader(train_dataset, batch_sampler=targeted_sampler, num_workers=4)
 val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
 # 6. Model, optimizer, scheduler
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 model.to(device)
 optimizer = AdamW(model.parameters(), lr=learning_rate)
 total_steps = len(train_loader)*3
 scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)
 best_val_loss = float('inf')
 best_state = None
 steps_no_improve, patience = 0, 10
 global_step, running_train_loss, train_steps_count = 0, 0, 0

 for epoch in range(num_epochs):
 print(f"Epoch {epoch+1}/{num_epochs}")
 model.train()
 pbar = tqdm(train_loader, desc='Training', total=len(train_loader) * num_epochs)
 for batch in pbar:
 input_ids = batch['input_ids'].to(device)
 attention_mask = batch['attention_mask'].to(device)
 labels = batch['labels'].to(device)
 weights = batch['weight'].to(device)
 optimizer.zero_grad()
 outputs = model(input_ids=input_ids, attention_mask=attention_mask)
 loss = calculate_weighted_loss(outputs, labels, weights)
 loss.backward()
 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 optimizer.step(); scheduler.step()
 running_train_loss += loss.item(); train_steps_count += 1; global_step += 1
 pbar.set_postfix(loss=f'{loss.item():.3f}')
 if global_step % validation_steps == 0:
 avg_train_loss = running_train_loss / train_steps_count
 print(f"Step {global_step}: Mean Weighted Train Loss: {avg_train_loss:.4f}")
 running_train_loss = 0; train_steps_count = 0
 # Validation:
 all_preds, all_trues, all_weights = [], [], []
 model.eval()
 with torch.no_grad():
 for vb in val_loader:
 vi = vb['input_ids'].to(device)
 va = vb['attention_mask'].to(device)
 vl = vb['labels'].to(device)
 vw = vb['weight'].to(device)
 out = model(input_ids=vi, attention_mask=va)
 all_preds.append(out.cpu()); all_trues.append(vl.cpu()); all_weights.append(vw.cpu())
 preds = torch.cat(all_preds).numpy()
 trues = torch.cat(all_trues).numpy()
 weights = torch.cat(all_weights).numpy()
 val_loss = calculate_weighted_loss(torch.tensor(preds), torch.tensor(trues), torch.tensor(weights)).item()
 print(f"Weighted Val MSE (scaled): {val_loss:.4f}")
 metrics = stratified_metrics(preds, trues, bins, scalers)
 for k, v in metrics.items():
 print(f"{k}: {v:.4f}")
 if val_loss < best_val_loss:
 best_val_loss = val_loss; best_state = model.state_dict().copy(); steps_no_improve = 0
 torch.save(model.state_dict(), '/home/jovyan/simson_training_bolgov/regression/better_regression_states/best_state.bin')
 print(f"New best val_loss: {best_val_loss:.4f}")
 else:
 steps_no_improve += 1
 print(f'Patience meter: {steps_no_improve} out of {patience}')
 if steps_no_improve >= patience:
 print(f"Early stopping at step {global_step}")
 if best_state: model.load_state_dict(best_state)
 return
 model.train()
 if best_state: model.load_state_dict(best_state)
 print(f"Training completed, best weighted val_loss: {best_val_loss:.4f}")


In [9]:
import pandas as pd

df = pd.read_csv('/home/jovyan/simson_training_bolgov/regression/PI_Tg_P308K_synth_db_chem.csv')
df['unscaled_CO2'] = df['CO2'].copy()
df['unscaled_CH4'] = df['CH4'].copy()
targets = ['Tg', 'He', 'N2', 'O2', 'CH4', 'CO2']

In [18]:
from transformers import AutoTokenizer
tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# Only the hidden size is slightly larger, everything else is the same
config = BertConfig(
 vocab_size=tokenizer.vocab_size,
 hidden_size=768,
 num_hidden_layers=4,
 num_attention_heads=12,
 intermediate_size=2048,
 max_position_embeddings=512
 )

simson_params = torch.load('/home/jovyan/simson_training_bolgov/regression/actual_encoder_state.pkl', weights_only=False)

backbone = SimSonEncoder(config=config, max_len=512)
backbone.load_state_dict(simson_params)


model = SimSonClassifier(encoder=backbone, num_labels=len(targets))

for param in model.encoder.parameters():
 param.requires_grad = False

In [14]:
def create_stratified_splits_regression(
 df,
 label_cols,
 n_bins=10,
 val_frac=0.05,
 seed=42
):
 
 values = df[label_cols].values
 # Each label gets its own bins, based on the overall distribution
 bins = [np.unique(np.quantile(values[:,i], np.linspace(0, 1, n_bins+1))) for i in range(len(label_cols))]
 # Assign each row to a bin for each label
 inds = [
 np.digitize(values[:,i], bins[i][1:-1], right=False) # exclude leftmost/rightmost for in-bin, avoids all bin edges as bins
 for i in range(len(label_cols))
 ]
 # Combine into a single integer stratification variable (tuple or max or sum...)
 strat_col = np.maximum.reduce(inds) # This ensures high bin in one = high bin overall
 # Use sklearn's train_test_split with stratify
 train_idx, val_idx = train_test_split(
 df.index.values,
 test_size=val_frac,
 random_state=seed,
 shuffle=True,
 stratify=strat_col
 )
 train = df.loc[train_idx].reset_index(drop=True)
 val = df.loc[val_idx].reset_index(drop=True)
 return train, val


# For your use case:
train, test = create_stratified_splits_regression(
 df,
 label_cols=['unscaled_CO2', 'unscaled_CH4'], # or actual column names
 n_bins=10,
 val_frac=0.05,
 seed=42
)

In [15]:
scalers = []

for target in targets:
 target_scaler = StandardScaler()
 train[target] = target_scaler.fit_transform(train[target].to_numpy().reshape(-1, 1))
 test[target] = target_scaler.transform(test[target].to_numpy().reshape(-1, 1))
 
 scalers.append(target_scaler)

smiles_train = train['Smiles']
smiles_test = test['Smiles']

labels_train = train[targets].values
labels_test = test[targets].values

In [16]:
def create_bins(target_values, n_bins=5, strategy='percentile'):
 """
 Create bins for a target based on the specified strategy.
 - 'percentile' creates approximately equal-sized groups
 - 'uniform' creates equal-width bins
 Returns:
 bin_edges: array of length n_bins+1
 """
 target_values = target_values[~np.isnan(target_values)]
 if strategy == 'percentile':
 return np.percentile(target_values, np.linspace(0, 100, n_bins+1))
 else:
 return np.linspace(np.min(target_values), np.max(target_values), n_bins+1)

In [None]:
# model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/regression/regression_simson.pth'))

In [None]:
import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from tqdm import tqdm



train_losses, val_losses, best_loss = run_training(
 smiles_train, smiles_test, labels_train, labels_test, 
 model, tokenizer, scalers, num_epochs=6, learning_rate=2e-5, batch_size=256, validation_steps=6_500,
)

Epoch 1/6


Training: 4%|▌ | 6499/149778 [21:21<7:51:15, 5.07it/s, loss=0.414]

Step 6500: Mean Weighted Train Loss: 34.2319
Weighted Val MSE (scaled): 0.1242


Training: 4%|▍ | 6500/149778 [25:38<3067:56:23, 77.08s/it, loss=0.414]

label4_bin0_count: 1269.0000
label4_bin0_mae: 2.4281
label4_bin1_count: 208626.0000
label4_bin1_mae: 0.8366
label4_bin2_count: 53528.0000
label4_bin2_mae: 1.2479
label4_bin3_count: 37339.0000
label4_bin3_mae: 2.0683
label4_bin4_count: 21793.0000
label4_bin4_mae: 3.2345
label4_bin5_count: 8194.0000
label4_bin5_mae: 5.3938
label4_bin6_count: 3423.0000
label4_bin6_mae: 8.1634
label4_bin7_count: 1506.0000
label4_bin7_mae: 12.3689
label4_bin8_count: 610.0000
label4_bin8_mae: 28.5760
label4_bin9_count: 60.0000
label4_bin9_mae: 29.3493
label5_bin1_count: 299.0000
label5_bin1_mae: 134.7964
label5_bin2_count: 1068.0000
label5_bin2_mae: 116.7105
label5_bin3_count: 1998.0000
label5_bin3_mae: 114.5267
label5_bin4_count: 4626.0000
label5_bin4_mae: 128.3239
label5_bin5_count: 9057.0000
label5_bin5_mae: 140.8095
label5_bin6_count: 12853.0000
label5_bin6_mae: 146.8567
label5_bin7_count: 17785.0000
label5_bin7_mae: 156.6570
label5_bin8_count: 23487.0000
label5_bin8_mae: 177.3546
label5_bin9_count: 3213

Training: 9%|█ | 12999/149778 [48:04<7:49:37, 4.85it/s, loss=0.321]

Step 13000: Mean Weighted Train Loss: 0.3521
Weighted Val MSE (scaled): 0.1131
label4_bin0_count: 1269.0000
label4_bin0_mae: 1.1570
label4_bin1_count: 208626.0000
label4_bin1_mae: 0.7505
label4_bin2_count: 53528.0000
label4_bin2_mae: 1.1643
label4_bin3_count: 37339.0000
label4_bin3_mae: 1.9008
label4_bin4_count: 21793.0000
label4_bin4_mae: 3.1142
label4_bin5_count: 8194.0000
label4_bin5_mae: 5.0156
label4_bin6_count: 3423.0000
label4_bin6_mae: 7.3518
label4_bin7_count: 1506.0000
label4_bin7_mae: 9.9343
label4_bin8_count: 610.0000
label4_bin8_mae: 25.9864
label4_bin9_count: 60.0000
label4_bin9_mae: 27.8746
label5_bin1_count: 299.0000
label5_bin1_mae: 86.6469
label5_bin2_count: 1068.0000
label5_bin2_mae: 87.6131
label5_bin3_count: 1998.0000
label5_bin3_mae: 87.2714
label5_bin4_count: 4626.0000
label5_bin4_mae: 90.1329
label5_bin5_count: 9057.0000
label5_bin5_mae: 102.7583
label5_bin6_count: 12853.0000
label5_bin6_mae: 109.2145
label5_bin7_count: 17785.0000
label5_bin7_mae: 125.8342
label

Training: 9%|▊ | 13001/149778 [52:21<2053:40:30, 54.05s/it, loss=0.309]

New best val_loss: 0.1131


Training: 13%|█▎ | 19049/149778 [1:12:14<7:14:15, 5.02it/s, loss=0.264]