File size: 14,331 Bytes
9dd777e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
""" Train a PPO and DPO model for PROTAC-Splitter using Hugging Face
Transformers and TRL. This is a work in progress code, so it's not tested nor
used in the package.
"""
from typing import Optional, Literal
from functools import partial
import os
import subprocess

import torch
import evaluate
import huggingface_hub as hf
from tqdm import tqdm
from datasets import load_dataset
from rdkit import Chem
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    EncoderDecoderModel,
    AutoConfig,
)
from trl import (
    AutoModelForSeq2SeqLMWithValueHead,
    PPOConfig,
    PPOTrainer,
    DPOTrainer,
)

from protac_splitter.llms.data_utils import (
    load_trl_dataset,
    data_collator_for_trl,
)

from protac_splitter.llms.hf_utils import (
    create_hf_repository,
    delete_hf_repository,
    repo_exists,
)
from protac_splitter.llms.evaluation import decode_and_get_metrics
from protac_splitter.evaluation import check_substructs, split_prediction


def clean_text(text: str) -> str:
    """ Cleans the text by removing special tokens. """
    return text.replace("<s>", "").replace("</s>", "")


def reward_function(
        query: str,
        response: str,
) -> float:
    """ Reward function for the RL-based models.
    
    Args:
        query (str): The query SMILES string.
        response (str): The response SMILES string.
        
    Returns:
        float: The reward value.
    """

    substructs = split_prediction(response)
    if substructs is None:
        return torch.Tensor(-1.)
    
    if not check_substructs(
        protac_smiles=query,
        poi_smiles=substructs['poi'],
        linker_smiles=substructs['linker'],
        e3_smiles=substructs['e3'],
        return_bond_types=False,
        poi_attachment_id=1,
        e3_attachment_id=2,
    ):
        return torch.Tensor(0.)

    return torch.Tensor(1.)


