# ============================================================================== # 1. IMPORTS # ============================================================================== import os import warnings import wandb import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset import numpy as np from tqdm import tqdm from rdkit import Chem, RDLogger from datasets import load_dataset, load_from_disk from transformers import AutoTokenizer, BertModel, BertConfig import pandas as pd # ============================================================================== # 2. INITIAL SETUP # ============================================================================== # Suppress RDKit console output RDLogger.DisableLog('rdApp.*') # Ignore warnings for cleaner output warnings.filterwarnings("ignore") # ============================================================================== # 3. MODEL AND LOSS FUNCTION # ============================================================================== def global_average_pooling(x): """Global Average Pooling: from [B, max_len, hid_dim] to [B, hid_dim]""" return torch.mean(x, dim=1) class SimSonEncoder(nn.Module): """The main encoder model based on BERT.""" def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): super(SimSonEncoder, self).__init__() self.bert = BertModel(config, add_pooling_layer=False) self.linear = nn.Linear(config.hidden_size, max_len) self.dropout = nn.Dropout(dropout) def forward(self, input_ids, attention_mask=None): if attention_mask is None: attention_mask = input_ids.ne(self.bert.config.pad_token_id) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) hidden_states = self.dropout(outputs.last_hidden_state) pooled_output = global_average_pooling(hidden_states) return self.linear(pooled_output) class ContrastiveLoss(nn.Module): """Calculates the contrastive loss for the SimSon model.""" def __init__(self, temperature=0.2): super(ContrastiveLoss, self).__init__() self.temperature = temperature self.similarity_fn = F.cosine_similarity def forward(self, proj_1, proj_2): batch_size = proj_1.shape[0] device = proj_1.device # Normalize projections z_i = F.normalize(proj_1, p=2, dim=1) z_j = F.normalize(proj_2, p=2, dim=1) # Concatenate for similarity matrix calculation representations = torch.cat([z_i, z_j], dim=0) # Calculate cosine similarity between all pairs similarity_matrix = self.similarity_fn(representations.unsqueeze(1), representations.unsqueeze(0), dim=2) # Identify positive pairs (original and its augmentation) sim_ij = torch.diag(similarity_matrix, batch_size) sim_ji = torch.diag(similarity_matrix, -batch_size) positives = torch.cat([sim_ij, sim_ji], dim=0) # Create a mask to exclude self-comparisons nominator = torch.exp(positives / self.temperature) mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float() denominator = mask * torch.exp(similarity_matrix / self.temperature) # Calculate the final loss loss = -torch.log(nominator / torch.sum(denominator, dim=1)) return torch.sum(loss) / (2 * batch_size) # ============================================================================== # 4. DATA HANDLING (Keeping your existing classes unchanged) # ============================================================================== class SmilesEnumerator: """Generates randomized SMILES strings for data augmentation.""" def randomize_smiles(self, smiles): try: mol = Chem.MolFromSmiles(smiles) return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles except: return smiles class ContrastiveSmilesDataset(Dataset): """Dataset for creating pairs of augmented SMILES for contrastive learning.""" def __init__(self, smiles_list, tokenizer, max_length=512): self.smiles_list = smiles_list self.tokenizer = tokenizer self.max_length = max_length self.enumerator = SmilesEnumerator() def __len__(self): return len(self.smiles_list) def __getitem__(self, idx): original_smiles = self.smiles_list[idx] # Create two different augmentations of the same SMILES smiles_1 = self.enumerator.randomize_smiles(original_smiles) smiles_2 = self.enumerator.randomize_smiles(original_smiles) # Tokenize and do pad. Padding will be handled by the collate_fn. tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length') tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length') return { 'input_ids_1': torch.tensor(tokens_1['input_ids']), 'attention_mask_1': torch.tensor(tokens_1['attention_mask']), 'input_ids_2': torch.tensor(tokens_2['input_ids']), 'attention_mask_2': torch.tensor(tokens_2['attention_mask']), } class PrecomputedContrastiveSmilesDataset(Dataset): """ A Dataset class that reads pre-augmented SMILES pairs from a Parquet file. This is significantly faster as it offloads the expensive SMILES randomization to a one-time preprocessing step. """ def __init__(self, tokenizer, file_path: str, max_length: int = 512): self.tokenizer = tokenizer self.max_length = max_length # Load the entire dataset from the Parquet file into memory. # This is fast and efficient for subsequent access. print(f"Loading pre-computed data from {file_path}...") self.data = pd.read_parquet(file_path) print("Data loaded successfully.") def __len__(self): """Returns the total number of pairs in the dataset.""" return len(self.data) def __getitem__(self, idx): """ Retrieves a pre-augmented pair, tokenizes it, and returns it in the format expected by the DataCollator. """ # Retrieve the pre-augmented pair from the DataFrame row = self.data.iloc[idx] smiles_1 = row['smiles_1'] smiles_2 = row['smiles_2'] # Tokenize the pair. This operation is fast and remains in the data loader. tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length') tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length') return { 'input_ids_1': torch.tensor(tokens_1['input_ids']), 'attention_mask_1': torch.tensor(tokens_1['attention_mask']), 'input_ids_2': torch.tensor(tokens_2['input_ids']), 'attention_mask_2': torch.tensor(tokens_2['attention_mask']), } class PreTokenizedSmilesDataset(Dataset): """ A Dataset that loads a pre-tokenized and pre-padded dataset created by the preprocessing script. It uses memory-mapping for instant loads and high efficiency. """ def __init__(self, dataset_path: str): # Load the dataset from disk. This is very fast due to memory-mapping. self.dataset = load_from_disk(dataset_path) # Set the format to PyTorch tensors for direct use in the model self.dataset.set_format(type='torch', columns=[ 'input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2' ]) print(f"Successfully loaded pre-tokenized dataset from {dataset_path}.") def __len__(self): """Returns the total number of items in the dataset.""" return len(self.dataset) def __getitem__(self, idx): """Retrieves a single pre-processed item.""" return self.dataset[idx] class DataCollatorWithPadding: """ A collate function that dynamically pads inputs to the longest sequence across both augmented views in the batch, ensuring consistent tensor shapes. """ def __init__(self, tokenizer): self.tokenizer = tokenizer def __call__(self, features): # Create a combined list of features for both views to find the global max length combined_features = [] for feature in features: combined_features.append({'input_ids': feature['input_ids_1'], 'attention_mask': feature['attention_mask_1']}) combined_features.append({'input_ids': feature['input_ids_2'], 'attention_mask': feature['attention_mask_2']}) # Pad the combined batch. This ensures all sequences are padded to the same length. padded_combined = self.tokenizer.pad(combined_features, padding='longest', return_tensors='pt') # Split the padded tensors back into two views batch_size = len(features) input_ids_1, input_ids_2 = torch.split(padded_combined['input_ids'], batch_size, dim=0) attention_mask_1, attention_mask_2 = torch.split(padded_combined['attention_mask'], batch_size, dim=0) return { 'input_ids_1': input_ids_1, 'attention_mask_1': attention_mask_1, 'input_ids_2': input_ids_2, 'attention_mask_2': attention_mask_2, } # ============================================================================== # 5. CHECKPOINT UTILITIES # ============================================================================== def save_checkpoint(model, optimizer, scheduler, global_step, save_path): """Save complete checkpoint with model, optimizer, scheduler states and step count.""" checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'global_step': global_step, } torch.save(checkpoint, save_path) print(f"Full checkpoint saved at step {global_step}") def load_checkpoint(checkpoint_path, model, optimizer, scheduler): """Load checkpoint and return the global step to resume from.""" checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) global_step = checkpoint['global_step'] print(f"Checkpoint loaded from step {global_step}") return global_step # ============================================================================== # 6. TRAINING AND EVALUATION LOOPS - MODIFIED # ============================================================================== def evaluation_step(model, batch, criterion, device): """Performs a single evaluation step on a batch of data.""" input_ids_1 = batch['input_ids_1'].to(device) attention_mask_1 = batch['attention_mask_1'].to(device) input_ids_2 = batch['input_ids_2'].to(device) attention_mask_2 = batch['attention_mask_2'].to(device) combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0) combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0) with torch.no_grad(): combined_proj = model(combined_input_ids, combined_attention_mask) batch_size = input_ids_1.size(0) proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0) loss = criterion(proj_1, proj_2) return proj_1, proj_2, loss def train_with_step_based_validation(model, train_loader, val_loader, optimizer, criterion, device, scheduler, checkpoint_path, save_steps, validation_steps, start_step=0, max_steps=None): """ Modified training function with step-based validation and checkpointing. """ model.train() global_step = start_step best_val_loss = float('inf') # Calculate total steps if max_steps is not provided if max_steps is None: max_steps = len(train_loader) progress_bar = tqdm(total=max_steps - start_step, desc="Training Steps", initial=start_step) # Create iterator that can be resumed from any point train_iterator = iter(train_loader) # Skip batches if resuming from checkpoint if start_step > 0: batches_to_skip = start_step % len(train_loader) for _ in range(batches_to_skip): try: next(train_iterator) except StopIteration: train_iterator = iter(train_loader) while global_step < max_steps: try: batch = next(train_iterator) except StopIteration: train_iterator = iter(train_loader) batch = next(train_iterator) # Training step input_ids_1 = batch['input_ids_1'].to(device) attention_mask_1 = batch['attention_mask_1'].to(device) input_ids_2 = batch['input_ids_2'].to(device) attention_mask_2 = batch['attention_mask_2'].to(device) optimizer.zero_grad() with torch.autocast(dtype=torch.float16, device_type="cuda"): combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0) combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0) combined_proj = model(combined_input_ids, combined_attention_mask) batch_size = input_ids_1.size(0) proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0) loss = criterion(proj_1, proj_2) loss.backward() optimizer.step() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scheduler.step() global_step += 1 progress_bar.update(1) progress_bar.set_postfix(loss=f"{loss.item():.4f}", step=global_step) wandb.log({ "train_batch_loss": loss.item(), "learning_rate": scheduler.get_last_lr()[0], "global_step": global_step }) # Step-based validation if global_step % validation_steps == 0: val_loss = validate_epoch(model, val_loader, criterion, device) wandb.log({ "val_loss": val_loss, "global_step": global_step }) # Save best model (model state only for best checkpoint) if val_loss < best_val_loss: best_val_loss = val_loss model_save_path = checkpoint_path.replace('.pt', '_best_model.bin') torch.save(model.state_dict(), model_save_path) progress_bar.write(f"Step {global_step}: New best model saved with val loss {val_loss:.4f}") model.train() # Resume training mode after validation # Step-based checkpointing (full checkpoint) if global_step % save_steps == 0: save_checkpoint(model, optimizer, scheduler, global_step, checkpoint_path) progress_bar.close() return global_step def validate_epoch(model, val_loader, criterion, device): """Validation function - unchanged from original.""" model.eval() total_loss = 0 progress_bar = tqdm(val_loader, desc="Validating", leave=False) for batch in progress_bar: _, _, loss = evaluation_step(model, batch, criterion, device) total_loss += loss.item() avg_loss = total_loss / len(val_loader) print(f'Validation loss: {avg_loss:.4f}') return avg_loss def test_model(model, test_loader, criterion, device): """Test function - unchanged from original.""" model.eval() total_loss = 0 all_similarities = [] progress_bar = tqdm(test_loader, desc="Testing", leave=False) for batch in progress_bar: proj_1, proj_2, loss = evaluation_step(model, batch, criterion, device) total_loss += loss.item() proj_1_norm = F.normalize(proj_1, p=2, dim=1) proj_2_norm = F.normalize(proj_2, p=2, dim=1) batch_similarities = F.cosine_similarity(proj_1_norm, proj_2_norm, dim=1) all_similarities.extend(batch_similarities.cpu().numpy()) avg_loss = total_loss / len(test_loader) avg_sim = np.mean(all_similarities) std_sim = np.std(all_similarities) return avg_loss, avg_sim, std_sim # ============================================================================== # 7. MODIFIED SINGLE-GPU TRAINING # ============================================================================== def run_training(model_config, hparams, data_splits): """The main function to run the training and evaluation process with step-based validation.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") wandb_key = os.getenv("WANDB_API_KEY") if wandb_key: wandb.login(key=wandb_key) wandb.init( #project="simson-contrastive-learning-single-gpu", #name=f"run-{wandb.util.generate_id()}", #config=hparams ) train_smiles, val_smiles, test_smiles = data_splits tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') precomputed_train_path = '/home/jovyan/simson_training_bolgov/data/pubchem_119m_splits/train.parquet' precomputed_test_path = '/home/jovyan/simson_training_bolgov/data/pubchem_119m_splits/test.parquet' precomputed_val_path = '/home/jovyan/simson_training_bolgov/data/pubchem_119m_splits/validation.parquet' train_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_train_path, max_length=hparams['max_length']) test_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_test_path, max_length=hparams['max_length']) val_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_val_path, max_length=hparams['max_length']) train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, num_workers=8, prefetch_factor=128, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True) print('Initialized all data. Compiling the model...') model = SimSonEncoder(config=model_config, max_len=hparams['max_embeddings']).to(device) model = torch.compile(model) model.load_state_dict(torch.load('/home/jovyan/simson_training_bolgov/simson_checkpoints/checkpoint_best_model.bin')) print(model) total_params = sum(p.numel() for p in model.parameters()) print(f"Total number of parameters: {total_params // 1_000_000} M") wandb.config.update({"total_params_M": total_params // 1_000_000}) criterion = ContrastiveLoss(temperature=hparams['temperature']).to(device) optimizer = optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=1e-5, fused=True) total_steps = hparams['epochs'] * len(train_loader) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=total_steps) print("Starting training...") wandb.watch(model, log='all', log_freq=5000) start_step = 0 checkpoint_path = hparams['checkpoint_path'] # Resume from checkpoint if provided if hparams.get('resume_checkpoint') and os.path.exists(hparams['resume_checkpoint']): print(f"Resuming from checkpoint: {hparams['resume_checkpoint']}") start_step = load_checkpoint(hparams['resume_checkpoint'], model, optimizer, scheduler) # Train with step-based validation final_step = train_with_step_based_validation( model, train_loader, val_loader, optimizer, criterion, device, scheduler, checkpoint_path, hparams['save_steps'], hparams['validation_steps'], start_step=start_step, max_steps=total_steps ) print("Training complete. Starting final testing...") # Load the best model for testing (model state only) best_model_path = checkpoint_path.replace('.pt', '_best_model.bin') if os.path.exists(best_model_path): model.load_state_dict(torch.load(best_model_path)) print("Loaded best model for testing") test_loss, avg_sim, std_sim = test_model(model, test_loader, criterion, device) print("\n--- Test Results ---") print(f"Test Loss: {test_loss:.4f}") print(f"Average Cosine Similarity: {avg_sim:.4f} ± {std_sim:.4f}") print("--------------------") wandb.log({ "test_loss": test_loss, "avg_cosine_similarity": avg_sim, "std_cosine_similarity": std_sim }) # Save final model state only final_model_path = hparams['save_path'] torch.save(model.state_dict(), final_model_path) print(f"Final model saved to {final_model_path}") wandb.finish() # ============================================================================== # 8. MAIN EXECUTION # ============================================================================== def main(): """Main function to configure and run the training process.""" hparams = { 'epochs': 1, 'lr': 6e-6, 'temperature': 0.05, 'batch_size': 128, 'max_length': 256, 'save_path': "simson_checkpoints_more_epochs/simson_model_single_gpu.bin", 'checkpoint_path': "simson_checkpoints_more_epochs/checkpoint.pt", # Full checkpoint 'save_steps': 50000, # Save checkpoint every 10k steps 'validation_steps': 5000, # Validate every 5k steps 'max_embeddings': 512, 'resume_checkpoint': None, # Set to checkpoint path to resume } dataset = load_dataset('HoangHa/SMILES-250M')['train'] smiles_column_name = 'SMILES' total_size = len(dataset) test_size = int(0.1 * total_size) val_size = int(0.1 * (total_size - test_size)) test_smiles = dataset.select(range(test_size))[smiles_column_name] val_smiles = dataset.select(range(test_size, test_size + val_size))[smiles_column_name] train_smiles = dataset.select(range(test_size + val_size, total_size))[smiles_column_name] data_splits = (train_smiles, val_smiles, test_smiles) tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') model_config = BertConfig( vocab_size=tokenizer.vocab_size, hidden_size=768, num_hidden_layers=4, num_attention_heads=12, intermediate_size=2048, max_position_embeddings=512 ) # Create directories save_dir = os.path.dirname(hparams['save_path']) checkpoint_dir = os.path.dirname(hparams['checkpoint_path']) for directory in [save_dir, checkpoint_dir]: if not os.path.exists(directory): os.makedirs(directory) # Directly call the training function for a single-GPU run run_training(model_config, hparams, data_splits) if __name__ == '__main__': main()