Spaces:
Sleeping
Sleeping
File size: 2,843 Bytes
9dd777e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
from typing import Dict, Any
import torch
from transformers import TrainerCallback
from trl import SFTTrainer
from rdkit import Chem
from protac_splitter.llms.data_utils import load_tokenized_dataset
from protac_splitter.llms.model_utils import get_model
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU if available
# Placeholder for a scoring function that evaluates the generated SMILES
def score_function(smiles1, predicted_smiles):
""" Evaluates the generated SMILES sequence based on validity. """
mol = Chem.MolFromSmiles(predicted_smiles)
return 1 if mol else 0 # Returns 1 if valid, 0 if invalid
# Custom Trainer subclass to integrate SMILES evaluation
class CustomSFTTrainer(SFTTrainer):
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
if eval_dataset is None:
eval_dataset = self.eval_dataset
# Generate predictions
predictions = self.predict(eval_dataset)
generated_texts = self.tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)
total_score = 0
total_samples = len(generated_texts)
for i, example in enumerate(eval_dataset):
input_text = example["text"] # Full input: "Smiles1 Smiles2.Smiles3.Smiles4"
smiles1 = input_text.split(" ")[0] # Extract Smiles1 (the prompt)
# Remove the prompt from the generated text to get the predicted completion
predicted_completion = generated_texts[i].removeprefix(smiles1).strip()
# Compute custom score
score = score_function(smiles1, predicted_completion)
total_score += score
# Compute average score
average_score = total_score / total_samples if total_samples > 0 else 0
# Log metrics
metrics = {f"{metric_key_prefix}_average_score": average_score}
self.log(metrics)
return metrics
def train():
""" Main training function """
model = get_model() # Load the model
tokenizer = model.tokenizer # Get tokenizer from model
# Load dataset
dataset = load_tokenized_dataset()
# Training arguments
training_args = {
"output_dir": "./trained_model",
"evaluation_strategy": "steps",
"save_strategy": "steps",
"logging_steps": 100,
"save_steps": 500,
"num_train_epochs": 3,
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"learning_rate": 5e-5,
"save_total_limit": 2,
}
# Initialize custom trainer
trainer = CustomSFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
)
# Train model
trainer.train()
if __name__ == "__main__":
train()
|