Spaces:
Sleeping
Sleeping
File size: 17,963 Bytes
9dd777e ac8676d 6aa9fe5 9dd777e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
import os
import requests
from typing import Union, Optional, Dict, List
from pathlib import Path
import logging
from datasets import Dataset
import pandas as pd
from protac_splitter.chemoinformatics import canonize
from protac_splitter.fixing_functions import fix_prediction
from protac_splitter.llms.model_utils import get_pipeline, run_pipeline
from protac_splitter.graphs.e3_clustering import get_representative_e3s_fp
from protac_splitter.graphs.edge_classifier import GraphEdgeClassifier
from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based
def load_graph_edge_classifier_from_cache(
cache_dir: Union[str, Path] = "~/.cache/protac_splitter",
model_filename: str = "PROTAC-Splitter-XGBoost.joblib",
download_url: str = "https://docs.google.com/uc?export=download&id=1bb9i5_L_-re3QYPc7tSiCtVNEEbNIzAC",
) -> GraphEdgeClassifier:
"""
Loads the GraphEdgeClassifier model from a local cache directory.
If the model file is not found, downloads it from the specified URL.
Args:
cache_dir (str or Path): Directory to cache the model file.
model_filename (str): Name of the model file.
download_url (str): URL to download the model if not present.
Returns:
GraphEdgeClassifier: Loaded classifier.
"""
cache_dir = Path(os.path.expanduser(cache_dir))
cache_dir.mkdir(parents=True, exist_ok=True)
model_path = cache_dir / model_filename
if not model_path.exists():
response = requests.get(download_url, stream=True)
response.raise_for_status()
expected_size = int(response.headers.get("Content-Length", -1))
with open(model_path, "wb") as f:
for chunk in response.iter_content(chunk_size=1024*1024):
if chunk:
f.write(chunk)
if expected_size != -1:
actual = model_path.stat().st_size
if actual != expected_size:
raise RuntimeError(f"Download incomplete: got {actual}, expected {expected_size}")
# Optional checksum:
# NOTE: Uncomment the following for debugging
import hashlib
h = hashlib.sha256(model_path.read_bytes()).hexdigest()
h_orig = "513621f4dc2ff7ec819a222bc7311afb8b6e6e89d6d694dd2906e695a50086dd"
if h != h_orig:
raise RuntimeError(
f"Downloaded model checksum mismatch: got {h}, expected {h_orig}. "
"Please delete the model file and try again."
)
return GraphEdgeClassifier.load(model_path)
def split_protac(
protac_smiles: Union[str, List, pd.DataFrame],
use_transformer: bool = False,
use_xgboost: bool = True,
fix_predictions: bool = True,
protac_smiles_col: str = "text",
batch_size: int = 1,
beam_size: int = 5,
device: Optional[Union[int, str]] = None,
num_proc: int = 1,
verbose: int = 0,
) -> Union[Dict[str, str], List[Dict[str, str]]]:
""" Split a PROTAC SMILES into the two ligands and the linker.
If `use_transformer` and `use_xgboost` are both True, the Transformer model
will run first, and XGBost will be used as a fallback for predictions that
fail re-assembly and fixing. If both `use_transformer` and `use_xgboost`
are False, a fully heuristic-based algorithm will be used for splitting.
Args:
protac_smiles (str, list, or pd.DataFrame): The PROTAC SMILES to split.
If a DataFrame is provided, it must contain a column named `protac_smiles_col`.
use_transformer (bool): Whether to use the transformer model for splitting.
use_xgboost (bool): Whether to use the XGBoost model for splitting.
fix_predictions (bool): Whether to fix the predictions using deterministic cheminformatics rules. Only used if `use_transformer` is True.
protac_smiles_col (str): The name of the column containing the PROTAC SMILES in the DataFrame.
batch_size (int): Batch size for processing. Only used if `use_transformer` is True.
beam_size (int): Number of beam search predictions to generate. Only used if `use_transformer` is True. Higher values may yield better results but increase computation time.
device (int or str, optional): Device to run the Transformer model on. Defaults to None will attempt to run on GPU if available, otherwise CPU.
num_proc (int): Number of processes to use for parallel processing. Useful for large datasets of PROTACs to split.
verbose (int): Verbosity level.
Returns:
Union[Dict[str, str], List[Dict[str, str]]]: Depending on the input type, returns:
- If a single string is provided, returns a dictionary with format: `{protac_smiles_col: protac_smiles, "default_pred_n0": e3l.linker.warhead, "model_name": Transformer|XGBoost|Heuristic}`.
- If a list of strings is provided, returns a list of dictionaries with the same format as above.
- If a DataFrame is provided, returns a DataFrame with columns: `protac_smiles_col`, `default_pred_n0`, and `model_name`. The `default_pred_n0` column contains the predicted split strings in the format `e3.linker.warhead`.
"""
if use_xgboost:
representative_e3s_fp = get_representative_e3s_fp()
xgboost_model = load_graph_edge_classifier_from_cache()
# Generate a Dataset from the input PROTAC SMILES
if isinstance(protac_smiles, str):
protac_smiles_canon = canonize(protac_smiles)
if protac_smiles_canon is None:
raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
ds = Dataset.from_dict({protac_smiles_col: [protac_smiles_canon]})
elif isinstance(protac_smiles, list):
# Canonize and check if all PROTAC SMILES are valid
protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
if None in protac_smiles_canon:
wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
ds = Dataset.from_dict({protac_smiles_col: protac_smiles_canon})
elif isinstance(protac_smiles, pd.DataFrame):
# Check if the DataFrame contains a columns named `protac_smiles_col`
if protac_smiles_col not in protac_smiles.columns:
raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
# Canonize and check if all PROTAC SMILES are valid
protac_smiles_canon = protac_smiles[protac_smiles_col].apply(canonize)
if protac_smiles_canon.isnull().any():
wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
ds = Dataset.from_pandas(protac_smiles_canon.to_frame(name=protac_smiles_col))
if use_transformer:
pipe = get_pipeline(
model_name="ailab-bio/PROTAC-Splitter",
token=os.environ.get("HF_TOKEN", None),
is_causal_language_model=False,
num_return_sequences=beam_size,
device=device,
)
# preds will be a list of dictionaries, each containing the
# beam-size predictions for each input PROTAC SMILES. Format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
preds = run_pipeline(
pipe,
ds,
batch_size,
is_causal_language_model=False,
smiles_column=protac_smiles_col,
)
# Turn the predictions into a DataFrame and then into a Dataset
preds_df = pd.DataFrame(preds)
preds_df[protac_smiles_col] = ds[protac_smiles_col]
preds_ds = Dataset.from_pandas(preds_df)
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
"""Fix the predictions for each row."""
protac = row[protac_smiles_col]
if fix_predictions:
preds = {k: fix_prediction(protac, v, verbose=verbose) for k, v in row.items() if k.startswith("pred_")}
else:
preds = {k: v for k, v in row.items() if k.startswith("pred_")}
# If all preds are None, we attempt to use the XGBoost model
if all(v is None for v in preds.values()):
if use_xgboost:
pred = split_protac_graph_based(
protac_smiles=protac,
use_classifier=True,
classifier=xgboost_model,
representative_e3s_fp=representative_e3s_fp,
)
return {
protac_smiles_col: protac,
"default_pred_n0": f"{pred['e3']}.{pred['linker']}.{pred['poi']}",
"model_name": "XGBoost",
}
else:
# If no predictions are valid, we return None for the default prediction
return {
protac_smiles_col: protac,
"default_pred_n0": None,
"model_name": "Transformer",
}
else:
# Select the non-None prediction with the lowest beam index
# NOTE: The HF predictions comes in lists, with the first
# element being the one with the highest likelihood.
for i in range(beam_size):
key = f"pred_n{i}"
if preds[key] is not None:
return {
protac_smiles_col: protac,
"default_pred_n0": preds[key],
"model_name": "Transformer",
}
# Map the function over the Dataset to fix the predictions and/or
# replace them with the XGBoost fallback predictions if they fail.
if fix_predictions or use_xgboost:
preds_ds = preds_ds.map(
mapping_func,
num_proc=1 if use_xgboost else num_proc, # Using XGBoost IN a map function might not be thread-safe
desc=f"{'Fixing predictions' if fix_predictions else ''}{' and ' if fix_predictions and use_xgboost else ''}{'Replacing predictions with XGBoost fallback' if use_xgboost else ''}",
)
elif use_xgboost:
# Use the XGBoost model only
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
"""Split the PROTAC SMILES using the XGBoost model."""
protac = row[protac_smiles_col]
pred = split_protac_graph_based(
protac_smiles=protac,
use_classifier=True,
classifier=xgboost_model,
representative_e3s_fp=representative_e3s_fp,
)
if all(v is None for v in pred.values()):
split = None
else:
split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
return {
protac_smiles_col: protac,
"default_pred_n0": split,
"model_name": "XGBoost",
}
preds_ds = ds.map(
mapping_func,
num_proc=1,
desc="Splitting PROTAC SMILES using XGBoost model",
)
else:
# If neither transformer nor XGBoost is used, we use the heuristic-based
# algorithm, that does not require any model.
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
"""Split the PROTAC SMILES using the heuristic-based algorithm."""
protac = row[protac_smiles_col]
pred = split_protac_graph_based(
protac_smiles=protac,
use_classifier=False,
)
if all(v is None for v in pred.values()):
split = None
else:
split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
return {
protac_smiles_col: protac,
"default_pred_n0": split,
"model_name": "Heuristic",
}
preds_ds = ds.map(
mapping_func,
num_proc=num_proc,
desc="Splitting PROTAC SMILES using heuristic-based algorithm",
)
if isinstance(protac_smiles, str):
# If the input was a single string, we return the first prediction
return preds_ds[0]
elif isinstance(protac_smiles, pd.DataFrame):
# If the input was a DataFrame, we return a dataframe with the predictions
return preds_ds.to_pandas()
elif isinstance(protac_smiles, list):
# Convert the Dataset to a list of dictionaries
return [row for row in preds_ds]
# if tokenizer is None:
# if verbose:
# print(f"Loading tokenizer...")
# tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
# if pipe is None:
# if verbose:
# print("Loading pipeline for \"default\" predictions...")
# pipe = pipeline(
# "text2text-generation",
# model=model_name,
# tokenizer=tokenizer,
# device="cuda" if torch.cuda.is_available() else "cpu",
# token=hf_token,
# num_return_sequences=beam_size,
# )
# if isinstance(protac_smiles, str):
# protac_smiles_canon = canonize(protac_smiles)
# if protac_smiles_canon is None:
# raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
# pred = pipe(protac_smiles_canon)
# pred = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
# if fix_predictions:
# p_fixed = {k: fix_prediction(protac_smiles_canon, v, verbose=verbose) for k, v in pred.items()}
# # For each prediction, if the fixed prediction is not None, we
# # replace the original prediction with the fixed one.
# for k, v in p_fixed.items():
# if v is not None:
# pred[k] = v
# preds = [pred]
# if isinstance(protac_smiles, list):
# # Canonize and check if all PROTAC SMILES are valid
# protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
# if None in protac_smiles_canon:
# wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
# raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
# # Get the predictions for all PROTAC SMILES
# preds = pipe(protac_smiles_canon, batch_size=batch_size)
# preds = [{f"default_pred_n{i}": p["generated_text"] for i, p in enumerate(pred)} for pred in preds]
# if fix_predictions:
# for i, (protac, pred) in enumerate(zip(protac_smiles_canon, preds)):
# p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
# # For each prediction, if the fixed prediction is not None, we
# # replace the original prediction with the fixed one.
# for k, v in p_fixed.items():
# if v is not None:
# preds[i][k] = v
# if isinstance(protac_smiles, pd.DataFrame):
# # Check if the DataFrame contains a columns named `protac_smiles_col`
# if protac_smiles_col not in protac_smiles.columns:
# raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
# # Canonize and check if all PROTAC SMILES are valid
# protac_smiles_canon = protac_smiles.apply(lambda x: canonize(x[protac_smiles_col]), axis=1)
# # Check if there are invalid PROTAC SMILES
# if protac_smiles_canon.isnull().any():
# wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
# raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
# # Convert the Series to a DataFrame
# protac_smiles_canon = pd.DataFrame(protac_smiles_canon, columns=[protac_smiles_col])
# # Convert the DataFrame to a Dataset
# dataset = Dataset.from_pandas(protac_smiles_canon)
# preds = []
# for pred in tqdm(pipe(KeyDataset(dataset, protac_smiles_col), batch_size=batch_size), total=len(dataset) // batch_size, desc="Generating predictions"):
# p = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
# preds.append(p)
# if fix_predictions:
# for i, (protac, pred) in tqdm(enumerate(zip(protac_smiles_canon, preds)), desc="Fixing predictions", total=len(preds)):
# p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
# # For each prediction, if the fixed prediction is not None, we
# # replace the original prediction with the fixed one.
# for k, v in p_fixed.items():
# if v is not None:
# pred[k] = v
# if return_check_reassembly:
# if isinstance(protac_smiles_canon, str):
# protac_smiles_list = [protac_smiles_canon]
# elif isinstance(protac_smiles_canon, list):
# protac_smiles_list = protac_smiles_canon
# elif isinstance(protac_smiles_canon, pd.DataFrame):
# protac_smiles_list = protac_smiles_canon[protac_smiles_col].tolist()
# print("Checking re-assembly...")
# for protac, pred in zip(protac_smiles_list, preds):
# for i in range(beam_size):
# pred[f"reassembly_correct_n{i}"] = check_reassembly(protac, pred[f"default_pred_n{i}"])
# # Just take the first prediction if the input was a string
# if isinstance(protac_smiles, str):
# preds = preds[0]
# return preds |