Spaces:
Sleeping
Sleeping
File size: 7,131 Bytes
9dd777e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Union
from transformers import AutoTokenizer, EvalPrediction
import numpy as np
from rdkit import Chem, DataStructs
import evaluate
import multiprocessing as mp
import datetime
from protac_splitter.evaluation import (
# is_valid_smiles,
# has_three_substructures,
# has_all_attachment_points,
# check_substructs,
score_prediction,
)
def process_predictions(args) -> list:
""" Process one iteration of the prediction scoring.
Args:
args (tuple): Tuple of arguments for the scoring function.
Returns:
dict: The scores for the prediction.
"""
pred_smiles, protac_smiles, label_smiles, fpgen, compute_rdkit_metrics, compute_graph_metrics = args
scores = []
for protac, pred, label in zip(protac_smiles, pred_smiles, label_smiles):
scores.append(score_prediction(
protac_smiles=protac,
label_smiles=label,
pred_smiles=pred,
fpgen=fpgen,
compute_rdkit_metrics=compute_rdkit_metrics,
compute_graph_metrics=compute_graph_metrics,
graph_edit_kwargs={"timeout": 0.05},
))
return scores
def decode_and_get_metrics(
pred: EvalPrediction,
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
rouge = None, # Optional[evaluate.metrics.rouge.Rouge] = None,
fpgen = None, # Optional[Chem.rdFingerprintGenerator] = None,
compute_rdkit_metrics: bool = False,
compute_graph_metrics: bool = True,
num_proc: int = 1,
batch_size: int = 128,
use_nan_for_missing: bool = True,
causal_language_modeling: bool = False,
) -> dict[str, float]:
""" Compute metrics for tokenized PROTAC predictions.
Args:
pred (transformers.EvalPrediction): The predictions from the model.
rouge (Rouge): The Rouge object to use for scoring. Example: `rouge = evaluate.load("rouge")`
tokenizer (AutoTokenizer | str): The tokenizer to use for decoding the predictions. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Default: "seyonec/ChemBERTa-zinc-base-v1"
fpgen (Chem.rdFingerprintGenerator): The fingerprint generator to use for computing the Tanimoto similarity. Default: `Chem.rdFingerprintGenerator.GetMorganGenerator(radius=8, fpSize=2048)`
Returns:
dict[str, float]: A dictionary containing the scores for the predictions
"""
print(f"[{datetime.datetime.now()}] Starting decode_and_get_metrics (protac_splitter/llms/evaluation.py)")
if causal_language_modeling:
# NOTE: For causal language models, we only care about perplexity, so we
# only need the eval_loss, which is automatically added.
return {}
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
labels_ids = pred.label_ids
pred_ids = pred.predictions
input_ids = pred.inputs
if causal_language_modeling:
# The prediction logits will be of shape: (batch_size, sequence_length, vocabulary_size)
# So we need to get the argmax of the last dimension to get the
# predicted token IDs.
# NOTE: Not exactly the same as what would happen during generation, but
# hopefully it's close enough to assess model performance during
# training.
pred_ids = np.argmax(pred_ids, axis=-1)
# Replace -100 in the IDs with the tokenizer pad token id
# NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss.
# TODO: Understand why this needs to be done to the inputs as well
ignore_index = -100
labels_ids[labels_ids == ignore_index] = tokenizer.pad_token_id
pred_ids[pred_ids == ignore_index] = tokenizer.pad_token_id
# Get strings from IDs
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
if not causal_language_modeling:
input_ids[input_ids == ignore_index] = tokenizer.pad_token_id
input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
else:
# NOTE: For causal language models, i.e., decoder only, the input PROTAC
# is in the label. Therefore, we need to decode the label to get the
# input. The label looks something like "PROTAC.E3.Linker.WH", so we
# need to split it and get the last (three) parts.
input_str = [str(s.split('.')[0]) for s in label_str]
label_str = ['.'.join(s.split('.')[1:]) for s in label_str]
pred_str = ['.'.join(s.split('.')[1:]) if '.' in s else s for s in pred_str]
# Get scores
if num_proc == 1:
scores = process_predictions((
pred_str, input_str, label_str, fpgen, compute_rdkit_metrics, compute_graph_metrics
))
else:
# Use pools to process batches of predictions
with mp.Pool(processes=num_proc) as pool:
scores = []
for i in range(0, len(pred_str), batch_size):
scores += pool.map(process_predictions, [
(pred_str[i:i+batch_size], input_str[i:i+batch_size], label_str[i:i+batch_size], fpgen, compute_rdkit_metrics, compute_graph_metrics)
])
# Flatten the list of scores
scores = [s for ls in scores for s in ls]
# Aggregate scores
scores_labels = set()
for s in scores:
scores_labels.update(s.keys())
aggregated_scores = {}
for k in scores_labels:
values = np.array([s.get(k, np.nan) for s in scores], dtype=float)
# If values is all NaN, set the aggregated score to NaN and continue
if np.all(np.isnan(values)):
aggregated_scores[k] = None
continue
# Compute average, excluding `NaN` values if necessary
if use_nan_for_missing:
aggregated_scores[k] = np.nanmean(values)
else:
valid_values = values[~np.isnan(values)]
aggregated_scores[k] = np.mean(valid_values) if valid_values.size > 0 else float('nan')
# Get Rouge score
if rouge is not None:
rouge_output = rouge.compute(predictions=pred_str, references=label_str)
aggregated_scores.update({k: v for k, v in rouge_output.items()})
# TODO
# # Get tanimoto score
# pred_str = np.array(pred_str)[valid_smiles == 1]
# label_str = np.array(label_str)[valid_smiles == 1]
# if len(pred_str) == 0:
# scores['tanimoto'] = 0.0
# return scores
# pred_mols = [Chem.MolFromSmiles(s) for s in pred_str]
# label_mols = [Chem.MolFromSmiles(s) for s in label_str]
# pred_fps = [fpgen.GetFingerprint(m) for m in pred_mols]
# label_fps = [fpgen.GetFingerprint(m) for m in label_mols]
# tanimoto = [DataStructs.TanimotoSimilarity(l, p) for l, p in zip(label_fps, pred_fps)]
# scores['tanimoto'] = np.array(tanimoto).mean()
print(f"[{datetime.datetime.now()}] Done with decode_and_get_metrics (protac_splitter/llms/evaluation.py)")
return aggregated_scores
|