Spaces:
Sleeping
Sleeping
""" Evaluation functions for the protac_splitter package. They need to be generic to accomodate predictions coming from different models. """ | |
import math | |
import re | |
import logging | |
from typing import Tuple, Any, Dict, Optional, Union | |
import numpy as np | |
from rdkit import Chem, RDLogger | |
from rdkit.Chem import DataStructs | |
# Disable RDKit logging: when checking SMILES validity, we suppress warnings | |
RDLogger.DisableLog("rdApp.*") | |
from .chemoinformatics import ( | |
canonize, | |
canonize_smiles, | |
remove_stereo, | |
get_substr_match, | |
) | |
from .protac_cheminformatics import reassemble_protac | |
from .graphs_utils import ( | |
get_smiles2graph_edit_distance, | |
get_smiles2graph_edit_distance_norm, | |
) | |
def is_valid_smiles( | |
smiles: Optional[str], | |
return_mol: bool = False, | |
) -> Union[bool, Tuple[bool, Chem.Mol]]: | |
""" Check if a SMILES is valid, i.e., it can be parsed by RDKit. | |
Args: | |
smiles (Optional[str]): The SMILES to check. | |
return_mol (bool): If True, return the RDKit molecule object, i.e., `(is_valid, mol)`. | |
Returns: | |
bool | Tuple[bool, Chem.Mol]: True if the SMILES is valid, False otherwise. If return_mol is True, also return the RDKit molecule object. | |
""" | |
if smiles is None: | |
return False | |
mol = Chem.MolFromSmiles(smiles) | |
if return_mol: | |
return mol is not None, mol | |
return mol is not None | |
def has_three_substructures(smiles: Optional[str]) -> bool: | |
""" Check if a PROTAC SMILES has three substructures. """ | |
if smiles is None: | |
return False | |
return smiles.count(".") == 2 | |
def has_all_attachment_points(smiles: Optional[str]) -> bool: | |
""" Check if a PROTAC SMILES has all attachment points, i.e., [*:1] and [*:2], two each. """ | |
if smiles is None: | |
return False | |
return smiles.count("[*:1]") == 2 and smiles.count("[*:2]") == 2 | |
def split_prediction( | |
pred: str, | |
poi_attachment_id: int = 1, | |
e3_attachment_id: int = 2, | |
) -> Optional[dict[str, str]]: | |
""" Split a PROTAC SMILES prediction into its three substructures. | |
Args: | |
pred (str): The SMILES of the PROTAC molecule. | |
poi_attachment_id (int): The attachment point ID for the POI substructure. | |
e3_attachment_id (int): The attachment point ID for the E3 substructure. | |
Returns: | |
dict[str, str] | None: A dictionary (with keys: 'e3', 'linker', 'poi') containing the SMILES notations for the POI, linker, and E3 substructures, or None if the prediction is invalid | |
""" | |
ret = {k: None for k in ['poi', 'linker', 'e3']} | |
if pred is None: | |
return ret | |
ligands = pred.split('.') | |
if len(ligands) != 3: | |
return ret | |
for ligand in ligands: | |
if f'[*:{poi_attachment_id}]' in ligand and f'[*:{e3_attachment_id}]' not in ligand: | |
ret['poi'] = ligand | |
elif f'[*:{e3_attachment_id}]' in ligand and f'[*:{poi_attachment_id}]' not in ligand: | |
ret['e3'] = ligand | |
elif f'[*:{poi_attachment_id}]' in ligand and f'[*:{e3_attachment_id}]' in ligand: | |
ret['linker'] = ligand | |
return ret | |
def rename_attachment_id(mol: Union[str, Chem.Mol], old_id: int, new_id: int) -> Union[str, Chem.Mol]: | |
""" Rename an attachment point ID in a molecule. | |
Args: | |
mol: The input molecule. | |
old_id: The old attachment point ID. | |
new_id: The new attachment point ID. | |
Returns: | |
The renamed molecule. | |
""" | |
return_str = False | |
if isinstance(mol, Chem.Mol): | |
mol = Chem.MolToSmiles(mol, canonical=True) | |
return_str = True | |
# Regex-replace the patterns "[*:old_id]" or "[old_id*]" with "[*:new_id]" | |
mol = re.sub(rf'\[\*:{old_id}\]', f'[*:{new_id}]', mol) | |
mol = re.sub(rf'\[{old_id}\*\]', f'[*:{new_id}]', mol) | |
mol = canonize_smiles(mol) | |
if mol is None: | |
return None | |
mol = Chem.MolFromSmiles(mol) | |
if return_str: | |
return Chem.MolToSmiles(mol, canonical=True) | |
return mol | |
def at_least_two_ligands_correct( | |
protac_smiles: str, | |
ligands_smiles: str, | |
) -> bool: | |
""" Check if at least two ligands are correct. """ | |
# Check if there is at least one "." in the ligands SMILES | |
if "." not in ligands_smiles: | |
return False | |
ligands = ligands_smiles.split(".") | |
return True | |
def check_reassembly( | |
protac_smiles: str, | |
ligands_smiles: str, | |
stats: Optional[Dict[str, int]] = None, | |
linker_can_be_null: bool = False, | |
poi_attachment_id: int = 1, | |
e3_attachment_id: int = 2, | |
verbose: int = 0, | |
return_reassembled_smiles: bool = False, | |
) -> bool: | |
"""Check if the reassembled PROTAC matches the original PROTAC SMILES. | |
Args: | |
protac_smiles (str): The original PROTAC SMILES. | |
ligands_smiles (str): The SMILES of the joined PROTAC ligands, separated by a "." (dot). | |
stats (Optional[Dict[str, int]]): A dictionary to store statistics about the reassembly process. | |
linker_can_be_null (bool): If False, the linker cannot be empty, and if so, a None will be returned. If True, a special check is performed to rename the E3 and WH attchament points to assemble them together. | |
poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]". Default is 1. | |
e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]". Default is 2. | |
verbose (int): The verbosity | |
Returns: | |
bool: True if the reassembled PROTAC matches the original PROTAC SMILES, False otherwise. None if it failed. | |
""" | |
ligands_smiles = canonize_smiles(ligands_smiles) | |
if ligands_smiles is None: | |
if verbose: | |
logging.error('Ligand could be canonicalized.') | |
return (False, None) if return_reassembled_smiles else False | |
null_linker_e3 = f'[*:{e3_attachment_id}][*:{poi_attachment_id}]' | |
null_linker_poi = f'[*:{poi_attachment_id}][*:{e3_attachment_id}]' | |
linker_is_null = False | |
if null_linker_e3 in ligands_smiles or null_linker_poi in ligands_smiles: | |
# If the linker is empty, remove the linker atoms | |
ligands_smiles = ligands_smiles.replace(null_linker_poi, '') | |
ligands_smiles = ligands_smiles.replace(null_linker_e3, '') | |
ligands_smiles = ligands_smiles.replace('..', '.') | |
ligands_smiles = ligands_smiles.rstrip('.') | |
ligands_smiles = ligands_smiles.lstrip('.') | |
ligands_smiles = canonize_smiles(ligands_smiles) | |
linker_is_null = True | |
if linker_can_be_null or linker_is_null: | |
if len(ligands_smiles.split('.')) == 2: | |
# Replace the attachment points with a third one (they will be joined later) | |
ligands_smiles = rename_attachment_id(ligands_smiles, e3_attachment_id, max([poi_attachment_id, e3_attachment_id]) + 1) | |
ligands_smiles = rename_attachment_id(ligands_smiles, poi_attachment_id, max([poi_attachment_id, e3_attachment_id]) + 1) | |
ligands_mol = Chem.MolFromSmiles(ligands_smiles) | |
if ligands_mol is None: | |
if verbose: | |
logging.error('ligands_mol is None') | |
return (False, None) if return_reassembled_smiles else False | |
try: | |
reassembled_mol = Chem.molzip(ligands_mol) | |
if reassembled_mol is None: | |
if stats is not None: | |
stats['molzip failed'] += 1 | |
if verbose: | |
logging.error(f'molzip failed') | |
return (False, None) if return_reassembled_smiles else False | |
except: | |
if stats is not None: | |
stats['molzip failed (exception)'] += 1 | |
if verbose: | |
logging.error(f'molzip failed (exception)') | |
return (False, None) if return_reassembled_smiles else False | |
try: | |
reassembled_smiles = canonize(Chem.MolToSmiles(reassembled_mol)) | |
if reassembled_smiles is None: | |
if stats is not None: | |
stats['MolToSmiles of reassembled failed'] += 1 | |
if verbose: | |
logging.error('MolToSmiles of reassembled failed') | |
return (False, None) if return_reassembled_smiles else False | |
except: | |
if stats is not None: | |
stats['MolToSmiles of reassembled failed'] += 1 | |
if verbose: | |
logging.error('MolToSmiles of reassembled failed') | |
return (False, None) if return_reassembled_smiles else False | |
is_equal = canonize(protac_smiles) == reassembled_smiles | |
return (is_equal, reassembled_smiles) if return_reassembled_smiles else is_equal | |
def check_substructs( | |
protac_smiles: str, | |
poi_smiles: str = None, | |
linker_smiles: str = None, | |
e3_smiles: str = None, | |
return_bond_types: bool = False, | |
poi_attachment_id: int = 1, | |
e3_attachment_id: int = 2, | |
pred: str = None, | |
) -> Union[bool, Tuple[bool, dict[str, str]]]: | |
""" DEPRECATED. | |
Check if the reassembled PROTAC is correct. | |
Args: | |
protac_smiles (str): The SMILES of the PROTAC molecule. | |
poi_smiles (str): The SMILES of the POI ligand. | |
linker_smiles (str): The SMILES of the linker. | |
e3_smiles (str): The SMILES of the E3 binder. | |
return_bond_types (bool): If True, return the bond types used for the reassembly. | |
poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]". | |
e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]". | |
pred (str): The SMILES of the predicted PROTAC molecule. | |
Returns: | |
bool | Tuple[bool, dict[str, str]]: True if the reassembled PROTAC is correct, False otherwise. If return_bond_types is True, also return the bond types used for the reassembly. | |
""" | |
def get_failed_return(): | |
if return_bond_types: | |
return False, {} | |
return False | |
# Make some checks before starting and fail if necessary | |
all_subs_none = all(v is None for v in [poi_smiles, linker_smiles, e3_smiles]) | |
any_subs_none = any(v is None for v in [poi_smiles, linker_smiles, e3_smiles]) | |
if pred is not None and all_subs_none: | |
# Split the prediction into the substructures | |
pred_substructs = split_prediction(pred, poi_attachment_id, e3_attachment_id) | |
if any(v is None for v in pred_substructs.values()): | |
return get_failed_return() | |
poi_smiles = pred_substructs['poi'] | |
linker_smiles = pred_substructs['linker'] | |
e3_smiles = pred_substructs['e3'] | |
elif pred is None and any_subs_none: | |
return get_failed_return() | |
elif pred is None and all_subs_none: | |
logging.warning("Arguments 'pred' and 'poi_smiles', 'linker_smiles', 'e3_smiles' cannot be all None.") | |
return get_failed_return() | |
if f"[*:{poi_attachment_id}]" in e3_smiles: | |
return get_failed_return() | |
if f"[*:{e3_attachment_id}]" in poi_smiles: | |
return get_failed_return() | |
if f"[*:{poi_attachment_id}]" not in linker_smiles: | |
return get_failed_return() | |
if f"[*:{e3_attachment_id}]" not in linker_smiles: | |
return get_failed_return() | |
correct_substructs = False | |
protac_mol = Chem.MolFromSmiles(protac_smiles) | |
protac_inchi = Chem.MolToInchi(protac_mol) | |
protac_smiles_canon = canonize_smiles(protac_smiles) | |
bond_types = {} | |
bonds = ['single', 'double', 'triple'] | |
# for e3_bond_type, poi_bond_type in itertools.product([bonds, bonds]): | |
for e3_bond_type in bonds: | |
for poi_bond_type in bonds: | |
try: | |
assmbl_smiles, assmbl_mol = reassemble_protac( | |
poi_smiles, | |
linker_smiles, | |
e3_smiles, | |
e3_bond_type, | |
poi_bond_type, | |
poi_attachment_id, | |
e3_attachment_id, | |
) | |
if assmbl_mol is not None: | |
# If either the InChI or SMILES of the reassembled PROTAC is | |
# the same as the original PROTAC, then the reassembly is | |
# correct. | |
if protac_inchi == Chem.MolToInchi(assmbl_mol): | |
correct_substructs = True | |
bond_types['e3_bond_type'] = e3_bond_type | |
bond_types['poi_bond_type'] = poi_bond_type | |
break | |
if protac_smiles_canon == canonize_smiles(assmbl_smiles): | |
correct_substructs = True | |
bond_types['e3_bond_type'] = e3_bond_type | |
bond_types['poi_bond_type'] = poi_bond_type | |
break | |
except: | |
continue | |
if return_bond_types: | |
return correct_substructs, bond_types | |
return correct_substructs | |
def score_prediction( | |
protac_smiles: str, | |
label_smiles: str, | |
pred_smiles: str, | |
rouge = None, | |
poi_attachment_id: int = 1, | |
e3_attachment_id: int = 2, | |
fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=11, fpSize=2048), | |
compute_rdkit_metrics: bool = False, | |
compute_graph_metrics: bool = False, | |
graph_edit_kwargs: Dict[str, Any] = {}, | |
) -> dict[str, float]: | |
""" Score a PROTAC SMILES prediction. | |
Args: | |
protac_smiles (str): The SMILES of the PROTAC molecule. | |
label_smiles (str): The SMILES of the ground truth PROTAC molecule. | |
pred_smiles (str): The SMILES of the predicted PROTAC molecule. | |
rouge (Rouge | None): The Rouge object to use for scoring. If None, do not compute Rouge scores. Example: `rouge = evaluate.load("rouge")` | |
poi_attachment_id (int): The attachment point ID for the POI substructure. | |
e3_attachment_id (int): The attachment point ID for the E3 substructure. | |
Returns: | |
dict[str, float]: A dictionary containing the scores for the prediction | |
""" | |
protac_mol = Chem.MolFromSmiles(protac_smiles) | |
protac_num_atoms = protac_mol.GetNumHeavyAtoms() | |
scores = { | |
'has_three_substructures': has_three_substructures(pred_smiles), | |
'has_all_attachment_points': has_all_attachment_points(pred_smiles), | |
'num_fragments': 0 if pred_smiles is None else pred_smiles.count('.') + 1, | |
'tanimoto_similarity': 0.0, # Default value | |
'valid': False, | |
'reassembly': False, | |
'reassembly_nostereo': False, | |
'heavy_atoms_difference': protac_num_atoms, | |
'heavy_atoms_difference_norm': 1.0, | |
'all_ligands_equal': False, | |
} | |
pred_substructs = split_prediction(pred_smiles, poi_attachment_id, e3_attachment_id) | |
# Compute metrics for the "entire" predicted PROTAC molecule | |
if None not in list(pred_substructs.values()): | |
e3_nostereo = remove_stereo(pred_substructs['e3']) | |
linker_nostereo = remove_stereo(pred_substructs['linker']) | |
poi_nostereo = remove_stereo(pred_substructs['poi']) | |
if None not in [e3_nostereo, linker_nostereo, poi_nostereo]: | |
pred_nostereo = f"{e3_nostereo}.{linker_nostereo}.{poi_nostereo}" | |
scores['reassembly_nostereo'] = check_reassembly(remove_stereo(protac_smiles), pred_nostereo) | |
scores['valid'] = is_valid_smiles(pred_smiles) | |
is_equal, reassembled_smiles = check_reassembly(protac_smiles, pred_smiles, return_reassembled_smiles=True) | |
scores['reassembly'] = is_equal | |
# Get the number of heavy atoms difference between the reassembled PROTAC and the ground truth PROTAC | |
if reassembled_smiles is not None: | |
reassembled_mol = Chem.MolFromSmiles(reassembled_smiles) | |
if reassembled_mol is not None: | |
scores['heavy_atoms_difference'] -= reassembled_mol.GetNumHeavyAtoms() | |
scores['heavy_atoms_difference_norm'] = scores['heavy_atoms_difference'] / protac_num_atoms | |
if scores['valid'] and compute_rdkit_metrics and fpgen is not None: | |
# Get Tanimoto similarity between the predicted PROTAC and the ground truth PROTAC | |
pred_mol = Chem.MolFromSmiles(pred_smiles) | |
label_mol = Chem.MolFromSmiles(label_smiles) | |
pred_fp = fpgen.GetFingerprint(pred_mol) | |
label_fp = fpgen.GetFingerprint(label_mol) | |
scores['tanimoto_similarity'] = DataStructs.TanimotoSimilarity(pred_fp, label_fp) | |
if rouge is not None: | |
rouge_output = rouge.compute(predictions=[pred_smiles], references=[label_smiles]) | |
scores.update({k: v for k, v in rouge_output.items()}) | |
# Compute metrics for each substructure | |
label_substructs = split_prediction(label_smiles, poi_attachment_id, e3_attachment_id) | |
# Set default values | |
for sub in ['e3', 'poi', 'linker']: | |
scores[f'{sub}_valid'] = False | |
scores[f'{sub}_equal'] = False | |
scores[f'{sub}_has_attachment_point(s)'] = False | |
scores[f'{sub}_tanimoto_similarity'] = 0.0 | |
# NOTE: The graph edit distance can be very high and dependant on the | |
# graphs, but when the molecule is not valid, then we cannot compute it. | |
# Because of that, we instead set it to something very large, in case we | |
# need to sum the eval metrics. | |
scores[f'{sub}_graph_edit_distance'] = 1e64 | |
scores[f'{sub}_graph_edit_distance_norm'] = 1.0 | |
scores[f'{sub}_heavy_atoms_difference'] = 0 | |
try: | |
scores[f'{sub}_heavy_atoms_difference'] = Chem.MolFromSmiles(label_substructs[sub]).GetNumHeavyAtoms() | |
except: | |
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'") | |
scores[f'{sub}_heavy_atoms_difference_norm'] = 1.0 | |
# Calculate metrics for each substructure | |
for sub in ['e3', 'poi', 'linker']: | |
# Skip if the predicted substructure is None from `split_prediction` | |
pred_sub = pred_substructs[sub] | |
label_sub = label_substructs[sub] | |
if pred_sub is None: | |
continue | |
if label_sub is None: | |
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'") | |
continue | |
# Check if the predicted substructure is a valid RDKit molecule | |
sub_valid, sub_mol = is_valid_smiles(pred_sub, return_mol=True) | |
scores[f'{sub}_valid'] = sub_valid | |
if sub_mol is None: | |
continue | |
# Check if the predicted substructure has the correct attachment point(s) | |
if sub == 'e3': | |
if f'[*:{e3_attachment_id}]' in pred_sub and f'[*:{poi_attachment_id}]' not in pred_sub: | |
scores[f'{sub}_has_attachment_point(s)'] = True | |
elif sub == 'poi': | |
if f'[*:{poi_attachment_id}]' in pred_sub and f'[*:{e3_attachment_id}]' not in pred_sub: | |
scores[f'{sub}_has_attachment_point(s)'] = True | |
elif sub == 'linker': | |
if f'[*:{poi_attachment_id}]' in pred_sub and f'[*:{e3_attachment_id}]' in pred_sub: | |
scores[f'{sub}_has_attachment_point(s)'] = True | |
# Check if the predicted substructure InChI is the same as the ground truth substructure InChI | |
if scores[f'{sub}_valid']: | |
# scores[f'{sub}_equal'] = Chem.MolToInchi(sub_mol) == Chem.MolToInchi(Chem.MolFromSmiles(label_sub)) | |
canon_pred = canonize_smiles(pred_sub) | |
canon_label = canonize_smiles(label_sub) | |
scores[f'{sub}_equal'] = canon_pred == canon_label | |
# Compute graph-related metrics | |
if scores[f'{sub}_valid'] and compute_graph_metrics: | |
scores[f'{sub}_graph_edit_distance'] = get_smiles2graph_edit_distance(pred_sub, label_sub, **graph_edit_kwargs) | |
scores[f'{sub}_graph_edit_distance_norm'] = get_smiles2graph_edit_distance_norm( | |
smi1=pred_sub, | |
smi2=label_sub, | |
ged_G1_G2=scores[f'{sub}_graph_edit_distance'], | |
**graph_edit_kwargs, | |
) | |
# Get the number of heavy atoms difference between the predicted substructure and the ground truth substructure | |
if scores[f'{sub}_valid']: | |
pred_mol = Chem.MolFromSmiles(pred_sub) | |
label_mol = Chem.MolFromSmiles(label_sub) | |
if label_mol is None: | |
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'") | |
continue | |
scores[f'{sub}_heavy_atoms_difference'] -= pred_mol.GetNumHeavyAtoms() | |
scores[f'{sub}_heavy_atoms_difference_norm'] = scores[f'{sub}_heavy_atoms_difference'] / label_mol.GetNumHeavyAtoms() | |
# Get Tanimoto similarity b/w the predicted substructure and the ground truth | |
if scores[f'{sub}_valid'] and compute_rdkit_metrics: | |
pred_mol = Chem.MolFromSmiles(pred_sub) | |
label_mol = Chem.MolFromSmiles(label_sub) | |
if label_mol is None: | |
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'") | |
continue | |
pred_fp = fpgen.GetFingerprint(pred_mol) | |
label_fp = fpgen.GetFingerprint(label_mol) | |
scores[f'{sub}_tanimoto_similarity'] = DataStructs.TanimotoSimilarity(pred_fp, label_fp) | |
# Compute Rouge scores | |
if rouge is not None: | |
rouge_output = rouge.compute(predictions=[pred_sub], references=[label_sub]) | |
scores.update({f'{sub}_{k}': v for k, v in rouge_output.items()}) | |
scores['all_ligands_equal'] = all([scores[f'{sub}_equal'] for sub in ['e3', 'poi', 'linker']]) | |
return scores |