File size: 10,562 Bytes
9dd777e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
""" Hugging Face utilities for model loading and pipeline creation. """
from typing import Optional, List, Dict, Union
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    EncoderDecoderModel,
    AutoModelForCausalLM,
    pipeline,
    GenerationConfig,
)
from transformers.pipelines.pt_utils import KeyDataset
from tqdm import tqdm
import torch


def get_encoder_decoder_model(
        pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
        pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
        max_length: Optional[int] = 512,
        tie_encoder_decoder: bool = False,
) -> EncoderDecoderModel:
    """ Get the EncoderDecoderModel model for the PROTAC splitter.

    Args:
        pretrained_encoder (str): The pretrained model to use for the encoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
        pretrained_decoder (str): The pretrained model to use for the decoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
        max_length (int): The maximum length of the input sequence. Default: 512
        tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False

    Returns:
        EncoderDecoderModel: The EncoderDecoderModel model for the PROTAC splitter
    """
    bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained(
        pretrained_encoder,
        pretrained_decoder,
        tie_encoder_decoder=tie_encoder_decoder,
    )
    print(f"Number of parameters: {bert2bert.num_parameters():,}")
    tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
    # Tokenizer-related configs
    bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
    bert2bert.config.eos_token_id = tokenizer.sep_token_id
    bert2bert.config.pad_token_id = tokenizer.pad_token_id
    bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
    # Generation configs
    # NOTE: See full list of configurations can be found here: https://huggingface.co/docs/transformers/v4.33.3/en/main_classes/text_generation#transformers.GenerationConfig
    bert2bert.encoder.config.max_length = max_length
    bert2bert.decoder.config.max_length = max_length

    def setup_gen(config):
        config.do_sample = True
        config.num_beams = 5
        config.top_k = 20
        config.max_length = 512
        # config.max_new_tokens = 512
        return config
    
    bert2bert.config = setup_gen(bert2bert.config)
    bert2bert.encoder.config = setup_gen(bert2bert.encoder.config)
    bert2bert.decoder.config = setup_gen(bert2bert.decoder.config)
    bert2bert.decoder.config.is_decoder = True
    bert2bert.generation_config = setup_gen(bert2bert.generation_config)
    
    # bert2bert.config.do_sample = True
    # bert2bert.config.num_beams = 5
    # bert2bert.config.top_k = 20
    # bert2bert.config.max_length=512
    # bert2bert.config.max_new_tokens=512

    # bert2bert.generation_config.max_new_tokens = 512
    # bert2bert.generation_config.min_new_tokens = 512
    
    # bert2bert.config.max_new_tokens = 514
    # bert2bert.config.early_stopping = True
    # bert2bert.config.length_penalty = 2.0
    # # bert2bert.config.no_repeat_ngram_size = 3 # Default: 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    bert2bert.to(device)

    return bert2bert


def get_causal_model(
        pretrained_model: str = "seyonec/ChemBERTa-zinc-base-v1",
        max_length: Optional[int] = 512,
) -> AutoModelForCausalLM:
    """ Get the causal language model for the PROTAC splitter.

    Args:
        pretrained_model (str): The pretrained model to use for the causal language model. Default: "seyonec/ChemBERTa-zinc-base-v1"
        max_length (int): The maximum length of the input sequence. Default: 512

    Returns:
        AutoModelForCausalLM: The causal language model for the PROTAC splitter
    """
    model = AutoModelForCausalLM.from_pretrained(pretrained_model, is_decoder=True)
    # model.is_decoder = True # It might not be necessary, but it's good to be explicit

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    return model


# REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71
GENERATION_STRATEGY_PARAMS = {
    "greedy": {"num_beams": 1, "do_sample": False},
    "contrastive_search": {"penalty_alpha": 0.1, "top_k": 10},
    "multinomial_sampling": {"num_beams": 1, "do_sample": True},
    "beam_search_decoding": {"num_beams": 5, "do_sample": False, "num_return_sequences": 5},
    "beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True, "num_return_sequences": 5},
    "diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0, "num_return_sequences": 5},
}

def avail_generation_strategies() -> List[str]:
    """ Get the available generation strategies. """
    return list(GENERATION_STRATEGY_PARAMS.keys())

