Spaces:
Sleeping
Sleeping
import os | |
from typing import Dict, Any | |
import torch | |
from transformers import TrainerCallback | |
from trl import SFTTrainer | |
from rdkit import Chem | |
from protac_splitter.llms.data_utils import load_tokenized_dataset | |
from protac_splitter.llms.model_utils import get_model | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU if available | |
# Placeholder for a scoring function that evaluates the generated SMILES | |
def score_function(smiles1, predicted_smiles): | |
""" Evaluates the generated SMILES sequence based on validity. """ | |
mol = Chem.MolFromSmiles(predicted_smiles) | |
return 1 if mol else 0 # Returns 1 if valid, 0 if invalid | |
# Custom Trainer subclass to integrate SMILES evaluation | |
class CustomSFTTrainer(SFTTrainer): | |
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): | |
if eval_dataset is None: | |
eval_dataset = self.eval_dataset | |
# Generate predictions | |
predictions = self.predict(eval_dataset) | |
generated_texts = self.tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True) | |
total_score = 0 | |
total_samples = len(generated_texts) | |
for i, example in enumerate(eval_dataset): | |
input_text = example["text"] # Full input: "Smiles1 Smiles2.Smiles3.Smiles4" | |
smiles1 = input_text.split(" ")[0] # Extract Smiles1 (the prompt) | |
# Remove the prompt from the generated text to get the predicted completion | |
predicted_completion = generated_texts[i].removeprefix(smiles1).strip() | |
# Compute custom score | |
score = score_function(smiles1, predicted_completion) | |
total_score += score | |
# Compute average score | |
average_score = total_score / total_samples if total_samples > 0 else 0 | |
# Log metrics | |
metrics = {f"{metric_key_prefix}_average_score": average_score} | |
self.log(metrics) | |
return metrics | |
def train(): | |
""" Main training function """ | |
model = get_model() # Load the model | |
tokenizer = model.tokenizer # Get tokenizer from model | |
# Load dataset | |
dataset = load_tokenized_dataset() | |
# Training arguments | |
training_args = { | |
"output_dir": "./trained_model", | |
"evaluation_strategy": "steps", | |
"save_strategy": "steps", | |
"logging_steps": 100, | |
"save_steps": 500, | |
"num_train_epochs": 3, | |
"per_device_train_batch_size": 8, | |
"per_device_eval_batch_size": 8, | |
"learning_rate": 5e-5, | |
"save_total_limit": 2, | |
} | |
# Initialize custom trainer | |
trainer = CustomSFTTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["validation"], | |
) | |
# Train model | |
trainer.train() | |
if __name__ == "__main__": | |
train() | |