import joblib from pathlib import Path from typing import Optional, List, Dict, Union, Any, Literal import pandas as pd import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.compose import ColumnTransformer from sklearn.preprocessing import StandardScaler, OneHotEncoder from sklearn.decomposition import TruncatedSVD from imblearn.over_sampling import SMOTE from imblearn.pipeline import Pipeline as ImbPipeline from sklearn.pipeline import Pipeline from sklearn.metrics import classification_report from sklearn.metrics import confusion_matrix from xgboost import XGBClassifier import optuna from optuna.samplers import QMCSampler from sklearn.metrics import accuracy_score, f1_score try: import seaborn as sns import matplotlib.pyplot as plt HAS_VISUALIZATION = True except ImportError: HAS_VISUALIZATION = False from .edge_features import extract_edge_features, get_edge_features class GraphEdgeClassifier(BaseEstimator, ClassifierMixin): """ Edge-level graph classifier for PROTACs with integrated pipeline building. """ def __init__( self, graph_features: List[str], categorical_features: Optional[List[str]] = None, descriptor_features: Optional[List[str]] = None, fingerprint_features: Optional[List[str]] = None, use_descriptors: bool = True, use_fingerprints: bool = True, scaler_graph: Literal["passthrough", "standard"] = "passthrough", scaler_desc: Literal["passthrough", "standard"] = "passthrough", use_svd_fp: bool = True, n_svd_components: int = 100, binary: bool = False, smote_k_neighbors: Optional[int] = 5, xgb_params: Optional[dict] = None, n_bits: int = 512, radius: int = 6, descriptor_names: Optional[List[str]] = None ): self.graph_features = graph_features self.categorical_features = categorical_features self.descriptor_features = descriptor_features self.fingerprint_features = fingerprint_features self.use_descriptors = use_descriptors self.use_fingerprints = use_fingerprints self.scaler_graph = scaler_graph self.scaler_desc = scaler_desc self.use_svd_fp = use_svd_fp self.n_svd_components = n_svd_components self.binary = binary self.smote_k_neighbors = smote_k_neighbors self.xgb_params = xgb_params or {} self.n_bits = n_bits self.radius = radius self.descriptor_names = descriptor_names or [ "MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors", "TPSA", "NumRotatableBonds", "RingCount", "MolLogP" ] self.pipeline = self._build_pipeline() def _build_pipeline(self): transformers = [] if self.categorical_features: transformers.append(("cat", OneHotEncoder(handle_unknown="ignore"), self.categorical_features)) if self.scaler_graph == "standard": transformers.append(("num", StandardScaler(), self.graph_features)) else: transformers.append(("num", "passthrough", self.graph_features)) if self.use_descriptors and self.descriptor_features: desc_block = ( ("desc", StandardScaler(), self.descriptor_features) if self.scaler_desc == "standard" else ("desc", "passthrough", self.descriptor_features) ) transformers.append(desc_block) if self.use_fingerprints and self.fingerprint_features: if self.use_svd_fp: fp_block = ("fp", ImbPipeline([ ("svd", TruncatedSVD(n_components=self.n_svd_components, random_state=42)) ]), self.fingerprint_features) else: fp_block = ("fp", "passthrough", self.fingerprint_features) transformers.append(fp_block) preprocessor = ColumnTransformer(transformers) # Define the classifier classifier = XGBClassifier( random_state=42, eval_metric="logloss" if self.binary else "mlogloss", objective="binary:logistic" if self.binary else "multi:softprob", **self.xgb_params ) if self.smote_k_neighbors is not None: return ImbPipeline([ ("preprocess", preprocessor), ("smote", SMOTE(random_state=42, k_neighbors=self.smote_k_neighbors)), ("clf", classifier) ]) else: return Pipeline([ ("preprocess", preprocessor), ("clf", classifier) ]) def fit(self, X: pd.DataFrame, y: pd.Series): self.pipeline.fit(X, y) return self def predict(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any: X_proc = self._ensure_features(X) return self.pipeline.predict(X_proc) def predict_proba(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any: X_proc = self._ensure_features(X) return self.pipeline.predict_proba(X_proc) def save(self, path: Union[str, Path]): joblib.dump(self, str(path)) @classmethod def load(cls, path: Union[str, Path]) -> "GraphEdgeClassifier": return joblib.load(str(path)) @staticmethod def extract_graph_features( protac_smiles: Union[str, List[str]], wh_smiles: Optional[Union[str, List[str]]] = None, lk_smiles: Optional[Union[str, List[str]]] = None, e3_smiles: Optional[Union[str, List[str]]] = None, n_bits: int = 512, radius: int = 6, descriptor_names: Optional[List[str]] = None, verbose: int = 0 ) -> pd.DataFrame: if any(x is None for x in [wh_smiles, lk_smiles, e3_smiles]): # Get features from PROTAC only, for inference return extract_edge_features( protac_smiles=protac_smiles, n_bits=n_bits, radius=radius, descriptor_names=descriptor_names, ) else: # Get features and labels from all components, for training return get_edge_features( protac_smiles=protac_smiles, wh_smiles=wh_smiles, lk_smiles=lk_smiles, e3_smiles=e3_smiles, n_bits=n_bits, radius=radius, descriptor_names=descriptor_names, verbose=verbose ) @staticmethod def build_multiclass_target( df: pd.DataFrame, poi_attachment_id: int = 1, e3_attachment_id: int = 2, ) -> pd.Series: """ Returns multiclass target: 0 = no split, 1 = E3 split, 2 = WH split """ assert ((df["label_e3_split"] + df["label_wh_split"]) <= 1).all() y = ( df["label_wh_split"] * poi_attachment_id + df["label_e3_split"] * e3_attachment_id ) return y.astype("int32") def _ensure_features(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> pd.DataFrame: """ Filter out features/columns that are are not used in the pipeline. """ required_columns = ( (self.graph_features or []) + (self.categorical_features or []) + (self.descriptor_features or []) + (self.fingerprint_features or []) ) # If input is a DataFrame with SMILES, assume already featurized if isinstance(X, pd.DataFrame): Xf = X elif isinstance(X, list) and isinstance(X[0], dict): Xf = pd.DataFrame(X) else: raise ValueError("Provide either a DataFrame or list of feature dicts. Use extract_graph_features for SMILES.") missing = set(required_columns) - set(Xf.columns) if missing: raise ValueError(f"Input data missing required columns: {missing}") return Xf[required_columns].copy() def predict_proba_from_smiles( self, protac_smiles: Union[str, List[str]], wh_smiles: Union[str, List[str]], lk_smiles: Union[str, List[str]], e3_smiles: Union[str, List[str]], verbose: int = 0, ): features = self.extract_graph_features( protac_smiles, wh_smiles, lk_smiles, e3_smiles, n_bits=self.n_bits, radius=self.radius, descriptor_names=self.descriptor_names, verbose=verbose ) Xf = self._ensure_features(features) return self.pipeline.predict_proba(Xf) def predict_from_smiles( self, protac_smiles: Union[str, List[str]], wh_smiles: Union[str, List[str]], lk_smiles: Union[str, List[str]], e3_smiles: Union[str, List[str]], top_n: int = 1, return_array: bool = True, verbose: int = 0, ) -> Union[pd.DataFrame, np.ndarray]: """ For binary classification: For each SMILES, return the top_n edge chem_bond_idx indices among those predicted as class 1, sorted by predicted probability. If not enough edges are class 1, pad with -1. For multiclass: For each SMILES, return the chem_bond_idx with highest probability for class 1 (E3 split) and for class 2 (WH split). Shape: (num_smiles, 2). If no edge is predicted as that class, value is -1. """ features = self.extract_graph_features( protac_smiles, wh_smiles, lk_smiles, e3_smiles, n_bits=self.n_bits, radius=self.radius, descriptor_names=self.descriptor_names, verbose=verbose ) Xf = self._ensure_features(features) pred_proba = self.pipeline.predict_proba(Xf) pred_label = self.pipeline.predict(Xf) features = features.copy() features["pred_label"] = pred_label features["pred_proba"] = pred_proba[:, 1] if pred_proba.shape[1] > 1 else pred_proba[:, 0] # NOTE: The SMILES is repeated for each edge, so we can drop duplicates # and group by SMILES to get the top_n edges per SMILES. unique_smiles = pd.Series(features["chem_mol_smiles"]).drop_duplicates().tolist() groupby = features.groupby("chem_mol_smiles") results = [] if return_array: if pred_proba.shape[1] == 2: # Binary case for mol_smiles in unique_smiles: group = groupby.get_group(mol_smiles) # Sort by proba, take top_n if top_n < 0: top_n = len(group["graph_num_bridges"]) top_edges = group.nlargest(top_n, "pred_proba") idxs = top_edges["chem_bond_idx"].to_numpy() if len(idxs) < top_n: idxs = np.pad(idxs, (0, top_n - len(idxs)), constant_values=-1) results.append(idxs[:top_n]) return np.vstack(results) else: # Multiclass case for mol_smiles in unique_smiles: group = groupby.get_group(mol_smiles) # For class 1 class1_idx = -1 if (group["pred_label"] == 1).any(): # Take the edge with highest class-1 probability mask = group["pred_label"] == 1 idx1 = group.loc[mask, "pred_proba"].idxmax() class1_idx = group.loc[idx1, "chem_bond_idx"] # For class 2 class2_idx = -1 if (group["pred_label"] == 2).any(): mask = group["pred_label"] == 2 idx2 = group.loc[mask, "pred_proba"].idxmax() class2_idx = group.loc[idx2, "chem_bond_idx"] results.append([class1_idx, class2_idx]) return np.array(results, dtype=int) else: return features def get_classification_report(y_true, y_pred, labels): report = classification_report(y_true, y_pred, target_names=labels, output_dict=True) df_report = pd.DataFrame(report).transpose().round(2) print(df_report) return df_report def plot_confusion_matrix(y_true, y_pred, labels): cm = confusion_matrix(y_true, y_pred) if HAS_VISUALIZATION: plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels) plt.xlabel("Predicted") plt.ylabel("True") plt.title("Confusion Matrix") plt.show() else: print("Visualization libraries not available. Skipping confusion matrix plot.") print("Confusion Matrix:") print(cm) def get_classification_report_and_plot(y_true, y_pred, labels): report = get_classification_report(y_true, y_pred, labels) plot_confusion_matrix(y_true, y_pred, labels) return report def train_edge_classifier( train_df: pd.DataFrame, val_df: Optional[pd.DataFrame] = None, test_df: Optional[pd.DataFrame] = None, model_filename: Optional[Union[str, Path]] = None, edge_classifier_kwargs: Optional[Dict[str, Any]] = None, cache_dir: Optional[Union[str, Path]] = None, return_reports: bool = True, plot_confusion_matrix: bool = False, ) -> GraphEdgeClassifier: """ Train an edge-level graph classifier for PROTACs. Args: train_df (pd.DataFrame): Training data with columns: - 'PROTAC SMILES' - 'POI Ligand SMILES with direction' - 'Linker SMILES with direction' - 'E3 Binder SMILES with direction' val_df (Optional[pd.DataFrame]): Validation data, same format as train_df. test_df (Optional[pd.DataFrame]): Test data, same format as train_df. model_filename (Optional[Union[str, Path]]): Path to save the trained model. edge_classifier_kwargs (Optional[Dict[str, Any]]): Additional parameters for GraphEdgeClassifier. return_reports (bool): Whether to return classification reports for validation and test sets. Returns: GraphEdgeClassifier: Trained edge classifier instance. """ sets = {} for set_name, df in [ ("train", train_df), ("val", val_df), ("test", test_df), ]: if cache_dir is not None: cache_path = Path(cache_dir) / f"{set_name}.csv" if cache_path.exists(): print(f"Loading cached features for {set_name} from {cache_path}") sets[set_name] = pd.read_csv(cache_path) continue else: print(f"Cache not found for {set_name}, extracting features...") if df is None or df.empty: continue print(f"Set: {set_name}, size: {len(df):,}") if 'PROTAC SMILES' not in df.columns or \ 'POI Ligand SMILES with direction' not in df.columns or \ 'Linker SMILES with direction' not in df.columns or \ 'E3 Binder SMILES with direction' not in df.columns: raise ValueError(f"DataFrame for {set_name} is missing required columns: 'PROTAC SMILES', 'POI Ligand SMILES with direction', 'Linker SMILES with direction', 'E3 Binder SMILES with direction'.") sets[set_name] = GraphEdgeClassifier.extract_graph_features( df['PROTAC SMILES'].tolist(), df['POI Ligand SMILES with direction'].tolist(), df['Linker SMILES with direction'].tolist(), df['E3 Binder SMILES with direction'].tolist(), verbose=1, ) # Drop rows with label_e3_split + label_wh_split > 1 sets[set_name] = sets[set_name][(sets[set_name]["label_e3_split"] + sets[set_name]["label_wh_split"]) <= 1] print(f"Set: {set_name}, size: {len(sets[set_name]):,}") if cache_dir is not None: cache_path = Path(cache_dir) / f"{set_name}.csv" cache_path.parent.mkdir(parents=True, exist_ok=True) sets[set_name].to_csv(cache_path, index=False) print(f"Saved {set_name} features to {cache_path}") train_set = sets["train"] label_cols = [c for c in train_set.columns if c.startswith("label_")] train_set = train_set.dropna(subset=label_cols) train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1] X_train = train_set.drop(columns=label_cols) # Instantiate and train clf = GraphEdgeClassifier(**edge_classifier_kwargs or { "graph_features": [c for c in X_train.columns if c.startswith("graph_")], "categorical_features": ["chem_bond_type", "chem_atom_u", "chem_atom_v"], "fingerprint_features": [c for c in X_train.columns if c.startswith("chem_mol_fp_")], "use_descriptors": False, "use_fingerprints": True, "n_svd_components": 50, "binary": True, "smote_k_neighbors": 10, "xgb_params": { "max_depth": 6, "learning_rate": 0.3, "alpha": 0.1, # Default: 0 "lambda": 0.5, # Default: 1 "gamma": 0.1, # Default: 0 }, }) # Prepare target variable according to classification type if clf.binary: y_train = train_set["label_is_split"].astype("int32") else: y_train = GraphEdgeClassifier.build_multiclass_target(train_set) print(f"Training set size: {len(X_train):,}, labels: {y_train.unique()}") clf.fit(X_train, y_train) print("Training complete.") if model_filename is not None: clf.save(model_filename) print(f"Model saved to {model_filename}") target_labels = ["No Split", "Split"] if clf.binary else ["No Split", "WH-Linker", "E3-Linker"] report = None if "val" in sets: # Get validation data val_set = sets["val"].dropna(subset=label_cols) val_set = val_set[(val_set["label_e3_split"] + val_set["label_wh_split"]) <= 1] X_val = val_set.drop(columns=label_cols) y_val = val_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(val_set) y_pred = clf.predict(X_val) if plot_confusion_matrix: report = get_classification_report_and_plot(y_val, y_pred, target_labels) else: report = get_classification_report(y_val, y_pred, target_labels) print(f"Validation set classification report:\n{report.to_markdown(index=False)}") if "test" in sets: # Get test data test_set = sets["test"].dropna(subset=label_cols) test_set = test_set[(test_set["label_e3_split"] + test_set["label_wh_split"]) <= 1] X_test = test_set.drop(columns=label_cols) y_test = test_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(test_set) y_pred = clf.predict(X_test) if plot_confusion_matrix: report = get_classification_report_and_plot(y_test, y_pred, target_labels) else: report = get_classification_report(y_test, y_pred, target_labels) print(f"Test set classification report:\n{report.to_markdown(index=False)}") if return_reports: return clf, report else: return clf def objective(trial, train_df, val_df): # HP space max_depth = trial.suggest_int("max_depth", 3, 10) learning_rate = trial.suggest_float("learning_rate", 0.01, 0.3, log=True) alpha = trial.suggest_float("alpha", 0.0, 2.0) reg_lambda = trial.suggest_float("lambda", 0.0, 2.0) gamma = trial.suggest_float("gamma", 0.0, 1.0) n_svd_components = trial.suggest_int("n_svd_components", 16, 128) smote_k_neighbors = trial.suggest_int("smote_k_neighbors", 3, 15) use_descriptors = trial.suggest_categorical("use_descriptors", [False, True]) use_fingerprints = trial.suggest_categorical("use_fingerprints", [True, False]) edge_classifier_kwargs = { "graph_features": None, # Will be set in train_edge_classifier "categorical_features": None, "fingerprint_features": None, "use_descriptors": use_descriptors, "use_fingerprints": use_fingerprints, "n_svd_components": n_svd_components, "binary": True, "smote_k_neighbors": smote_k_neighbors, "xgb_params": { "max_depth": max_depth, "learning_rate": learning_rate, "alpha": alpha, "lambda": reg_lambda, "gamma": gamma, }, } _, val_report = train_edge_classifier( train_df=train_df, val_df=val_df, edge_classifier_kwargs=edge_classifier_kwargs, return_reports=True, ) # Evaluate metrics on validation set # Assume val_report has columns: ['Label', 'precision', 'recall', 'f1-score', 'support'] # and that the binary positive class is "Split" or "1" try: f1_1 = float(val_report[val_report["Label"].isin(["Split", 1, "1"])]["f1-score"]) except Exception: f1_1 = 0.0 try: acc = float(val_report[val_report["Label"] == "accuracy"]["f1-score"]) except Exception: acc = 0.0 # Multi-objective: prioritize F1 for minority class, but keep accuracy # Adjust weight depending on task (here equal) score = 0.5 * acc + 0.5 * f1_1 return score def run_optuna_search( train_df: pd.DataFrame, val_df: pd.DataFrame, n_trials: int = 50, study_name: str = "edge_classifier_hp_search", study_dir: str = "./optuna_studies", seed: int = 42, ) -> Any: import os os.makedirs(study_dir, exist_ok=True) study_path = f"sqlite:///{os.path.join(study_dir, study_name)}.db" study = optuna.create_study( study_name=study_name, direction="maximize", sampler=QMCSampler(seed=seed, qmc_type="sobol"), storage=study_path, load_if_exists=True, ) func = lambda trial: objective(trial, train_df, val_df) study.optimize(func, n_trials=n_trials, show_progress_bar=True) print("Best trial:") print(study.best_trial) # Train classifier with best HP and return it best_params = study.best_trial.params edge_classifier_kwargs = { "graph_features": None, "categorical_features": None, "fingerprint_features": None, "use_descriptors": best_params["use_descriptors"], "use_fingerprints": best_params["use_fingerprints"], "n_svd_components": best_params["n_svd_components"], "binary": True, "smote_k_neighbors": best_params["smote_k_neighbors"], "xgb_params": { "max_depth": best_params["max_depth"], "learning_rate": best_params["learning_rate"], "alpha": best_params["alpha"], "lambda": best_params["lambda"], "gamma": best_params["gamma"], }, } clf, _ = train_edge_classifier( train_df=train_df, val_df=val_df, edge_classifier_kwargs=edge_classifier_kwargs, return_reports=True, ) study_file = os.path.join(study_dir, f"{study_name}_study.pkl") import joblib joblib.dump(study, study_file) print(f"Optuna study saved to {study_file}") return clf, study