def train_ppo_model(
    model_id: str = "PROTAC-Splitter-PPO-standard_rand_recombined-ChemBERTa-zinc-base",
    organization: str = 'ailab-bio',
    output_dir: str = "./models/",
    max_steps: int = 2000,
    ppo_epochs: int = 5,
    batch_size: int = 128,
    hub_token: Optional[str] = None,
    pretrained_model_name: str = "ailab-bio/PROTAC-Splitter-standard_rand_recombined-ChemBERTa-zinc-base",
    max_length: int = 512,
    delete_repo_if_exists: bool = False,
    delete_local_repo_if_exists: bool = False,
    ds_name: str = "ailab-bio/PROTAC-Splitter-Dataset",
    ds_config: str = "standard",
):
    """ Trains a PPO model on a given dataset.
    
    Args:
        model_id (str, optional): The name of the model to be trained. Defaults to "PROTAC-Splitter-PPO-standard_rand_recombined-ChemBERTa-zinc-base".
        organization (str, optional): The organization name. Defaults to 'ailab-bio'.
        output_dir (str, optional): The output directory. Defaults to "./models/".
        max_steps (int, optional): The maximum number of training steps. Defaults to 2000.
        ppo_epochs (int, optional): The number of PPO epochs. Defaults to 4.
        batch_size (int, optional): The batch size. Defaults to 128.
        hub_token (Optional[str], optional): The Hugging Face token. Defaults to None.
        pretrained_model_name (str, optional): The name of the pretrained model. Defaults to "ailab-bio/PROTAC-Splitter-standard_rand_recombined-ChemBERTa-zinc-base".
        max_length (int, optional): The maximum length of the input sequence. Defaults to 512.
        delete_repo_first (bool, optional): Whether to delete the repository first. Defaults to False.
    """
    if ppo_epochs < 1:
        raise ValueError(f"ppo_epochs must be >= 1, got {ppo_epochs}.")
    if hub_token is not None:
        hf.login(token=hub_token)
    
    # Setup output directory and Hugging Face repository
    output_dir += f"/{model_id}"
    if organization is not None:
        hub_model_id = f"{organization}/{model_id}"
        if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
            delete_hf_repository(repo_id=hub_model_id, token=hub_token)
            if not repo_exists(hub_model_id, token=hub_token):
                print(f"Repository '{hub_model_id}' deleted.")
            else:
                print(f"Repository '{hub_model_id}' could not be deleted.")
                return
        if delete_local_repo_if_exists and os.path.exists(output_dir):
            subprocess.run(["rm", "-rf", output_dir])
            if not os.path.exists(output_dir):
                print(f"Local repository '{output_dir}' deleted.")
            else:
                print(f"Local repository '{output_dir}' could not be deleted.")
                return
        repo_url = create_hf_repository(
            repo_id=hub_model_id,
            repo_type="model",
            exist_ok=True,
            private=True,
            token=hub_token,
        )
        print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
    else:
        hub_model_id = None
    print(f"Hub model ID: {hub_model_id}")

    # Load pretrained model
    model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
        pretrained_model_name,
        max_length=max_length,
    )
    ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
        pretrained_model_name,
        max_length=max_length,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Get dataset
    train_dataset = load_trl_dataset(
        tokenizer=tokenizer,
        token=hub_token,
        max_length=max_length,
        dataset_name=ds_name,
        ds_config=ds_config,
    ).shuffle(seed=42).flatten_indices()

    # Setup PPO trainer
    hub_configs = {
        "repo_id": hub_model_id,
        "commit_message": "Initial version",
        "private": True,
    }
    ppo_config = PPOConfig(
        # Learning parameters
        learning_rate=1e-5,
        steps=max_steps, # Default: 20_000
        ppo_epochs=ppo_epochs, # Default: 4
        batch_size=batch_size, # Default: 256
        gradient_accumulation_steps=1, # Default: 1
        optimize_device_cache=True,
        # PPO parameters
        init_kl_coef=1.0,
        adap_kl_ctrl=True,
        target=0.5,
        horizon=1000,
        cliprange=0.1,
        early_stopping=True,
        target_kl=0.5,
        max_grad_norm=1.0,
        use_score_scaling=True,
        use_score_norm=True,
        whiten_rewards=True,
        # Logging parameters
        # NOTE: Check this guide for more information about the logged metrics:
        # https://huggingface.co/docs/trl/v0.10.1/logging
        model_name=hub_model_id,
        push_to_hub_if_best_kwargs=hub_configs,
        log_with="tensorboard", # ["wandb", LoggerType.TENSORBOARD],
        project_kwargs={"logging_dir": output_dir},
        seed=42,
    )
    ppo_trainer = PPOTrainer(
        model=model,
        ref_model=ref_model,
        num_shared_layers=0,
        config=ppo_config,
        tokenizer=tokenizer,
        dataset=train_dataset,
        data_collator=data_collator_for_trl,
        # lr_scheduler=torch.optim.lr_scheduler.LRScheduler, # NOTE: It must be that, CosineAnnealingLR is not supported
    )

    # Training Loop
    generation_kwargs = {
        "do_sample": True,
        "num_beams": 5,
        "top_k": 20,
        "max_length": 512,
        "pad_token_id": tokenizer.eos_token_id,
    }

    for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader), total=len(ppo_trainer.dataloader)):
        query_tensors = batch["input_ids"]

        # Get response from SFTModel
        response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
        batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

        # Compute reward score
        rewards = [reward_function(clean_text(q), clean_text(r)) for q, r in zip(batch["query"], batch["response"])]
        rewards = [torch.tensor(r) for r in rewards]

        # Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)

    # Save model and tokenizer
    ppo_trainer.push_to_hub(**hub_configs)
    tokenizer.push_to_hub(**hub_configs)


