Spaces:
Sleeping
Sleeping
import joblib | |
from pathlib import Path | |
from typing import Optional, List, Dict, Union, Any, Literal | |
import pandas as pd | |
import numpy as np | |
from sklearn.base import BaseEstimator, ClassifierMixin | |
from sklearn.compose import ColumnTransformer | |
from sklearn.preprocessing import StandardScaler, OneHotEncoder | |
from sklearn.decomposition import TruncatedSVD | |
from imblearn.over_sampling import SMOTE | |
from imblearn.pipeline import Pipeline as ImbPipeline | |
from sklearn.pipeline import Pipeline | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
from xgboost import XGBClassifier | |
import optuna | |
from optuna.samplers import QMCSampler | |
from sklearn.metrics import accuracy_score, f1_score | |
try: | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
HAS_VISUALIZATION = True | |
except ImportError: | |
HAS_VISUALIZATION = False | |
from .edge_features import extract_edge_features, get_edge_features | |
class GraphEdgeClassifier(BaseEstimator, ClassifierMixin): | |
""" | |
Edge-level graph classifier for PROTACs with integrated pipeline building. | |
""" | |
def __init__( | |
self, | |
graph_features: List[str], | |
categorical_features: Optional[List[str]] = None, | |
descriptor_features: Optional[List[str]] = None, | |
fingerprint_features: Optional[List[str]] = None, | |
use_descriptors: bool = True, | |
use_fingerprints: bool = True, | |
scaler_graph: Literal["passthrough", "standard"] = "passthrough", | |
scaler_desc: Literal["passthrough", "standard"] = "passthrough", | |
use_svd_fp: bool = True, | |
n_svd_components: int = 100, | |
binary: bool = False, | |
smote_k_neighbors: Optional[int] = 5, | |
xgb_params: Optional[dict] = None, | |
n_bits: int = 512, | |
radius: int = 6, | |
descriptor_names: Optional[List[str]] = None | |
): | |
self.graph_features = graph_features | |
self.categorical_features = categorical_features | |
self.descriptor_features = descriptor_features | |
self.fingerprint_features = fingerprint_features | |
self.use_descriptors = use_descriptors | |
self.use_fingerprints = use_fingerprints | |
self.scaler_graph = scaler_graph | |
self.scaler_desc = scaler_desc | |
self.use_svd_fp = use_svd_fp | |
self.n_svd_components = n_svd_components | |
self.binary = binary | |
self.smote_k_neighbors = smote_k_neighbors | |
self.xgb_params = xgb_params or {} | |
self.n_bits = n_bits | |
self.radius = radius | |
self.descriptor_names = descriptor_names or [ | |
"MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors", | |
"TPSA", "NumRotatableBonds", "RingCount", "MolLogP" | |
] | |
self.pipeline = self._build_pipeline() | |
def _build_pipeline(self): | |
transformers = [] | |
if self.categorical_features: | |
transformers.append(("cat", OneHotEncoder(handle_unknown="ignore"), self.categorical_features)) | |
if self.scaler_graph == "standard": | |
transformers.append(("num", StandardScaler(), self.graph_features)) | |
else: | |
transformers.append(("num", "passthrough", self.graph_features)) | |
if self.use_descriptors and self.descriptor_features: | |
desc_block = ( | |
("desc", StandardScaler(), self.descriptor_features) | |
if self.scaler_desc == "standard" | |
else ("desc", "passthrough", self.descriptor_features) | |
) | |
transformers.append(desc_block) | |
if self.use_fingerprints and self.fingerprint_features: | |
if self.use_svd_fp: | |
fp_block = ("fp", | |
ImbPipeline([ | |
("svd", TruncatedSVD(n_components=self.n_svd_components, random_state=42)) | |
]), | |
self.fingerprint_features) | |
else: | |
fp_block = ("fp", "passthrough", self.fingerprint_features) | |
transformers.append(fp_block) | |
preprocessor = ColumnTransformer(transformers) | |
# Define the classifier | |
classifier = XGBClassifier( | |
random_state=42, | |
eval_metric="logloss" if self.binary else "mlogloss", | |
objective="binary:logistic" if self.binary else "multi:softprob", | |
**self.xgb_params | |
) | |
if self.smote_k_neighbors is not None: | |
return ImbPipeline([ | |
("preprocess", preprocessor), | |
("smote", SMOTE(random_state=42, k_neighbors=self.smote_k_neighbors)), | |
("clf", classifier) | |
]) | |
else: | |
return Pipeline([ | |
("preprocess", preprocessor), | |
("clf", classifier) | |
]) | |
def fit(self, X: pd.DataFrame, y: pd.Series): | |
self.pipeline.fit(X, y) | |
return self | |
def predict(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any: | |
X_proc = self._ensure_features(X) | |
return self.pipeline.predict(X_proc) | |
def predict_proba(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any: | |
X_proc = self._ensure_features(X) | |
return self.pipeline.predict_proba(X_proc) | |
def save(self, path: Union[str, Path]): | |
joblib.dump(self, str(path)) | |
def load(cls, path: Union[str, Path]) -> "GraphEdgeClassifier": | |
return joblib.load(str(path)) | |
def extract_graph_features( | |
protac_smiles: Union[str, List[str]], | |
wh_smiles: Optional[Union[str, List[str]]] = None, | |
lk_smiles: Optional[Union[str, List[str]]] = None, | |
e3_smiles: Optional[Union[str, List[str]]] = None, | |
n_bits: int = 512, | |
radius: int = 6, | |
descriptor_names: Optional[List[str]] = None, | |
verbose: int = 0 | |
) -> pd.DataFrame: | |
if any(x is None for x in [wh_smiles, lk_smiles, e3_smiles]): | |
# Get features from PROTAC only, for inference | |
return extract_edge_features( | |
protac_smiles=protac_smiles, | |
n_bits=n_bits, | |
radius=radius, | |
descriptor_names=descriptor_names, | |
) | |
else: | |
# Get features and labels from all components, for training | |
return get_edge_features( | |
protac_smiles=protac_smiles, | |
wh_smiles=wh_smiles, | |
lk_smiles=lk_smiles, | |
e3_smiles=e3_smiles, | |
n_bits=n_bits, | |
radius=radius, | |
descriptor_names=descriptor_names, | |
verbose=verbose | |
) | |
def build_multiclass_target( | |
df: pd.DataFrame, | |
poi_attachment_id: int = 1, | |
e3_attachment_id: int = 2, | |
) -> pd.Series: | |
""" | |
Returns multiclass target: 0 = no split, 1 = E3 split, 2 = WH split | |
""" | |
assert ((df["label_e3_split"] + df["label_wh_split"]) <= 1).all() | |
y = ( | |
df["label_wh_split"] * poi_attachment_id + | |
df["label_e3_split"] * e3_attachment_id | |
) | |
return y.astype("int32") | |
def _ensure_features(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> pd.DataFrame: | |
""" Filter out features/columns that are are not used in the pipeline. """ | |
required_columns = ( | |
(self.graph_features or []) + | |
(self.categorical_features or []) + | |
(self.descriptor_features or []) + | |
(self.fingerprint_features or []) | |
) | |
# If input is a DataFrame with SMILES, assume already featurized | |
if isinstance(X, pd.DataFrame): | |
Xf = X | |
elif isinstance(X, list) and isinstance(X[0], dict): | |
Xf = pd.DataFrame(X) | |
else: | |
raise ValueError("Provide either a DataFrame or list of feature dicts. Use extract_graph_features for SMILES.") | |
missing = set(required_columns) - set(Xf.columns) | |
if missing: | |
raise ValueError(f"Input data missing required columns: {missing}") | |
return Xf[required_columns].copy() | |
def predict_proba_from_smiles( | |
self, | |
protac_smiles: Union[str, List[str]], | |
wh_smiles: Union[str, List[str]], | |
lk_smiles: Union[str, List[str]], | |
e3_smiles: Union[str, List[str]], | |
verbose: int = 0, | |
): | |
features = self.extract_graph_features( | |
protac_smiles, wh_smiles, lk_smiles, e3_smiles, | |
n_bits=self.n_bits, | |
radius=self.radius, | |
descriptor_names=self.descriptor_names, | |
verbose=verbose | |
) | |
Xf = self._ensure_features(features) | |
return self.pipeline.predict_proba(Xf) | |
def predict_from_smiles( | |
self, | |
protac_smiles: Union[str, List[str]], | |
wh_smiles: Union[str, List[str]], | |
lk_smiles: Union[str, List[str]], | |
e3_smiles: Union[str, List[str]], | |
top_n: int = 1, | |
return_array: bool = True, | |
verbose: int = 0, | |
) -> Union[pd.DataFrame, np.ndarray]: | |
""" | |
For binary classification: | |
For each SMILES, return the top_n edge chem_bond_idx indices among those predicted as class 1, | |
sorted by predicted probability. If not enough edges are class 1, pad with -1. | |
For multiclass: | |
For each SMILES, return the chem_bond_idx with highest probability for class 1 (E3 split) | |
and for class 2 (WH split). Shape: (num_smiles, 2). | |
If no edge is predicted as that class, value is -1. | |
""" | |
features = self.extract_graph_features( | |
protac_smiles, wh_smiles, lk_smiles, e3_smiles, | |
n_bits=self.n_bits, | |
radius=self.radius, | |
descriptor_names=self.descriptor_names, | |
verbose=verbose | |
) | |
Xf = self._ensure_features(features) | |
pred_proba = self.pipeline.predict_proba(Xf) | |
pred_label = self.pipeline.predict(Xf) | |
features = features.copy() | |
features["pred_label"] = pred_label | |
features["pred_proba"] = pred_proba[:, 1] if pred_proba.shape[1] > 1 else pred_proba[:, 0] | |
# NOTE: The SMILES is repeated for each edge, so we can drop duplicates | |
# and group by SMILES to get the top_n edges per SMILES. | |
unique_smiles = pd.Series(features["chem_mol_smiles"]).drop_duplicates().tolist() | |
groupby = features.groupby("chem_mol_smiles") | |
results = [] | |
if return_array: | |
if pred_proba.shape[1] == 2: # Binary case | |
for mol_smiles in unique_smiles: | |
group = groupby.get_group(mol_smiles) | |
# Sort by proba, take top_n | |
if top_n < 0: | |
top_n = len(group["graph_num_bridges"]) | |
top_edges = group.nlargest(top_n, "pred_proba") | |
idxs = top_edges["chem_bond_idx"].to_numpy() | |
if len(idxs) < top_n: | |
idxs = np.pad(idxs, (0, top_n - len(idxs)), constant_values=-1) | |
results.append(idxs[:top_n]) | |
return np.vstack(results) | |
else: # Multiclass case | |
for mol_smiles in unique_smiles: | |
group = groupby.get_group(mol_smiles) | |
# For class 1 | |
class1_idx = -1 | |
if (group["pred_label"] == 1).any(): | |
# Take the edge with highest class-1 probability | |
mask = group["pred_label"] == 1 | |
idx1 = group.loc[mask, "pred_proba"].idxmax() | |
class1_idx = group.loc[idx1, "chem_bond_idx"] | |
# For class 2 | |
class2_idx = -1 | |
if (group["pred_label"] == 2).any(): | |
mask = group["pred_label"] == 2 | |
idx2 = group.loc[mask, "pred_proba"].idxmax() | |
class2_idx = group.loc[idx2, "chem_bond_idx"] | |
results.append([class1_idx, class2_idx]) | |
return np.array(results, dtype=int) | |
else: | |
return features | |
def get_classification_report(y_true, y_pred, labels): | |
report = classification_report(y_true, y_pred, target_names=labels, output_dict=True) | |
df_report = pd.DataFrame(report).transpose().round(2) | |
print(df_report) | |
return df_report | |
def plot_confusion_matrix(y_true, y_pred, labels): | |
cm = confusion_matrix(y_true, y_pred) | |
if HAS_VISUALIZATION: | |
plt.figure(figsize=(8, 6)) | |
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels) | |
plt.xlabel("Predicted") | |
plt.ylabel("True") | |
plt.title("Confusion Matrix") | |
plt.show() | |
else: | |
print("Visualization libraries not available. Skipping confusion matrix plot.") | |
print("Confusion Matrix:") | |
print(cm) | |
def get_classification_report_and_plot(y_true, y_pred, labels): | |
report = get_classification_report(y_true, y_pred, labels) | |
plot_confusion_matrix(y_true, y_pred, labels) | |
return report | |
def train_edge_classifier( | |
train_df: pd.DataFrame, | |
val_df: Optional[pd.DataFrame] = None, | |
test_df: Optional[pd.DataFrame] = None, | |
model_filename: Optional[Union[str, Path]] = None, | |
edge_classifier_kwargs: Optional[Dict[str, Any]] = None, | |
cache_dir: Optional[Union[str, Path]] = None, | |
return_reports: bool = True, | |
plot_confusion_matrix: bool = False, | |
) -> GraphEdgeClassifier: | |
""" | |
Train an edge-level graph classifier for PROTACs. | |
Args: | |
train_df (pd.DataFrame): Training data with columns: | |
- 'PROTAC SMILES' | |
- 'POI Ligand SMILES with direction' | |
- 'Linker SMILES with direction' | |
- 'E3 Binder SMILES with direction' | |
val_df (Optional[pd.DataFrame]): Validation data, same format as train_df. | |
test_df (Optional[pd.DataFrame]): Test data, same format as train_df. | |
model_filename (Optional[Union[str, Path]]): Path to save the trained model. | |
edge_classifier_kwargs (Optional[Dict[str, Any]]): Additional parameters for GraphEdgeClassifier. | |
return_reports (bool): Whether to return classification reports for validation and test sets. | |
Returns: | |
GraphEdgeClassifier: Trained edge classifier instance. | |
""" | |
sets = {} | |
for set_name, df in [ | |
("train", train_df), | |
("val", val_df), | |
("test", test_df), | |
]: | |
if cache_dir is not None: | |
cache_path = Path(cache_dir) / f"{set_name}.csv" | |
if cache_path.exists(): | |
print(f"Loading cached features for {set_name} from {cache_path}") | |
sets[set_name] = pd.read_csv(cache_path) | |
continue | |
else: | |
print(f"Cache not found for {set_name}, extracting features...") | |
if df is None or df.empty: | |
continue | |
print(f"Set: {set_name}, size: {len(df):,}") | |
if 'PROTAC SMILES' not in df.columns or \ | |
'POI Ligand SMILES with direction' not in df.columns or \ | |
'Linker SMILES with direction' not in df.columns or \ | |
'E3 Binder SMILES with direction' not in df.columns: | |
raise ValueError(f"DataFrame for {set_name} is missing required columns: 'PROTAC SMILES', 'POI Ligand SMILES with direction', 'Linker SMILES with direction', 'E3 Binder SMILES with direction'.") | |
sets[set_name] = GraphEdgeClassifier.extract_graph_features( | |
df['PROTAC SMILES'].tolist(), | |
df['POI Ligand SMILES with direction'].tolist(), | |
df['Linker SMILES with direction'].tolist(), | |
df['E3 Binder SMILES with direction'].tolist(), | |
verbose=1, | |
) | |
# Drop rows with label_e3_split + label_wh_split > 1 | |
sets[set_name] = sets[set_name][(sets[set_name]["label_e3_split"] + sets[set_name]["label_wh_split"]) <= 1] | |
print(f"Set: {set_name}, size: {len(sets[set_name]):,}") | |
if cache_dir is not None: | |
cache_path = Path(cache_dir) / f"{set_name}.csv" | |
cache_path.parent.mkdir(parents=True, exist_ok=True) | |
sets[set_name].to_csv(cache_path, index=False) | |
print(f"Saved {set_name} features to {cache_path}") | |
train_set = sets["train"] | |
label_cols = [c for c in train_set.columns if c.startswith("label_")] | |
train_set = train_set.dropna(subset=label_cols) | |
train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1] | |
X_train = train_set.drop(columns=label_cols) | |
# Instantiate and train | |
clf = GraphEdgeClassifier(**edge_classifier_kwargs or { | |
"graph_features": [c for c in X_train.columns if c.startswith("graph_")], | |
"categorical_features": ["chem_bond_type", "chem_atom_u", "chem_atom_v"], | |
"fingerprint_features": [c for c in X_train.columns if c.startswith("chem_mol_fp_")], | |
"use_descriptors": False, | |
"use_fingerprints": True, | |
"n_svd_components": 50, | |
"binary": True, | |
"smote_k_neighbors": 10, | |
"xgb_params": { | |
"max_depth": 6, | |
"learning_rate": 0.3, | |
"alpha": 0.1, # Default: 0 | |
"lambda": 0.5, # Default: 1 | |
"gamma": 0.1, # Default: 0 | |
}, | |
}) | |
# Prepare target variable according to classification type | |
if clf.binary: | |
y_train = train_set["label_is_split"].astype("int32") | |
else: | |
y_train = GraphEdgeClassifier.build_multiclass_target(train_set) | |
print(f"Training set size: {len(X_train):,}, labels: {y_train.unique()}") | |
clf.fit(X_train, y_train) | |
print("Training complete.") | |
if model_filename is not None: | |
clf.save(model_filename) | |
print(f"Model saved to {model_filename}") | |
target_labels = ["No Split", "Split"] if clf.binary else ["No Split", "WH-Linker", "E3-Linker"] | |
report = None | |
if "val" in sets: | |
# Get validation data | |
val_set = sets["val"].dropna(subset=label_cols) | |
val_set = val_set[(val_set["label_e3_split"] + val_set["label_wh_split"]) <= 1] | |
X_val = val_set.drop(columns=label_cols) | |
y_val = val_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(val_set) | |
y_pred = clf.predict(X_val) | |
if plot_confusion_matrix: | |
report = get_classification_report_and_plot(y_val, y_pred, target_labels) | |
else: | |
report = get_classification_report(y_val, y_pred, target_labels) | |
print(f"Validation set classification report:\n{report.to_markdown(index=False)}") | |
if "test" in sets: | |
# Get test data | |
test_set = sets["test"].dropna(subset=label_cols) | |
test_set = test_set[(test_set["label_e3_split"] + test_set["label_wh_split"]) <= 1] | |
X_test = test_set.drop(columns=label_cols) | |
y_test = test_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(test_set) | |
y_pred = clf.predict(X_test) | |
if plot_confusion_matrix: | |
report = get_classification_report_and_plot(y_test, y_pred, target_labels) | |
else: | |
report = get_classification_report(y_test, y_pred, target_labels) | |
print(f"Test set classification report:\n{report.to_markdown(index=False)}") | |
if return_reports: | |
return clf, report | |
else: | |
return clf | |
def objective(trial, train_df, val_df): | |
# HP space | |
max_depth = trial.suggest_int("max_depth", 3, 10) | |
learning_rate = trial.suggest_float("learning_rate", 0.01, 0.3, log=True) | |
alpha = trial.suggest_float("alpha", 0.0, 2.0) | |
reg_lambda = trial.suggest_float("lambda", 0.0, 2.0) | |
gamma = trial.suggest_float("gamma", 0.0, 1.0) | |
n_svd_components = trial.suggest_int("n_svd_components", 16, 128) | |
smote_k_neighbors = trial.suggest_int("smote_k_neighbors", 3, 15) | |
use_descriptors = trial.suggest_categorical("use_descriptors", [False, True]) | |
use_fingerprints = trial.suggest_categorical("use_fingerprints", [True, False]) | |
edge_classifier_kwargs = { | |
"graph_features": None, # Will be set in train_edge_classifier | |
"categorical_features": None, | |
"fingerprint_features": None, | |
"use_descriptors": use_descriptors, | |
"use_fingerprints": use_fingerprints, | |
"n_svd_components": n_svd_components, | |
"binary": True, | |
"smote_k_neighbors": smote_k_neighbors, | |
"xgb_params": { | |
"max_depth": max_depth, | |
"learning_rate": learning_rate, | |
"alpha": alpha, | |
"lambda": reg_lambda, | |
"gamma": gamma, | |
}, | |
} | |
_, val_report = train_edge_classifier( | |
train_df=train_df, | |
val_df=val_df, | |
edge_classifier_kwargs=edge_classifier_kwargs, | |
return_reports=True, | |
) | |
# Evaluate metrics on validation set | |
# Assume val_report has columns: ['Label', 'precision', 'recall', 'f1-score', 'support'] | |
# and that the binary positive class is "Split" or "1" | |
try: | |
f1_1 = float(val_report[val_report["Label"].isin(["Split", 1, "1"])]["f1-score"]) | |
except Exception: | |
f1_1 = 0.0 | |
try: | |
acc = float(val_report[val_report["Label"] == "accuracy"]["f1-score"]) | |
except Exception: | |
acc = 0.0 | |
# Multi-objective: prioritize F1 for minority class, but keep accuracy | |
# Adjust weight depending on task (here equal) | |
score = 0.5 * acc + 0.5 * f1_1 | |
return score | |
def run_optuna_search( | |
train_df: pd.DataFrame, | |
val_df: pd.DataFrame, | |
n_trials: int = 50, | |
study_name: str = "edge_classifier_hp_search", | |
study_dir: str = "./optuna_studies", | |
seed: int = 42, | |
) -> Any: | |
import os | |
os.makedirs(study_dir, exist_ok=True) | |
study_path = f"sqlite:///{os.path.join(study_dir, study_name)}.db" | |
study = optuna.create_study( | |
study_name=study_name, | |
direction="maximize", | |
sampler=QMCSampler(seed=seed, qmc_type="sobol"), | |
storage=study_path, | |
load_if_exists=True, | |
) | |
func = lambda trial: objective(trial, train_df, val_df) | |
study.optimize(func, n_trials=n_trials, show_progress_bar=True) | |
print("Best trial:") | |
print(study.best_trial) | |
# Train classifier with best HP and return it | |
best_params = study.best_trial.params | |
edge_classifier_kwargs = { | |
"graph_features": None, | |
"categorical_features": None, | |
"fingerprint_features": None, | |
"use_descriptors": best_params["use_descriptors"], | |
"use_fingerprints": best_params["use_fingerprints"], | |
"n_svd_components": best_params["n_svd_components"], | |
"binary": True, | |
"smote_k_neighbors": best_params["smote_k_neighbors"], | |
"xgb_params": { | |
"max_depth": best_params["max_depth"], | |
"learning_rate": best_params["learning_rate"], | |
"alpha": best_params["alpha"], | |
"lambda": best_params["lambda"], | |
"gamma": best_params["gamma"], | |
}, | |
} | |
clf, _ = train_edge_classifier( | |
train_df=train_df, | |
val_df=val_df, | |
edge_classifier_kwargs=edge_classifier_kwargs, | |
return_reports=True, | |
) | |
study_file = os.path.join(study_dir, f"{study_name}_study.pkl") | |
import joblib | |
joblib.dump(study, study_file) | |
print(f"Optuna study saved to {study_file}") | |
return clf, study |