PROTAC-Splitter-App / protac_splitter /llms /training_mlm_model.py
ribesstefano's picture
Setup the spaces app
9dd777e
""" Train a masked language model (MLM) using an encoder-decoder architecture. """
import os
from typing import Optional, Dict, Any, Union
import subprocess
import torch
import huggingface_hub as hf
from transformers import (
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
AutoTokenizer,
)
from protac_splitter.llms.data_utils import load_tokenized_dataset
from protac_splitter.llms.hf_utils import (
create_hf_repository,
delete_hf_repository,
repo_exists,
)
from protac_splitter.llms.model_utils import get_encoder_decoder_model
def compute_metrics_for_mlm(pred) -> Dict[str, float]:
"""Compute metrics for MLM predictions, i.e., perplexity."""
logits = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
labels = pred.label_ids
# Convert to torch tensors
logits = torch.tensor(logits)
labels = torch.tensor(labels)
# Compute masked loss
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return {
"perplexity": torch.exp(loss).item(),
"loss": loss.item()
}
def train_mlm_model(
model_id: str,
ds_name: str,
ds_config: str = 'default',
learning_rate: float = 5e-5,
max_steps: int = -1,
num_train_epochs: int = 40,
batch_size: int = 128,
batch_size_tokenizer: int = 512,
gradient_accumulation_steps: int = 4,
hub_token: Optional[str] = None,
organization: Optional[str] = None,
output_dir: str = "./models/",
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
encoder_max_length: int = 512,
decoder_max_length: int = 512,
tie_encoder_decoder: bool = False,
delete_repo_if_exists: bool = False,
delete_local_repo_if_exists: bool = False,
training_args: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[str] = None,
num_proc_map: int = 1,
per_device_batch_size: Optional[int] = None,
lr_scheduler_type: Optional[str] = None,
mlm_probability: float = 0.15,
randomize_smiles: bool = False,
randomize_smiles_prob: float = 0.5,
randomize_smiles_repeat: int = 1,
):
"""
Trains a masked language model (MLM) using an encoder-decoder architecture.
Args:
model_id (str): The name of the model to be trained.
ds_name (str): The name of the dataset to use for training.
ds_config (str): The configuration of the dataset to use. Default: 'default'.
learning_rate (float): The learning rate for training. Default: 5e-5.
max_steps (int): The maximum number of training steps. Default: -1.
num_train_epochs (int): The number of training epochs. Default: 40.
batch_size (int): The total batch size. Default: 128.
batch_size_tokenizer (int): The batch size for the tokenizer. Default: 512.
gradient_accumulation_steps (int): The number of gradient accumulation steps. Default: 4.
hub_token (str): The Hugging Face token for authentication. Default: None.
organization (str): The organization to push the model to. Default: None.
output_dir (str): The output directory for the model. Default: "./models/".
tokenizer (AutoTokenizer | str): The tokenizer to use for training. Default: "seyonec/ChemBERTa-zinc-base-v1".
pretrained_encoder (str): The pretrained encoder model to use. Default: "seyonec/ChemBERTa-zinc-base-v1".
pretrained_decoder (str): The pretrained decoder model to use. Default: "seyonec/ChemBERTa-zinc-base-v1".
encoder_max_length (int): The maximum length of the encoder input. Default: 512.
decoder_max_length (int): The maximum length of the decoder input. Default: 512.
tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False.
delete_repo_if_exists (bool): Whether to delete the repository if it already exists. Default: False.
delete_local_repo_if_exists (bool): Whether to delete the local repository if it already exists. Default: False.
training_args (Dict[str, Any]): The training arguments for the Trainer. Default: None.
resume_from_checkpoint (str): The checkpoint to resume training from. Default: None.
num_optuna_trials (int): The number of Optuna hyperparameter search trials. Default: 0.
num_proc_map (int): The number of processes to use for mapping. Default: 1.
per_device_batch_size (int): The batch size per device. If defined, it will overwrite batch_size. Default: None.
lr_scheduler_type (str): The learning rate scheduler type. Default: None.
mlm_probability (float): The probability of masking tokens in the input. Default: 0.15.
randomize_smiles (bool): Whether to randomize SMILES strings. Default: False.
randomize_smiles_prob (float): The probability of randomizing SMILES strings. Default: 0.5.
randomize_smiles_repeat (int): The number of times to repeat randomizing SMILES strings. Default: 1.
"""
# Check if resume_from_checkpoint exists and it's a file
if resume_from_checkpoint is not None:
# Check if the checkpoint exists: it can be either a file or a directory
if not os.path.exists(resume_from_checkpoint):
raise ValueError(f"Checkpoint file '{resume_from_checkpoint}' does not exist.")
if hub_token is not None:
hf.login(token=hub_token)
# Setup output directory and Hugging Face repository
output_dir += f"/{model_id}"
if organization is not None:
hub_model_id = f"{organization}/{model_id}"
if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
delete_hf_repository(repo_id=hub_model_id, token=hub_token)
if not repo_exists(hub_model_id, token=hub_token):
print(f"Repository '{hub_model_id}' deleted.")
else:
print(f"Repository '{hub_model_id}' could not be deleted.")
return
if delete_local_repo_if_exists and os.path.exists(output_dir):
subprocess.run(["rm", "-rf", output_dir])
if not os.path.exists(output_dir):
print(f"Local repository '{output_dir}' deleted.")
else:
print(f"Local repository '{output_dir}' could not be deleted.")
return
repo_url = create_hf_repository(
repo_id=hub_model_id,
repo_type="model",
exist_ok=True,
private=True,
token=hub_token,
)
print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
else:
hub_model_id = None
print(f"Hub model ID: {hub_model_id}")
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
elif tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
# Set the pad token to the end of the sequence, required for MLM training
tokenizer.pad_token = tokenizer.eos_token
# Load the tokenized dataset
print("Loading tokenized dataset.")
dataset_tokenized = load_tokenized_dataset(
ds_name,
ds_config,
tokenizer,
batch_size_tokenizer,
encoder_max_length,
decoder_max_length,
token=hub_token,
num_proc_map=num_proc_map,
randomize_smiles=randomize_smiles,
randomize_smiles_prob=randomize_smiles_prob,
randomize_smiles_repeat=randomize_smiles_repeat,
randomize_text=True,
randomize_labels=False,
)
# Remove "labels" column from the dataset
dataset_tokenized = dataset_tokenized.remove_columns(["labels"])
print("Dataset loaded.")
# Setup the model for `model_init` in the Trainer
bert2bert = lambda: get_encoder_decoder_model(
pretrained_encoder=pretrained_encoder,
pretrained_decoder=pretrained_decoder,
max_length=encoder_max_length,
tie_encoder_decoder=tie_encoder_decoder,
)
# Setup the data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer,
mlm=True,
mlm_probability=mlm_probability,
pad_to_multiple_of=8,
)
# Setup the training arguments
if per_device_batch_size is None:
per_device_batch_size = batch_size // gradient_accumulation_steps
if training_args is None:
training_args = {
"output_dir": output_dir,
# Optimizer-related configs
"learning_rate": learning_rate,
"optim": "adamw_torch",
"lr_scheduler_type": "cosine" if lr_scheduler_type is None else lr_scheduler_type,
"warmup_steps": 8000, # NOTE: ChemFormer: 8000
# "warmup_ratio": 0,
"adam_beta1": 0.9, # NOTE: ChemFormer: 0.9
"adam_beta2": 0.999, # NOTE: ChemFormer: 0.999
"adam_epsilon": 1e-8, # Default: 1e-8
# Batch size, device, and performance optimizations configs
# "torch_compile": True,
"group_by_length": True,
"per_device_train_batch_size": per_device_batch_size,
"per_device_eval_batch_size": per_device_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"auto_find_batch_size": True,
"fp16": True if torch.cuda.is_available() else False,
# Evaluation and checkpointing configs
"max_steps": max_steps,
"num_train_epochs": num_train_epochs,
"save_steps": 1000, # NOTE: 200
"save_strategy": "steps",
"eval_steps": 1000, # NOTE: 500
"evaluation_strategy": "steps",
"save_total_limit": 1,
"load_best_model_at_end": True,
"metric_for_best_model": "perplexity",
"include_inputs_for_metrics": True,
# Logging configs
"log_level": "warning",
"logging_steps": 500,
"disable_tqdm": True,
"report_to": ["tensorboard"],
"save_only_model": False, # Default: False
# Hub information configs
"push_to_hub": True, # NOTE: Also manually done further down
"push_to_hub_model_id": model_id,
"push_to_hub_organization": organization,
"hub_model_id": hub_model_id,
"hub_token": hub_token,
"hub_strategy": "checkpoint", # NOTE: Allows to resume training from last checkpoint
"hub_private_repo": True,
# Other configs
"seed": 42,
"data_seed": 42,
}
# Setup the Trainer and start training (no Optuna hyperparameter search)
trainer = Trainer(
model_init=bert2bert,
tokenizer=tokenizer,
data_collator=data_collator,
args=TrainingArguments(**training_args),
compute_metrics=compute_metrics_for_mlm,
train_dataset=dataset_tokenized["train"],
eval_dataset=dataset_tokenized["validation"],
)
if resume_from_checkpoint is not None:
trainer.train(
resume_from_checkpoint=resume_from_checkpoint,
)
else:
trainer.train()
print("-" * 80)
print("Training completed.")
print("-" * 80)
if hub_model_id is not None:
print("Pushing model to Hugging Face Hub.")
print("-" * 80)
tokenizer.save_pretrained(output_dir)
trainer.push_to_hub(
commit_message="Initial version",
model_name=hub_model_id,
license="mit",
finetuned_from=f"{pretrained_encoder}",
tasks=["Text2Text Generation", "question-answering"],
tags=["PROTAC", "cheminformatics"],
dataset=[ds_name],
dataset_args=[ds_config],
)
tokenizer.push_to_hub(
repo_id=hub_model_id,
commit_message="Upload tokenizer",
private=True,
token=hub_token,
tags=["PROTAC", "cheminformatics"],
)
print("All done.")