def train_dpo_model(
    model_name: str = "ailab-bio/PROTAC-Splitter-DPO",
    output_dir: str = "./models/",
    beta: float = 0.1,
    loss_type: Literal["sigmoid", "hinge"] = "sigmoid",
    learning_rate: float = 5e-5,
    max_steps: int = 2000,
    num_train_epochs: int = -1,
    batch_size: int = 128,
    gradient_accumulation_steps: int = 4,
    resume_from_checkpoint: bool = False,
    hub_token: Optional[str] = None,
    pretrained_model_name: str = "ailab-bio/PROTAC-Splitter_untied_80-20-split",
    pretrained_ref_model_name: str = "ailab-bio/PROTAC-Splitter_untied_80-20-split",
    max_length: int = None,
    delete_repo_first: bool = False,
    optuna_search: bool = False,  
):
    """ Trains a DPO model on a given dataset.
    
    Args:
        model_name (str, optional): The name of the model to be trained. Defaults to "ailab-bio/PROTAC-Splitter-DPO".
        max_steps (int, optional): The maximum number of training steps. Defaults to 2000.
    """
    if hub_token is not None:
        hf.login(token=hub_token)
    if delete_repo_first and not resume_from_checkpoint:
        delete_hf_repository(repo_id=model_name, token=hub_token)
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name,
        token=hub_token,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # Get train and eval datasets
    dataset = load_dataset(
        "ailab-bio/PROTAC-Substructures-DPO",
        token=hub_token,
    )
    # Setup models
    def model_init():
        return EncoderDecoderModel.from_pretrained(
            pretrained_model_name,
            token=hub_token,
        )
    model_ref = EncoderDecoderModel.from_pretrained(
        pretrained_ref_model_name,
        token=hub_token,
    )
    # Setup training arguments
    per_device_batch_size = batch_size // gradient_accumulation_steps
    training_args = TrainingArguments(
        output_dir=output_dir,
        # Optimizer-related configs
        learning_rate=learning_rate,
        optim="adamw_torch",
        lr_scheduler_type="cosine", # Default: "linear"
        # Batch size and device configs
        per_device_train_batch_size=per_device_batch_size,
        per_device_eval_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        auto_find_batch_size=True,
        # torch_compile=True,
        fp16=True,
        # Evaluation and checkpointing configs
        evaluation_strategy="steps", # TODO: Why is it not working? "steps",
        max_steps=max_steps,
        num_train_epochs=num_train_epochs,
        eval_steps=100,
        save_steps=200,
        # eval_steps=7500,
        # warmup_steps=2000,
        save_strategy="steps",
        save_total_limit=1,
        load_best_model_at_end=True,
        # metric_for_best_model="valid_smiles",
        # Logging configs
        log_level="info",
        logging_steps=50,
        disable_tqdm=True,
        # Hub information configs
        push_to_hub=True, # NOTE: Done manually further down
        hub_token=hub_token,
        hub_model_id=model_name,
        hub_strategy="checkpoint", # NOTE: Allows to resume training from last checkpoint 
        hub_private_repo=True,
        # Other configs
        remove_unused_columns=False,
        seed=42,
        data_seed=42,
    )
    # Setup Matrics
    # TODO: The metric is not working because the predictions include rewards,
    # or something like that, i.e., real values, which cannot be decoded by the
    # tokenizer. Skipping for now and using the default one.
    rouge = evaluate.load("rouge")
    fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(
        radius=8,
        fpSize=2048,
    )
    metric = partial(
        decode_and_get_metrics,
        rouge=rouge,
        tokenizer=tokenizer,
        fpgen=fpgen,
    )
    # Setup trainer and start training
    if max_length is None:
        max_length = AutoConfig.from_pretrained(
            pretrained_model_name,
            token=hub_token,
        ).max_length
        # max_length = model.config.max_length
    dpo_trainer = DPOTrainer(
        model=model_init(),
        ref_model=model_ref,
        beta=beta,
        loss_type=loss_type,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        tokenizer=tokenizer,
        model_init=model_init if optuna_search else None,
        # compute_metrics=metric,
        max_length=max_length,
        max_prompt_length=max_length,
        max_target_length=max_length,
        is_encoder_decoder=True,
        padding_value=tokenizer.pad_token_id,
        truncation_mode="keep_start",
        args=training_args,
    )
    if optuna_search and False:
        # TODO: This is not working because the training arguments do NOT
        # include the beta parameter...
        def optuna_hp_space(trial):
            return {
                "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
                "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64, 128]),
                "beta": trial.suggest_float("beta", 0.1, 0.5),
            }
        best_trials = dpo_trainer.hyperparameter_search(
            direction=["minimize"],
            backend="optuna",
            hp_space=optuna_hp_space,
            n_trials=20,
            # compute_objective=compute_objective,
        )
        print("-" * 80)
        print(f"Best trials:\n{best_trials}")
        print("-" * 80)
    else:
        if resume_from_checkpoint:
            resume_from_checkpoint = "last-checkpoint"
        else:
            resume_from_checkpoint = None
        dpo_trainer.train(
            resume_from_checkpoint=resume_from_checkpoint,
        )
    dpo_trainer.push_to_hub(
        commit_message="Initial version",
        model_name=model_name,
        license="mit",
        finetuned_from=pretrained_model_name,
        tasks=["Text2Text Generation"],
        tags=["PROTAC", "cheminformatics"],
        dataset="ailab-bio/PROTAC-Substructures-DPO",
    )