ribesstefano's picture
Setup the spaces app
9dd777e
import os
import random
import logging
from typing import Optional, Union
import torch
from datasets import load_dataset, concatenate_datasets, Dataset
from transformers import AutoTokenizer
from rdkit import Chem
from protac_splitter.evaluation import split_prediction
def randomize_smiles_dataset(
batch: dict,
repeat: int = 1,
prob: float = 0.5,
apply_to_text: bool = True,
apply_to_labels: bool = False,
) -> dict:
""" Randomize SMILES in a batch of data.
Args:
batch (dict): Batch of data with "text" and "labels" keys.
repeat (int, optional): Number of times to repeat the randomization. Defaults to 1.
prob (float, optional): Probability of randomizing SMILES. Defaults to 0.5.
apply_to_text (bool, optional): Whether to apply randomization to text. Defaults to True.
apply_to_labels (bool, optional): Whether to apply randomization to labels. Defaults to False.
Returns:
dict: Randomized batch of data.
"""
new_texts, new_labels = [], []
for text, label in zip(batch["text"], batch["labels"]):
try:
mol_text = Chem.MolFromSmiles(text)
mol_label = Chem.MolFromSmiles(label)
except Exception:
logging.error("Failed to convert SMILES to Mol!")
new_texts.append(text)
new_labels.append(label)
continue
if random.random() < prob:
if apply_to_text:
rand_texts = [Chem.MolToSmiles(mol_text, canonical=False, doRandom=True) for _ in range(repeat)]
else:
rand_texts = [text] * repeat
if apply_to_labels:
rand_labels = [Chem.MolToSmiles(mol_label, canonical=False, doRandom=True) for _ in range(repeat)]
else:
rand_labels = [label] * repeat
new_texts.extend(rand_texts)
new_labels.extend(rand_labels)
else:
new_texts.append(text)
new_labels.append(label)
return {"text": new_texts, "labels": new_labels}
def process_data_to_model_inputs(
batch,
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
encoder_max_length: int = 512,
decoder_max_length: int = 512,
):
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
# tokenize the inputs and labels
inputs = tokenizer(batch["text"], truncation=True, max_length=encoder_max_length)
outputs = tokenizer(batch["labels"], truncation=True, max_length=decoder_max_length)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["labels"] = outputs.input_ids.copy()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# batch["input_ids"] = batch["input_ids"].to(device)
# batch["attention_mask"] = batch["attention_mask"].to(device)
# batch["labels"] = batch["labels"].to(device)
# Because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`.
# We have to make sure that the PAD token is ignored when calculating the loss.
# NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss.
# NOTE: The following is already done in the DataCollatorForSeq2Seq
# batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
return batch
def get_fragments_in_labels(labels: str, linkers_only_as_labels: bool = True) -> list[str]:
""" Get the fragments in the labels.
Args:
labels (str): The labels.
linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to True.
Returns:
list[str]: The fragments in the labels.
"""
ligands = split_prediction(labels)
if linkers_only_as_labels:
return ligands.get("linker", None)
if None in ligands.values():
return None
return f"{ligands['e3']}.{ligands['poi']}"
def load_tokenized_dataset(
dataset_dir: str,
dataset_config: str = 'default',
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
batch_size: int = 512,
encoder_max_length: int = 512,
decoder_max_length: int = 512,
token: Optional[str] = None,
num_proc_map: int = 1,
randomize_smiles: bool = False,
randomize_smiles_prob: float = 0.5,
randomize_smiles_repeat: int = 1,
randomize_text: bool = True,
randomize_labels: bool = False,
cache_dir: Optional[str] = None,
all_fragments_as_labels: bool = True,
linkers_only_as_labels: bool = False,
causal_language_modeling: bool = False,
train_size_ratio: float = 1.0,
) -> Dataset:
""" Load dataset and tokenize it.
Args:
dataset_dir (str): The directory of the dataset or the name of the data on the Hugging Face Hub.
dataset_config (str, optional): The configuration of the dataset. Defaults to 'default'.
tokenizer (AutoTokenizer | str, optional): The tokenizer to use for tokenization. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
batch_size (int, optional): The batch size for tokenization. Defaults to 512.
encoder_max_length (int, optional): The maximum length of the encoder input sequence. Defaults to 512.
decoder_max_length (int, optional): The maximum length of the decoder input sequence. Defaults to 512.
token (Optional[str], optional): The Hugging Face API token. Defaults to None.
num_proc_map (int, optional): The number of processes to use for mapping. Defaults to 1.
randomize_smiles (bool, optional): Whether to randomize SMILES. Defaults to False.
randomize_smiles_prob (float, optional): The probability of randomizing SMILES. Defaults to 0.5.
randomize_smiles_repeat (int, optional): The number of times to repeat the randomization. Defaults to 1.
randomize_text (bool, optional): Whether to randomize text. Defaults to True.
randomize_labels (bool, optional): Whether to randomize labels. Defaults to False.
cache_dir (Optional[str], optional): The directory to cache the dataset. Defaults to None.
all_fragments_as_labels (bool, optional): Whether to get all fragments in the labels. Defaults to True.
linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to False.
causal_language_modeling (bool, optional): Whether to use causal language modeling. Defaults to False.
train_size_ratio (float, optional): The ratio of the training dataset to use. Defaults to 1.0.
Returns:
Dataset: The tokenized dataset.
"""
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
if os.path.exists(dataset_dir):
# NOTE: We need a different argument to load a dataset from disk:
dataset = load_dataset(
dataset_dir,
data_dir=dataset_config,
)
print(f"Dataset loaded from disk at: \"{dataset_dir}\". Length: {dataset.num_rows}")
else:
dataset = load_dataset(
dataset_dir,
dataset_config,
token=token,
cache_dir=cache_dir,
)
print(f"Dataset loaded from hub. Length: {dataset.num_rows}")
if train_size_ratio < 1.0 and train_size_ratio > 0:
# Reduce the size of the training dataset but just selecting a fraction of the samples
dataset["train"] = dataset["train"].select(range(int(train_size_ratio * dataset["train"].num_rows)))
print(f"Reduced training dataset size to {train_size_ratio}. Length: {dataset.num_rows}")
elif train_size_ratio > 1.0 or train_size_ratio < 0:
raise ValueError("train_size_ratio must be between 0 and 1.")
if not all_fragments_as_labels:
dataset = dataset.map(
lambda x: {
"text": x["text"],
"labels": get_fragments_in_labels(x["labels"], linkers_only_as_labels),
},
batched=False,
num_proc=num_proc_map,
load_from_cache_file=True,
desc="Getting fragments in labels",
)
# Filter out the samples with None labels
dataset = dataset.filter(lambda x: x["labels"] is not None)
if linkers_only_as_labels:
print(f"Set labels to linkers only. Length: {dataset.num_rows}")
else:
print(f"Set labels to E3 and WH only. Length: {dataset.num_rows}")
if randomize_smiles:
dataset["train"] = dataset["train"].map(
randomize_smiles_dataset,
batched=True,
batch_size=batch_size,
fn_kwargs={
"repeat": randomize_smiles_repeat,
"prob": randomize_smiles_prob,
"apply_to_text": randomize_text,
"apply_to_labels": randomize_labels,
},
num_proc=num_proc_map,
load_from_cache_file=True,
desc="Randomizing SMILES",
)
print(f"Randomized SMILES in dataset. Length: {dataset.num_rows}")
if causal_language_modeling:
dataset = dataset.map(
lambda x: {
"text": x["text"] + "." + x["labels"],
"labels": x["labels"],
},
batched=False,
num_proc=num_proc_map,
load_from_cache_file=True,
desc="Setting labels to text",
)
print(f"Appended labels to text. Length: {dataset.num_rows}")
# NOTE: Remove the "labels" column if causal language modeling, since the
# DataCollatorForLM will automatically set the labels to the input_ids.
dataset = dataset.map(
process_data_to_model_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["text", "labels"] if causal_language_modeling else ["text"],
fn_kwargs={
"tokenizer": tokenizer,
"encoder_max_length": encoder_max_length,
"decoder_max_length": decoder_max_length,
},
num_proc=num_proc_map,
load_from_cache_file=True,
desc="Tokenizing dataset",
)
print(f"Tokenized dataset. Length: {dataset.num_rows}")
return dataset
def load_trl_dataset(
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
token: Optional[str] = None,
max_length: int = 512,
dataset_name: str = "ailab-bio/PROTAC-Splitter-Dataset",
ds_config: str = "standard",
ds_unalabeled: Optional[str] = None,
) -> Dataset:
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
# Load training data
train_dataset = load_dataset(
dataset_name,
ds_config,
split="train",
token=token,
)
train_dataset = train_dataset.rename_column("text", "query")
train_dataset = train_dataset.remove_columns(["labels"])
if ds_unalabeled is not None:
# Load un-labelled data
unlabeled_dataset = load_dataset(
dataset_name,
ds_unalabeled,
split="train",
token=token,
)
unlabeled_dataset = unlabeled_dataset.rename_column("text", "query")
unlabeled_dataset = unlabeled_dataset.remove_columns(["labels"])
# Concatenate datasets row-wise
dataset = concatenate_datasets([train_dataset, unlabeled_dataset])
else:
dataset = train_dataset
def tokenize(sample, tokenizer, max_length=512):
input_ids = tokenizer.encode(sample["query"], padding="max_length", max_length=max_length)
return {"input_ids": input_ids, "query": sample["query"]}
return dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=False)
def data_collator_for_trl(batch):
return {
"input_ids": [torch.tensor(x["input_ids"]) for x in batch],
"query": [x["query"] for x in batch],
}