import os import requests from typing import Union, Optional, Dict, List from pathlib import Path import logging from datasets import Dataset import pandas as pd from protac_splitter.chemoinformatics import canonize from protac_splitter.fixing_functions import fix_prediction from protac_splitter.llms.model_utils import get_pipeline, run_pipeline from protac_splitter.graphs.e3_clustering import get_representative_e3s_fp from protac_splitter.graphs.edge_classifier import GraphEdgeClassifier from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based def load_graph_edge_classifier_from_cache( cache_dir: Union[str, Path] = "~/.cache/protac_splitter", model_filename: str = "PROTAC-Splitter-XGBoost.joblib", download_url: str = "https://docs.google.com/uc?export=download&id=1bb9i5_L_-re3QYPc7tSiCtVNEEbNIzAC", ) -> GraphEdgeClassifier: """ Loads the GraphEdgeClassifier model from a local cache directory. If the model file is not found, downloads it from the specified URL. Args: cache_dir (str or Path): Directory to cache the model file. model_filename (str): Name of the model file. download_url (str): URL to download the model if not present. Returns: GraphEdgeClassifier: Loaded classifier. """ cache_dir = Path(os.path.expanduser(cache_dir)) cache_dir.mkdir(parents=True, exist_ok=True) model_path = cache_dir / model_filename if not model_path.exists(): response = requests.get(download_url, stream=True) response.raise_for_status() expected_size = int(response.headers.get("Content-Length", -1)) with open(model_path, "wb") as f: for chunk in response.iter_content(chunk_size=1024*1024): if chunk: f.write(chunk) if expected_size != -1: actual = model_path.stat().st_size if actual != expected_size: raise RuntimeError(f"Download incomplete: got {actual}, expected {expected_size}") # Optional checksum: # NOTE: Uncomment the following for debugging import hashlib h = hashlib.sha256(model_path.read_bytes()).hexdigest() h_orig = "513621f4dc2ff7ec819a222bc7311afb8b6e6e89d6d694dd2906e695a50086dd" if h != h_orig: raise RuntimeError( f"Downloaded model checksum mismatch: got {h}, expected {h_orig}. " "Please delete the model file and try again." ) return GraphEdgeClassifier.load(model_path) def split_protac( protac_smiles: Union[str, List, pd.DataFrame], use_transformer: bool = False, use_xgboost: bool = True, fix_predictions: bool = True, protac_smiles_col: str = "text", batch_size: int = 1, beam_size: int = 5, device: Optional[Union[int, str]] = None, num_proc: int = 1, verbose: int = 0, ) -> Union[Dict[str, str], List[Dict[str, str]]]: """ Split a PROTAC SMILES into the two ligands and the linker. If `use_transformer` and `use_xgboost` are both True, the Transformer model will run first, and XGBost will be used as a fallback for predictions that fail re-assembly and fixing. If both `use_transformer` and `use_xgboost` are False, a fully heuristic-based algorithm will be used for splitting. Args: protac_smiles (str, list, or pd.DataFrame): The PROTAC SMILES to split. If a DataFrame is provided, it must contain a column named `protac_smiles_col`. use_transformer (bool): Whether to use the transformer model for splitting. use_xgboost (bool): Whether to use the XGBoost model for splitting. fix_predictions (bool): Whether to fix the predictions using deterministic cheminformatics rules. Only used if `use_transformer` is True. protac_smiles_col (str): The name of the column containing the PROTAC SMILES in the DataFrame. batch_size (int): Batch size for processing. Only used if `use_transformer` is True. beam_size (int): Number of beam search predictions to generate. Only used if `use_transformer` is True. Higher values may yield better results but increase computation time. device (int or str, optional): Device to run the Transformer model on. Defaults to None will attempt to run on GPU if available, otherwise CPU. num_proc (int): Number of processes to use for parallel processing. Useful for large datasets of PROTACs to split. verbose (int): Verbosity level. Returns: Union[Dict[str, str], List[Dict[str, str]]]: Depending on the input type, returns: - If a single string is provided, returns a dictionary with format: `{protac_smiles_col: protac_smiles, "default_pred_n0": e3l.linker.warhead, "model_name": Transformer|XGBoost|Heuristic}`. - If a list of strings is provided, returns a list of dictionaries with the same format as above. - If a DataFrame is provided, returns a DataFrame with columns: `protac_smiles_col`, `default_pred_n0`, and `model_name`. The `default_pred_n0` column contains the predicted split strings in the format `e3.linker.warhead`. """ if use_xgboost: representative_e3s_fp = get_representative_e3s_fp() xgboost_model = load_graph_edge_classifier_from_cache() # Generate a Dataset from the input PROTAC SMILES if isinstance(protac_smiles, str): protac_smiles_canon = canonize(protac_smiles) if protac_smiles_canon is None: raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}") ds = Dataset.from_dict({protac_smiles_col: [protac_smiles_canon]}) elif isinstance(protac_smiles, list): # Canonize and check if all PROTAC SMILES are valid protac_smiles_canon = [canonize(protac) for protac in protac_smiles] if None in protac_smiles_canon: wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None] raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}") ds = Dataset.from_dict({protac_smiles_col: protac_smiles_canon}) elif isinstance(protac_smiles, pd.DataFrame): # Check if the DataFrame contains a columns named `protac_smiles_col` if protac_smiles_col not in protac_smiles.columns: raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".") # Canonize and check if all PROTAC SMILES are valid protac_smiles_canon = protac_smiles[protac_smiles_col].apply(canonize) if protac_smiles_canon.isnull().any(): wrong_protacs = protac_smiles[protac_smiles_canon.isnull()] raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}") ds = Dataset.from_pandas(protac_smiles_canon.to_frame(name=protac_smiles_col)) if use_transformer: pipe = get_pipeline( model_name="ailab-bio/PROTAC-Splitter", token=os.environ.get("HF_TOKEN", None), is_causal_language_model=False, num_return_sequences=beam_size, device=device, ) # preds will be a list of dictionaries, each containing the # beam-size predictions for each input PROTAC SMILES. Format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...] preds = run_pipeline( pipe, ds, batch_size, is_causal_language_model=False, smiles_column=protac_smiles_col, ) # Turn the predictions into a DataFrame and then into a Dataset preds_df = pd.DataFrame(preds) preds_df[protac_smiles_col] = ds[protac_smiles_col] preds_ds = Dataset.from_pandas(preds_df) def mapping_func(row: Dict[str, str]) -> Dict[str, str]: """Fix the predictions for each row.""" protac = row[protac_smiles_col] if fix_predictions: preds = {k: fix_prediction(protac, v, verbose=verbose) for k, v in row.items() if k.startswith("pred_")} else: preds = {k: v for k, v in row.items() if k.startswith("pred_")} # If all preds are None, we attempt to use the XGBoost model if all(v is None for v in preds.values()): if use_xgboost: pred = split_protac_graph_based( protac_smiles=protac, use_classifier=True, classifier=xgboost_model, representative_e3s_fp=representative_e3s_fp, ) return { protac_smiles_col: protac, "default_pred_n0": f"{pred['e3']}.{pred['linker']}.{pred['poi']}", "model_name": "XGBoost", } else: # If no predictions are valid, we return None for the default prediction return { protac_smiles_col: protac, "default_pred_n0": None, "model_name": "Transformer", } else: # Select the non-None prediction with the lowest beam index # NOTE: The HF predictions comes in lists, with the first # element being the one with the highest likelihood. for i in range(beam_size): key = f"pred_n{i}" if preds[key] is not None: return { protac_smiles_col: protac, "default_pred_n0": preds[key], "model_name": "Transformer", } # Map the function over the Dataset to fix the predictions and/or # replace them with the XGBoost fallback predictions if they fail. if fix_predictions or use_xgboost: preds_ds = preds_ds.map( mapping_func, num_proc=1 if use_xgboost else num_proc, # Using XGBoost IN a map function might not be thread-safe desc=f"{'Fixing predictions' if fix_predictions else ''}{' and ' if fix_predictions and use_xgboost else ''}{'Replacing predictions with XGBoost fallback' if use_xgboost else ''}", ) elif use_xgboost: # Use the XGBoost model only def mapping_func(row: Dict[str, str]) -> Dict[str, str]: """Split the PROTAC SMILES using the XGBoost model.""" protac = row[protac_smiles_col] pred = split_protac_graph_based( protac_smiles=protac, use_classifier=True, classifier=xgboost_model, representative_e3s_fp=representative_e3s_fp, ) if all(v is None for v in pred.values()): split = None else: split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}" return { protac_smiles_col: protac, "default_pred_n0": split, "model_name": "XGBoost", } preds_ds = ds.map( mapping_func, num_proc=1, desc="Splitting PROTAC SMILES using XGBoost model", ) else: # If neither transformer nor XGBoost is used, we use the heuristic-based # algorithm, that does not require any model. def mapping_func(row: Dict[str, str]) -> Dict[str, str]: """Split the PROTAC SMILES using the heuristic-based algorithm.""" protac = row[protac_smiles_col] pred = split_protac_graph_based( protac_smiles=protac, use_classifier=False, ) if all(v is None for v in pred.values()): split = None else: split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}" return { protac_smiles_col: protac, "default_pred_n0": split, "model_name": "Heuristic", } preds_ds = ds.map( mapping_func, num_proc=num_proc, desc="Splitting PROTAC SMILES using heuristic-based algorithm", ) if isinstance(protac_smiles, str): # If the input was a single string, we return the first prediction return preds_ds[0] elif isinstance(protac_smiles, pd.DataFrame): # If the input was a DataFrame, we return a dataframe with the predictions return preds_ds.to_pandas() elif isinstance(protac_smiles, list): # Convert the Dataset to a list of dictionaries return [row for row in preds_ds] # if tokenizer is None: # if verbose: # print(f"Loading tokenizer...") # tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) # if pipe is None: # if verbose: # print("Loading pipeline for \"default\" predictions...") # pipe = pipeline( # "text2text-generation", # model=model_name, # tokenizer=tokenizer, # device="cuda" if torch.cuda.is_available() else "cpu", # token=hf_token, # num_return_sequences=beam_size, # ) # if isinstance(protac_smiles, str): # protac_smiles_canon = canonize(protac_smiles) # if protac_smiles_canon is None: # raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}") # pred = pipe(protac_smiles_canon) # pred = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))} # if fix_predictions: # p_fixed = {k: fix_prediction(protac_smiles_canon, v, verbose=verbose) for k, v in pred.items()} # # For each prediction, if the fixed prediction is not None, we # # replace the original prediction with the fixed one. # for k, v in p_fixed.items(): # if v is not None: # pred[k] = v # preds = [pred] # if isinstance(protac_smiles, list): # # Canonize and check if all PROTAC SMILES are valid # protac_smiles_canon = [canonize(protac) for protac in protac_smiles] # if None in protac_smiles_canon: # wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None] # raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}") # # Get the predictions for all PROTAC SMILES # preds = pipe(protac_smiles_canon, batch_size=batch_size) # preds = [{f"default_pred_n{i}": p["generated_text"] for i, p in enumerate(pred)} for pred in preds] # if fix_predictions: # for i, (protac, pred) in enumerate(zip(protac_smiles_canon, preds)): # p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()} # # For each prediction, if the fixed prediction is not None, we # # replace the original prediction with the fixed one. # for k, v in p_fixed.items(): # if v is not None: # preds[i][k] = v # if isinstance(protac_smiles, pd.DataFrame): # # Check if the DataFrame contains a columns named `protac_smiles_col` # if protac_smiles_col not in protac_smiles.columns: # raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".") # # Canonize and check if all PROTAC SMILES are valid # protac_smiles_canon = protac_smiles.apply(lambda x: canonize(x[protac_smiles_col]), axis=1) # # Check if there are invalid PROTAC SMILES # if protac_smiles_canon.isnull().any(): # wrong_protacs = protac_smiles[protac_smiles_canon.isnull()] # raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}") # # Convert the Series to a DataFrame # protac_smiles_canon = pd.DataFrame(protac_smiles_canon, columns=[protac_smiles_col]) # # Convert the DataFrame to a Dataset # dataset = Dataset.from_pandas(protac_smiles_canon) # preds = [] # for pred in tqdm(pipe(KeyDataset(dataset, protac_smiles_col), batch_size=batch_size), total=len(dataset) // batch_size, desc="Generating predictions"): # p = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))} # preds.append(p) # if fix_predictions: # for i, (protac, pred) in tqdm(enumerate(zip(protac_smiles_canon, preds)), desc="Fixing predictions", total=len(preds)): # p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()} # # For each prediction, if the fixed prediction is not None, we # # replace the original prediction with the fixed one. # for k, v in p_fixed.items(): # if v is not None: # pred[k] = v # if return_check_reassembly: # if isinstance(protac_smiles_canon, str): # protac_smiles_list = [protac_smiles_canon] # elif isinstance(protac_smiles_canon, list): # protac_smiles_list = protac_smiles_canon # elif isinstance(protac_smiles_canon, pd.DataFrame): # protac_smiles_list = protac_smiles_canon[protac_smiles_col].tolist() # print("Checking re-assembly...") # for protac, pred in zip(protac_smiles_list, preds): # for i in range(beam_size): # pred[f"reassembly_correct_n{i}"] = check_reassembly(protac, pred[f"default_pred_n{i}"]) # # Just take the first prediction if the input was a string # if isinstance(protac_smiles, str): # preds = preds[0] # return preds