import os import random import logging from typing import Optional, Union import torch from datasets import load_dataset, concatenate_datasets, Dataset from transformers import AutoTokenizer from rdkit import Chem from protac_splitter.evaluation import split_prediction def randomize_smiles_dataset( batch: dict, repeat: int = 1, prob: float = 0.5, apply_to_text: bool = True, apply_to_labels: bool = False, ) -> dict: """ Randomize SMILES in a batch of data. Args: batch (dict): Batch of data with "text" and "labels" keys. repeat (int, optional): Number of times to repeat the randomization. Defaults to 1. prob (float, optional): Probability of randomizing SMILES. Defaults to 0.5. apply_to_text (bool, optional): Whether to apply randomization to text. Defaults to True. apply_to_labels (bool, optional): Whether to apply randomization to labels. Defaults to False. Returns: dict: Randomized batch of data. """ new_texts, new_labels = [], [] for text, label in zip(batch["text"], batch["labels"]): try: mol_text = Chem.MolFromSmiles(text) mol_label = Chem.MolFromSmiles(label) except Exception: logging.error("Failed to convert SMILES to Mol!") new_texts.append(text) new_labels.append(label) continue if random.random() < prob: if apply_to_text: rand_texts = [Chem.MolToSmiles(mol_text, canonical=False, doRandom=True) for _ in range(repeat)] else: rand_texts = [text] * repeat if apply_to_labels: rand_labels = [Chem.MolToSmiles(mol_label, canonical=False, doRandom=True) for _ in range(repeat)] else: rand_labels = [label] * repeat new_texts.extend(rand_texts) new_labels.extend(rand_labels) else: new_texts.append(text) new_labels.append(label) return {"text": new_texts, "labels": new_labels} def process_data_to_model_inputs( batch, tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1", encoder_max_length: int = 512, decoder_max_length: int = 512, ): if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained(tokenizer) # tokenize the inputs and labels inputs = tokenizer(batch["text"], truncation=True, max_length=encoder_max_length) outputs = tokenizer(batch["labels"], truncation=True, max_length=decoder_max_length) batch["input_ids"] = inputs.input_ids batch["attention_mask"] = inputs.attention_mask batch["labels"] = outputs.input_ids.copy() # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # batch["input_ids"] = batch["input_ids"].to(device) # batch["attention_mask"] = batch["attention_mask"].to(device) # batch["labels"] = batch["labels"].to(device) # Because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. # We have to make sure that the PAD token is ignored when calculating the loss. # NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss. # NOTE: The following is already done in the DataCollatorForSeq2Seq # batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]] return batch def get_fragments_in_labels(labels: str, linkers_only_as_labels: bool = True) -> list[str]: """ Get the fragments in the labels. Args: labels (str): The labels. linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to True. Returns: list[str]: The fragments in the labels. """ ligands = split_prediction(labels) if linkers_only_as_labels: return ligands.get("linker", None) if None in ligands.values(): return None return f"{ligands['e3']}.{ligands['poi']}" def load_tokenized_dataset( dataset_dir: str, dataset_config: str = 'default', tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1", batch_size: int = 512, encoder_max_length: int = 512, decoder_max_length: int = 512, token: Optional[str] = None, num_proc_map: int = 1, randomize_smiles: bool = False, randomize_smiles_prob: float = 0.5, randomize_smiles_repeat: int = 1, randomize_text: bool = True, randomize_labels: bool = False, cache_dir: Optional[str] = None, all_fragments_as_labels: bool = True, linkers_only_as_labels: bool = False, causal_language_modeling: bool = False, train_size_ratio: float = 1.0, ) -> Dataset: """ Load dataset and tokenize it. Args: dataset_dir (str): The directory of the dataset or the name of the data on the Hugging Face Hub. dataset_config (str, optional): The configuration of the dataset. Defaults to 'default'. tokenizer (AutoTokenizer | str, optional): The tokenizer to use for tokenization. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Defaults to "seyonec/ChemBERTa-zinc-base-v1". batch_size (int, optional): The batch size for tokenization. Defaults to 512. encoder_max_length (int, optional): The maximum length of the encoder input sequence. Defaults to 512. decoder_max_length (int, optional): The maximum length of the decoder input sequence. Defaults to 512. token (Optional[str], optional): The Hugging Face API token. Defaults to None. num_proc_map (int, optional): The number of processes to use for mapping. Defaults to 1. randomize_smiles (bool, optional): Whether to randomize SMILES. Defaults to False. randomize_smiles_prob (float, optional): The probability of randomizing SMILES. Defaults to 0.5. randomize_smiles_repeat (int, optional): The number of times to repeat the randomization. Defaults to 1. randomize_text (bool, optional): Whether to randomize text. Defaults to True. randomize_labels (bool, optional): Whether to randomize labels. Defaults to False. cache_dir (Optional[str], optional): The directory to cache the dataset. Defaults to None. all_fragments_as_labels (bool, optional): Whether to get all fragments in the labels. Defaults to True. linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to False. causal_language_modeling (bool, optional): Whether to use causal language modeling. Defaults to False. train_size_ratio (float, optional): The ratio of the training dataset to use. Defaults to 1.0. Returns: Dataset: The tokenized dataset. """ if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained(tokenizer) if os.path.exists(dataset_dir): # NOTE: We need a different argument to load a dataset from disk: dataset = load_dataset( dataset_dir, data_dir=dataset_config, ) print(f"Dataset loaded from disk at: \"{dataset_dir}\". Length: {dataset.num_rows}") else: dataset = load_dataset( dataset_dir, dataset_config, token=token, cache_dir=cache_dir, ) print(f"Dataset loaded from hub. Length: {dataset.num_rows}") if train_size_ratio < 1.0 and train_size_ratio > 0: # Reduce the size of the training dataset but just selecting a fraction of the samples dataset["train"] = dataset["train"].select(range(int(train_size_ratio * dataset["train"].num_rows))) print(f"Reduced training dataset size to {train_size_ratio}. Length: {dataset.num_rows}") elif train_size_ratio > 1.0 or train_size_ratio < 0: raise ValueError("train_size_ratio must be between 0 and 1.") if not all_fragments_as_labels: dataset = dataset.map( lambda x: { "text": x["text"], "labels": get_fragments_in_labels(x["labels"], linkers_only_as_labels), }, batched=False, num_proc=num_proc_map, load_from_cache_file=True, desc="Getting fragments in labels", ) # Filter out the samples with None labels dataset = dataset.filter(lambda x: x["labels"] is not None) if linkers_only_as_labels: print(f"Set labels to linkers only. Length: {dataset.num_rows}") else: print(f"Set labels to E3 and WH only. Length: {dataset.num_rows}") if randomize_smiles: dataset["train"] = dataset["train"].map( randomize_smiles_dataset, batched=True, batch_size=batch_size, fn_kwargs={ "repeat": randomize_smiles_repeat, "prob": randomize_smiles_prob, "apply_to_text": randomize_text, "apply_to_labels": randomize_labels, }, num_proc=num_proc_map, load_from_cache_file=True, desc="Randomizing SMILES", ) print(f"Randomized SMILES in dataset. Length: {dataset.num_rows}") if causal_language_modeling: dataset = dataset.map( lambda x: { "text": x["text"] + "." + x["labels"], "labels": x["labels"], }, batched=False, num_proc=num_proc_map, load_from_cache_file=True, desc="Setting labels to text", ) print(f"Appended labels to text. Length: {dataset.num_rows}") # NOTE: Remove the "labels" column if causal language modeling, since the # DataCollatorForLM will automatically set the labels to the input_ids. dataset = dataset.map( process_data_to_model_inputs, batched=True, batch_size=batch_size, remove_columns=["text", "labels"] if causal_language_modeling else ["text"], fn_kwargs={ "tokenizer": tokenizer, "encoder_max_length": encoder_max_length, "decoder_max_length": decoder_max_length, }, num_proc=num_proc_map, load_from_cache_file=True, desc="Tokenizing dataset", ) print(f"Tokenized dataset. Length: {dataset.num_rows}") return dataset def load_trl_dataset( tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1", token: Optional[str] = None, max_length: int = 512, dataset_name: str = "ailab-bio/PROTAC-Splitter-Dataset", ds_config: str = "standard", ds_unalabeled: Optional[str] = None, ) -> Dataset: if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained(tokenizer) # Load training data train_dataset = load_dataset( dataset_name, ds_config, split="train", token=token, ) train_dataset = train_dataset.rename_column("text", "query") train_dataset = train_dataset.remove_columns(["labels"]) if ds_unalabeled is not None: # Load un-labelled data unlabeled_dataset = load_dataset( dataset_name, ds_unalabeled, split="train", token=token, ) unlabeled_dataset = unlabeled_dataset.rename_column("text", "query") unlabeled_dataset = unlabeled_dataset.remove_columns(["labels"]) # Concatenate datasets row-wise dataset = concatenate_datasets([train_dataset, unlabeled_dataset]) else: dataset = train_dataset def tokenize(sample, tokenizer, max_length=512): input_ids = tokenizer.encode(sample["query"], padding="max_length", max_length=max_length) return {"input_ids": input_ids, "query": sample["query"]} return dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=False) def data_collator_for_trl(batch): return { "input_ids": [torch.tensor(x["input_ids"]) for x in batch], "query": [x["query"] for x in batch], }