ribesstefano's picture
Fixed None token to get_pipeline function
6aa9fe5
import os
import requests
from typing import Union, Optional, Dict, List
from pathlib import Path
import logging
from datasets import Dataset
import pandas as pd
from protac_splitter.chemoinformatics import canonize
from protac_splitter.fixing_functions import fix_prediction
from protac_splitter.llms.model_utils import get_pipeline, run_pipeline
from protac_splitter.graphs.e3_clustering import get_representative_e3s_fp
from protac_splitter.graphs.edge_classifier import GraphEdgeClassifier
from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based
def load_graph_edge_classifier_from_cache(
cache_dir: Union[str, Path] = "~/.cache/protac_splitter",
model_filename: str = "PROTAC-Splitter-XGBoost.joblib",
download_url: str = "https://docs.google.com/uc?export=download&id=1bb9i5_L_-re3QYPc7tSiCtVNEEbNIzAC",
) -> GraphEdgeClassifier:
"""
Loads the GraphEdgeClassifier model from a local cache directory.
If the model file is not found, downloads it from the specified URL.
Args:
cache_dir (str or Path): Directory to cache the model file.
model_filename (str): Name of the model file.
download_url (str): URL to download the model if not present.
Returns:
GraphEdgeClassifier: Loaded classifier.
"""
cache_dir = Path(os.path.expanduser(cache_dir))
cache_dir.mkdir(parents=True, exist_ok=True)
model_path = cache_dir / model_filename
if not model_path.exists():
response = requests.get(download_url, stream=True)
response.raise_for_status()
expected_size = int(response.headers.get("Content-Length", -1))
with open(model_path, "wb") as f:
for chunk in response.iter_content(chunk_size=1024*1024):
if chunk:
f.write(chunk)
if expected_size != -1:
actual = model_path.stat().st_size
if actual != expected_size:
raise RuntimeError(f"Download incomplete: got {actual}, expected {expected_size}")
# Optional checksum:
# NOTE: Uncomment the following for debugging
import hashlib
h = hashlib.sha256(model_path.read_bytes()).hexdigest()
h_orig = "513621f4dc2ff7ec819a222bc7311afb8b6e6e89d6d694dd2906e695a50086dd"
if h != h_orig:
raise RuntimeError(
f"Downloaded model checksum mismatch: got {h}, expected {h_orig}. "
"Please delete the model file and try again."
)
return GraphEdgeClassifier.load(model_path)
def split_protac(
protac_smiles: Union[str, List, pd.DataFrame],
use_transformer: bool = False,
use_xgboost: bool = True,
fix_predictions: bool = True,
protac_smiles_col: str = "text",
batch_size: int = 1,
beam_size: int = 5,
device: Optional[Union[int, str]] = None,
num_proc: int = 1,
verbose: int = 0,
) -> Union[Dict[str, str], List[Dict[str, str]]]:
""" Split a PROTAC SMILES into the two ligands and the linker.
If `use_transformer` and `use_xgboost` are both True, the Transformer model
will run first, and XGBost will be used as a fallback for predictions that
fail re-assembly and fixing. If both `use_transformer` and `use_xgboost`
are False, a fully heuristic-based algorithm will be used for splitting.
Args:
protac_smiles (str, list, or pd.DataFrame): The PROTAC SMILES to split.
If a DataFrame is provided, it must contain a column named `protac_smiles_col`.
use_transformer (bool): Whether to use the transformer model for splitting.
use_xgboost (bool): Whether to use the XGBoost model for splitting.
fix_predictions (bool): Whether to fix the predictions using deterministic cheminformatics rules. Only used if `use_transformer` is True.
protac_smiles_col (str): The name of the column containing the PROTAC SMILES in the DataFrame.
batch_size (int): Batch size for processing. Only used if `use_transformer` is True.
beam_size (int): Number of beam search predictions to generate. Only used if `use_transformer` is True. Higher values may yield better results but increase computation time.
device (int or str, optional): Device to run the Transformer model on. Defaults to None will attempt to run on GPU if available, otherwise CPU.
num_proc (int): Number of processes to use for parallel processing. Useful for large datasets of PROTACs to split.
verbose (int): Verbosity level.
Returns:
Union[Dict[str, str], List[Dict[str, str]]]: Depending on the input type, returns:
- If a single string is provided, returns a dictionary with format: `{protac_smiles_col: protac_smiles, "default_pred_n0": e3l.linker.warhead, "model_name": Transformer|XGBoost|Heuristic}`.
- If a list of strings is provided, returns a list of dictionaries with the same format as above.
- If a DataFrame is provided, returns a DataFrame with columns: `protac_smiles_col`, `default_pred_n0`, and `model_name`. The `default_pred_n0` column contains the predicted split strings in the format `e3.linker.warhead`.
"""
if use_xgboost:
representative_e3s_fp = get_representative_e3s_fp()
xgboost_model = load_graph_edge_classifier_from_cache()
# Generate a Dataset from the input PROTAC SMILES
if isinstance(protac_smiles, str):
protac_smiles_canon = canonize(protac_smiles)
if protac_smiles_canon is None:
raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
ds = Dataset.from_dict({protac_smiles_col: [protac_smiles_canon]})
elif isinstance(protac_smiles, list):
# Canonize and check if all PROTAC SMILES are valid
protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
if None in protac_smiles_canon:
wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
ds = Dataset.from_dict({protac_smiles_col: protac_smiles_canon})
elif isinstance(protac_smiles, pd.DataFrame):
# Check if the DataFrame contains a columns named `protac_smiles_col`
if protac_smiles_col not in protac_smiles.columns:
raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
# Canonize and check if all PROTAC SMILES are valid
protac_smiles_canon = protac_smiles[protac_smiles_col].apply(canonize)
if protac_smiles_canon.isnull().any():
wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
ds = Dataset.from_pandas(protac_smiles_canon.to_frame(name=protac_smiles_col))
if use_transformer:
pipe = get_pipeline(
model_name="ailab-bio/PROTAC-Splitter",
token=os.environ.get("HF_TOKEN", None),
is_causal_language_model=False,
num_return_sequences=beam_size,
device=device,
)
# preds will be a list of dictionaries, each containing the
# beam-size predictions for each input PROTAC SMILES. Format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
preds = run_pipeline(
pipe,
ds,
batch_size,
is_causal_language_model=False,
smiles_column=protac_smiles_col,
)
# Turn the predictions into a DataFrame and then into a Dataset
preds_df = pd.DataFrame(preds)
preds_df[protac_smiles_col] = ds[protac_smiles_col]
preds_ds = Dataset.from_pandas(preds_df)
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
"""Fix the predictions for each row."""
protac = row[protac_smiles_col]
if fix_predictions:
preds = {k: fix_prediction(protac, v, verbose=verbose) for k, v in row.items() if k.startswith("pred_")}
else:
preds = {k: v for k, v in row.items() if k.startswith("pred_")}
# If all preds are None, we attempt to use the XGBoost model
if all(v is None for v in preds.values()):
if use_xgboost:
pred = split_protac_graph_based(
protac_smiles=protac,
use_classifier=True,
classifier=xgboost_model,
representative_e3s_fp=representative_e3s_fp,
)
return {
protac_smiles_col: protac,
"default_pred_n0": f"{pred['e3']}.{pred['linker']}.{pred['poi']}",
"model_name": "XGBoost",
}
else:
# If no predictions are valid, we return None for the default prediction
return {
protac_smiles_col: protac,
"default_pred_n0": None,
"model_name": "Transformer",
}
else:
# Select the non-None prediction with the lowest beam index
# NOTE: The HF predictions comes in lists, with the first
# element being the one with the highest likelihood.
for i in range(beam_size):
key = f"pred_n{i}"
if preds[key] is not None:
return {
protac_smiles_col: protac,
"default_pred_n0": preds[key],
"model_name": "Transformer",
}
# Map the function over the Dataset to fix the predictions and/or
# replace them with the XGBoost fallback predictions if they fail.
if fix_predictions or use_xgboost:
preds_ds = preds_ds.map(
mapping_func,
num_proc=1 if use_xgboost else num_proc, # Using XGBoost IN a map function might not be thread-safe
desc=f"{'Fixing predictions' if fix_predictions else ''}{' and ' if fix_predictions and use_xgboost else ''}{'Replacing predictions with XGBoost fallback' if use_xgboost else ''}",
)
elif use_xgboost:
# Use the XGBoost model only
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
"""Split the PROTAC SMILES using the XGBoost model."""
protac = row[protac_smiles_col]
pred = split_protac_graph_based(
protac_smiles=protac,
use_classifier=True,
classifier=xgboost_model,
representative_e3s_fp=representative_e3s_fp,
)
if all(v is None for v in pred.values()):
split = None
else:
split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
return {
protac_smiles_col: protac,
"default_pred_n0": split,
"model_name": "XGBoost",
}
preds_ds = ds.map(
mapping_func,
num_proc=1,
desc="Splitting PROTAC SMILES using XGBoost model",
)
else:
# If neither transformer nor XGBoost is used, we use the heuristic-based
# algorithm, that does not require any model.
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
"""Split the PROTAC SMILES using the heuristic-based algorithm."""
protac = row[protac_smiles_col]
pred = split_protac_graph_based(
protac_smiles=protac,
use_classifier=False,
)
if all(v is None for v in pred.values()):
split = None
else:
split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
return {
protac_smiles_col: protac,
"default_pred_n0": split,
"model_name": "Heuristic",
}
preds_ds = ds.map(
mapping_func,
num_proc=num_proc,
desc="Splitting PROTAC SMILES using heuristic-based algorithm",
)
if isinstance(protac_smiles, str):
# If the input was a single string, we return the first prediction
return preds_ds[0]
elif isinstance(protac_smiles, pd.DataFrame):
# If the input was a DataFrame, we return a dataframe with the predictions
return preds_ds.to_pandas()
elif isinstance(protac_smiles, list):
# Convert the Dataset to a list of dictionaries
return [row for row in preds_ds]
# if tokenizer is None:
# if verbose:
# print(f"Loading tokenizer...")
# tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
# if pipe is None:
# if verbose:
# print("Loading pipeline for \"default\" predictions...")
# pipe = pipeline(
# "text2text-generation",
# model=model_name,
# tokenizer=tokenizer,
# device="cuda" if torch.cuda.is_available() else "cpu",
# token=hf_token,
# num_return_sequences=beam_size,
# )
# if isinstance(protac_smiles, str):
# protac_smiles_canon = canonize(protac_smiles)
# if protac_smiles_canon is None:
# raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
# pred = pipe(protac_smiles_canon)
# pred = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
# if fix_predictions:
# p_fixed = {k: fix_prediction(protac_smiles_canon, v, verbose=verbose) for k, v in pred.items()}
# # For each prediction, if the fixed prediction is not None, we
# # replace the original prediction with the fixed one.
# for k, v in p_fixed.items():
# if v is not None:
# pred[k] = v
# preds = [pred]
# if isinstance(protac_smiles, list):
# # Canonize and check if all PROTAC SMILES are valid
# protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
# if None in protac_smiles_canon:
# wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
# raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
# # Get the predictions for all PROTAC SMILES
# preds = pipe(protac_smiles_canon, batch_size=batch_size)
# preds = [{f"default_pred_n{i}": p["generated_text"] for i, p in enumerate(pred)} for pred in preds]
# if fix_predictions:
# for i, (protac, pred) in enumerate(zip(protac_smiles_canon, preds)):
# p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
# # For each prediction, if the fixed prediction is not None, we
# # replace the original prediction with the fixed one.
# for k, v in p_fixed.items():
# if v is not None:
# preds[i][k] = v
# if isinstance(protac_smiles, pd.DataFrame):
# # Check if the DataFrame contains a columns named `protac_smiles_col`
# if protac_smiles_col not in protac_smiles.columns:
# raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
# # Canonize and check if all PROTAC SMILES are valid
# protac_smiles_canon = protac_smiles.apply(lambda x: canonize(x[protac_smiles_col]), axis=1)
# # Check if there are invalid PROTAC SMILES
# if protac_smiles_canon.isnull().any():
# wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
# raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
# # Convert the Series to a DataFrame
# protac_smiles_canon = pd.DataFrame(protac_smiles_canon, columns=[protac_smiles_col])
# # Convert the DataFrame to a Dataset
# dataset = Dataset.from_pandas(protac_smiles_canon)
# preds = []
# for pred in tqdm(pipe(KeyDataset(dataset, protac_smiles_col), batch_size=batch_size), total=len(dataset) // batch_size, desc="Generating predictions"):
# p = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
# preds.append(p)
# if fix_predictions:
# for i, (protac, pred) in tqdm(enumerate(zip(protac_smiles_canon, preds)), desc="Fixing predictions", total=len(preds)):
# p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
# # For each prediction, if the fixed prediction is not None, we
# # replace the original prediction with the fixed one.
# for k, v in p_fixed.items():
# if v is not None:
# pred[k] = v
# if return_check_reassembly:
# if isinstance(protac_smiles_canon, str):
# protac_smiles_list = [protac_smiles_canon]
# elif isinstance(protac_smiles_canon, list):
# protac_smiles_list = protac_smiles_canon
# elif isinstance(protac_smiles_canon, pd.DataFrame):
# protac_smiles_list = protac_smiles_canon[protac_smiles_col].tolist()
# print("Checking re-assembly...")
# for protac, pred in zip(protac_smiles_list, preds):
# for i in range(beam_size):
# pred[f"reassembly_correct_n{i}"] = check_reassembly(protac, pred[f"default_pred_n{i}"])
# # Just take the first prediction if the input was a string
# if isinstance(protac_smiles, str):
# preds = preds[0]
# return preds