def get_generation_config(generation_strategy: str) -> GenerationConfig:
    """ Get the generation config for the given generation strategy. """
    return GenerationConfig(
        max_length=512,
        max_new_tokens=512,
        **GENERATION_STRATEGY_PARAMS[generation_strategy],
    )

def get_pipeline(
        model_name: str,
        token: str,
        is_causal_language_model: bool,
        generation_strategy: Optional[str] = None,
        num_return_sequences: int = 1,
        device: Optional[Union[int, str]] = None,
) -> pipeline:
    """ Get the pipeline for the given model name and generation strategy.
    
    
    
    """
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    if is_causal_language_model and generation_strategy is None:
        print('Loading pipeline for causal language models...')
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
        return pipeline(
            "text-generation",
            model=model_name,
            tokenizer=tokenizer,
            token=token,
            device=device,
            num_return_sequences=num_return_sequences,
        )
    if is_causal_language_model and generation_strategy is not None:
        print('Loading pipeline for causal language models...')
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
        return pipeline(
            "text-generation",
            model=model_name,
            tokenizer=tokenizer,
            token=token,
            device=device,
            generation_config=get_generation_config(generation_strategy),
        )
    if not is_causal_language_model and generation_strategy is None:
        print('Loading pipeline for sequence-to-sequence models...')
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
        return pipeline(
            "text2text-generation",
            model=model_name,
            tokenizer=tokenizer,
            token=token,
            device=device,
        )
    if not is_causal_language_model and generation_strategy is not None:
        print('Loading pipeline for sequence-to-sequence models...')
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
        return pipeline(
            "text2text-generation",
            model=model_name,
            tokenizer=tokenizer,
            token=token,
            device=device,
            generation_config=get_generation_config(generation_strategy),
        )

def run_causal_pipeline(
        pipe: pipeline,
        test_ds: Dataset,
        batch_size: int,
        smiles_column: str = 'prompt',
) -> List[Dict[str, str]]:
    """ Run the pipeline for causal language models and return the predictions.
    
    Args:
        pipe (pipeline): The pipeline object to use for generating predictions.
        test_ds (Dataset): The test dataset to generate predictions for.
        batch_size (int): The batch size to use for generating predictions.

    Returns:
        List[Dict[str, str]]: A list of dictionaries containing the predictions.
    """
    preds = []
    for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
        generated_text = [p['generated_text'] for p in pred]
        # Remove the prompt from the generated text
        generated_text = ['.'.join(t.split('.')[1:]) for t in generated_text]
        # Add the predictions to the list
        p = {f'pred_n{i}': t for i, t in enumerate(generated_text)}
        preds.append(p)
    return preds

def run_seq2seq_pipeline(
        pipe: pipeline,
        test_ds: Dataset,
        batch_size: int,
        smiles_column: str = 'text',
) -> List[Dict[str, str]]:
    """ Run the pipeline for sequence-to-sequence models and return the predictions.
    
    Args:
        pipe (pipeline): The pipeline object to use for generating predictions.
        test_ds (Dataset): The test dataset to generate predictions for.
        batch_size (int): The batch size to use for generating predictions.
        
    Returns:
        List[Dict[str, str]]: A list of dictionaries containing the predictions.
    """
    preds = []
    for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
        p = {f'pred_n{i}': p['generated_text'] for i, p in enumerate(pred)}
        preds.append(p)
    return preds

def run_pipeline(
        pipe: pipeline,
        test_ds: Dataset,
        batch_size: int,
        is_causal_language_model: bool,
        smiles_column: str = 'text',
) -> List[Dict[str, str]]:
    """ Run the pipeline and return the predictions.
    
    Args:
        pipe (pipeline): The pipeline object to use for generating predictions.
        test_ds (Dataset): The test dataset to generate predictions for.
        batch_size (int): The batch size to use for generating predictions.
        is_causal_language_model (bool): Whether the model is a causal language model or not.
        smiles_column (str): The column name in the dataset that contains the SMILES strings. Default: 'text'
        
    Returns:
        List[Dict[str, str]]: A list of dictionaries containing the beam-size predictions in the format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
    """
    if is_causal_language_model:
        return run_causal_pipeline(pipe, test_ds, batch_size, smiles_column)
    else:
        return run_seq2seq_pipeline(pipe, test_ds, batch_size, smiles_column)