File size: 2,843 Bytes
9dd777e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

import os
from typing import Dict, Any
import torch
from transformers import TrainerCallback
from trl import SFTTrainer
from rdkit import Chem

from protac_splitter.llms.data_utils import load_tokenized_dataset
from protac_splitter.llms.model_utils import get_model

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use GPU if available

# Placeholder for a scoring function that evaluates the generated SMILES
def score_function(smiles1, predicted_smiles):
    """ Evaluates the generated SMILES sequence based on validity. """
    mol = Chem.MolFromSmiles(predicted_smiles)
    return 1 if mol else 0  # Returns 1 if valid, 0 if invalid

# Custom Trainer subclass to integrate SMILES evaluation
class CustomSFTTrainer(SFTTrainer):
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        if eval_dataset is None:
            eval_dataset = self.eval_dataset

        # Generate predictions
        predictions = self.predict(eval_dataset)
        generated_texts = self.tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)

        total_score = 0
        total_samples = len(generated_texts)

        for i, example in enumerate(eval_dataset):
            input_text = example["text"]  # Full input: "Smiles1 Smiles2.Smiles3.Smiles4"
            smiles1 = input_text.split(" ")[0]  # Extract Smiles1 (the prompt)

            # Remove the prompt from the generated text to get the predicted completion
            predicted_completion = generated_texts[i].removeprefix(smiles1).strip()

            # Compute custom score
            score = score_function(smiles1, predicted_completion)
            total_score += score

        # Compute average score
        average_score = total_score / total_samples if total_samples > 0 else 0

        # Log metrics
        metrics = {f"{metric_key_prefix}_average_score": average_score}
        self.log(metrics)

        return metrics

def train():
    """ Main training function """
    model = get_model()  # Load the model
    tokenizer = model.tokenizer  # Get tokenizer from model
    
    # Load dataset
    dataset = load_tokenized_dataset()

    # Training arguments
    training_args = {
        "output_dir": "./trained_model",
        "evaluation_strategy": "steps",
        "save_strategy": "steps",
        "logging_steps": 100,
        "save_steps": 500,
        "num_train_epochs": 3,
        "per_device_train_batch_size": 8,
        "per_device_eval_batch_size": 8,
        "learning_rate": 5e-5,
        "save_total_limit": 2,
    }

    # Initialize custom trainer
    trainer = CustomSFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
    )

    # Train model
    trainer.train()

if __name__ == "__main__":
    train()