ribesstefano's picture
Setup the spaces app
9dd777e
""" Hugging Face utilities for model loading and pipeline creation. """
from typing import Optional, List, Dict, Union
from datasets import Dataset
from transformers import (
AutoTokenizer,
EncoderDecoderModel,
AutoModelForCausalLM,
pipeline,
GenerationConfig,
)
from transformers.pipelines.pt_utils import KeyDataset
from tqdm import tqdm
import torch
def get_encoder_decoder_model(
pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
max_length: Optional[int] = 512,
tie_encoder_decoder: bool = False,
) -> EncoderDecoderModel:
""" Get the EncoderDecoderModel model for the PROTAC splitter.
Args:
pretrained_encoder (str): The pretrained model to use for the encoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
pretrained_decoder (str): The pretrained model to use for the decoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
max_length (int): The maximum length of the input sequence. Default: 512
tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False
Returns:
EncoderDecoderModel: The EncoderDecoderModel model for the PROTAC splitter
"""
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained(
pretrained_encoder,
pretrained_decoder,
tie_encoder_decoder=tie_encoder_decoder,
)
print(f"Number of parameters: {bert2bert.num_parameters():,}")
tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
# Tokenizer-related configs
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
# Generation configs
# NOTE: See full list of configurations can be found here: https://huggingface.co/docs/transformers/v4.33.3/en/main_classes/text_generation#transformers.GenerationConfig
bert2bert.encoder.config.max_length = max_length
bert2bert.decoder.config.max_length = max_length
def setup_gen(config):
config.do_sample = True
config.num_beams = 5
config.top_k = 20
config.max_length = 512
# config.max_new_tokens = 512
return config
bert2bert.config = setup_gen(bert2bert.config)
bert2bert.encoder.config = setup_gen(bert2bert.encoder.config)
bert2bert.decoder.config = setup_gen(bert2bert.decoder.config)
bert2bert.decoder.config.is_decoder = True
bert2bert.generation_config = setup_gen(bert2bert.generation_config)
# bert2bert.config.do_sample = True
# bert2bert.config.num_beams = 5
# bert2bert.config.top_k = 20
# bert2bert.config.max_length=512
# bert2bert.config.max_new_tokens=512
# bert2bert.generation_config.max_new_tokens = 512
# bert2bert.generation_config.min_new_tokens = 512
# bert2bert.config.max_new_tokens = 514
# bert2bert.config.early_stopping = True
# bert2bert.config.length_penalty = 2.0
# # bert2bert.config.no_repeat_ngram_size = 3 # Default: 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert2bert.to(device)
return bert2bert
def get_causal_model(
pretrained_model: str = "seyonec/ChemBERTa-zinc-base-v1",
max_length: Optional[int] = 512,
) -> AutoModelForCausalLM:
""" Get the causal language model for the PROTAC splitter.
Args:
pretrained_model (str): The pretrained model to use for the causal language model. Default: "seyonec/ChemBERTa-zinc-base-v1"
max_length (int): The maximum length of the input sequence. Default: 512
Returns:
AutoModelForCausalLM: The causal language model for the PROTAC splitter
"""
model = AutoModelForCausalLM.from_pretrained(pretrained_model, is_decoder=True)
# model.is_decoder = True # It might not be necessary, but it's good to be explicit
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model
# REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71
GENERATION_STRATEGY_PARAMS = {
"greedy": {"num_beams": 1, "do_sample": False},
"contrastive_search": {"penalty_alpha": 0.1, "top_k": 10},
"multinomial_sampling": {"num_beams": 1, "do_sample": True},
"beam_search_decoding": {"num_beams": 5, "do_sample": False, "num_return_sequences": 5},
"beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True, "num_return_sequences": 5},
"diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0, "num_return_sequences": 5},
}
def avail_generation_strategies() -> List[str]:
""" Get the available generation strategies. """
return list(GENERATION_STRATEGY_PARAMS.keys())
def get_generation_config(generation_strategy: str) -> GenerationConfig:
""" Get the generation config for the given generation strategy. """
return GenerationConfig(
max_length=512,
max_new_tokens=512,
**GENERATION_STRATEGY_PARAMS[generation_strategy],
)
def get_pipeline(
model_name: str,
token: str,
is_causal_language_model: bool,
generation_strategy: Optional[str] = None,
num_return_sequences: int = 1,
device: Optional[Union[int, str]] = None,
) -> pipeline:
""" Get the pipeline for the given model name and generation strategy.
"""
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
if is_causal_language_model and generation_strategy is None:
print('Loading pipeline for causal language models...')
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
return pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
token=token,
device=device,
num_return_sequences=num_return_sequences,
)
if is_causal_language_model and generation_strategy is not None:
print('Loading pipeline for causal language models...')
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
return pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
token=token,
device=device,
generation_config=get_generation_config(generation_strategy),
)
if not is_causal_language_model and generation_strategy is None:
print('Loading pipeline for sequence-to-sequence models...')
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
return pipeline(
"text2text-generation",
model=model_name,
tokenizer=tokenizer,
token=token,
device=device,
)
if not is_causal_language_model and generation_strategy is not None:
print('Loading pipeline for sequence-to-sequence models...')
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
return pipeline(
"text2text-generation",
model=model_name,
tokenizer=tokenizer,
token=token,
device=device,
generation_config=get_generation_config(generation_strategy),
)
def run_causal_pipeline(
pipe: pipeline,
test_ds: Dataset,
batch_size: int,
smiles_column: str = 'prompt',
) -> List[Dict[str, str]]:
""" Run the pipeline for causal language models and return the predictions.
Args:
pipe (pipeline): The pipeline object to use for generating predictions.
test_ds (Dataset): The test dataset to generate predictions for.
batch_size (int): The batch size to use for generating predictions.
Returns:
List[Dict[str, str]]: A list of dictionaries containing the predictions.
"""
preds = []
for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
generated_text = [p['generated_text'] for p in pred]
# Remove the prompt from the generated text
generated_text = ['.'.join(t.split('.')[1:]) for t in generated_text]
# Add the predictions to the list
p = {f'pred_n{i}': t for i, t in enumerate(generated_text)}
preds.append(p)
return preds
def run_seq2seq_pipeline(
pipe: pipeline,
test_ds: Dataset,
batch_size: int,
smiles_column: str = 'text',
) -> List[Dict[str, str]]:
""" Run the pipeline for sequence-to-sequence models and return the predictions.
Args:
pipe (pipeline): The pipeline object to use for generating predictions.
test_ds (Dataset): The test dataset to generate predictions for.
batch_size (int): The batch size to use for generating predictions.
Returns:
List[Dict[str, str]]: A list of dictionaries containing the predictions.
"""
preds = []
for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
p = {f'pred_n{i}': p['generated_text'] for i, p in enumerate(pred)}
preds.append(p)
return preds
def run_pipeline(
pipe: pipeline,
test_ds: Dataset,
batch_size: int,
is_causal_language_model: bool,
smiles_column: str = 'text',
) -> List[Dict[str, str]]:
""" Run the pipeline and return the predictions.
Args:
pipe (pipeline): The pipeline object to use for generating predictions.
test_ds (Dataset): The test dataset to generate predictions for.
batch_size (int): The batch size to use for generating predictions.
is_causal_language_model (bool): Whether the model is a causal language model or not.
smiles_column (str): The column name in the dataset that contains the SMILES strings. Default: 'text'
Returns:
List[Dict[str, str]]: A list of dictionaries containing the beam-size predictions in the format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
"""
if is_causal_language_model:
return run_causal_pipeline(pipe, test_ds, batch_size, smiles_column)
else:
return run_seq2seq_pipeline(pipe, test_ds, batch_size, smiles_column)