Spaces:
Sleeping
Sleeping
from numba import njit | |
import numpy as np | |
import networkx as nx | |
from rdkit import Chem | |
def mol2graph(mol: Chem.Mol) -> nx.Graph: | |
""" Convert an RDKit molecule to a NetworkX graph. | |
Args: | |
mol (Chem.Mol): The RDKit molecule to convert. | |
Returns: | |
nx.Graph: The NetworkX graph representation of the molecule. | |
""" | |
# NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files | |
# TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff | |
if mol is None: | |
return nx.empty_graph() | |
G = nx.Graph() | |
for atom in mol.GetAtoms(): | |
# Skip non-heavy atoms | |
if atom.GetAtomicNum() != 0: | |
G.add_node(atom.GetIdx(), label=atom.GetSymbol()) | |
for bond in mol.GetBonds(): | |
# Skip bonds to non-heavy atoms | |
if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0: | |
continue | |
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType()) | |
return G | |
def smiles2graph(smiles: str) -> nx.Graph: | |
""" Convert a SMILES string to a NetworkX graph. | |
Args: | |
smiles (str): The SMILES string to convert. | |
Returns: | |
nx.Graph: The NetworkX graph representation of the molecule. | |
""" | |
return mol2graph(Chem.MolFromSmiles(smiles)) | |
def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float: | |
""" Compute the graph edit distance between two SMILES strings. | |
Args: | |
smi1 (str): The first SMILES string. | |
smi2 (str): The second SMILES string. | |
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`. | |
Returns: | |
float: The graph edit distance between the two SMILES strings. | |
""" | |
ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs) | |
return ged if ged is not None else np.inf | |
def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float: | |
""" Compute the graph edit distance between two RDKit molecules. | |
Args: | |
mol1 (Chem.Mol): The first RDKit molecule. | |
mol2 (Chem.Mol): The second RDKit molecule. | |
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`. | |
Returns: | |
float: The graph edit distance between the two RDKit molecules. | |
""" | |
ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs) | |
return ged if ged is not None else np.inf | |
def get_smiles2graph_edit_distance_norm( | |
smi1: str, | |
smi2: str, | |
ged_G1_G2: None, | |
eps: float = 1e-9, | |
**kwargs, | |
) -> float: | |
""" Compute the normalized graph edit distance between two SMILES strings. | |
Args: | |
smi1 (str): The first SMILES string. | |
smi2 (str): The second SMILES string. | |
ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`. | |
eps (float): A small value to avoid division by zero. | |
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`. | |
Returns: | |
float: The normalized graph edit distance between the two SMILES strings. | |
""" | |
G1 = smiles2graph(smi1) | |
G2 = smiles2graph(smi2) | |
G0 = nx.empty_graph() | |
ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs) | |
ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs) | |
ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs) | |
if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]: | |
return np.inf | |
return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps) | |
def smiles2adjacency_matrix(smiles: str) -> np.ndarray: | |
return nx.adjacency_matrix(smiles2graph(smiles)).todense() | |
def build_label_mapping(G1, G2): | |
labels = set() | |
for G in [G1, G2]: | |
for node in G.nodes(): | |
labels.add(G.nodes[node]['label']) | |
label_to_int = {label: idx for idx, label in enumerate(sorted(labels))} | |
return label_to_int | |
def preprocess_graph(G, label_to_int): | |
n = G.number_of_nodes() | |
adj = np.zeros((n, n), dtype=np.int32) | |
labels = np.zeros(n, dtype=np.int32) | |
node_id_to_idx = {} | |
for idx, node in enumerate(G.nodes()): | |
node_id_to_idx[node] = idx | |
label = G.nodes[node]['label'] | |
labels[idx] = label_to_int[label] | |
for u, v in G.edges(): | |
idx_u = node_id_to_idx[u] | |
idx_v = node_id_to_idx[v] | |
adj[idx_u, idx_v] = 1 | |
adj[idx_v, idx_u] = 1 # Assuming undirected graph | |
return adj, labels | |
def compute_cost_matrix(labels1, labels2, degrees1, degrees2): | |
n1 = labels1.shape[0] | |
n2 = labels2.shape[0] | |
C = np.zeros((n1, n2), dtype=np.float64) | |
for i in range(n1): | |
for j in range(n2): | |
label_cost = 0.0 if labels1[i] == labels2[j] else 1.0 | |
neighborhood_cost = abs(degrees1[i] - degrees2[j]) | |
C[i, j] = label_cost + neighborhood_cost | |
return C | |
def greedy_assignment(C): | |
n1, n2 = C.shape | |
assigned_cols = np.full(n2, False) | |
row_ind = np.full(n1, -1, dtype=np.int32) | |
for i in range(n1): | |
min_cost = np.inf | |
min_j = -1 | |
for j in range(n2): | |
if not assigned_cols[j] and C[i, j] < min_cost: | |
min_cost = C[i, j] | |
min_j = j | |
if min_j != -1: | |
row_ind[i] = min_j | |
assigned_cols[min_j] = True | |
return row_ind | |
def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins): | |
total_cost = 0.0 | |
assigned_cols = np.full(n2, False) | |
for i in range(n1): | |
j = row_ind[i] | |
if j != -1: | |
total_cost += C[i, j] | |
assigned_cols[j] = True | |
else: | |
total_cost += c_node_del | |
for j in range(n2): | |
if not assigned_cols[j]: | |
total_cost += c_node_ins | |
return total_cost | |
def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0): | |
degrees1 = adj1.sum(axis=1) | |
degrees2 = adj2.sum(axis=1) | |
C = compute_cost_matrix(labels1, labels2, degrees1, degrees2) | |
row_ind = greedy_assignment(C) | |
total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins) | |
return total_cost | |
def get_approximate_ged(G1, G2): | |
label_to_int = build_label_mapping(G1, G2) | |
adj1, labels1 = preprocess_graph(G1, label_to_int) | |
adj2, labels2 = preprocess_graph(G2, label_to_int) | |
cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2) | |
return cost | |
def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float: | |
G1 = smiles2graph(smi1) | |
G2 = smiles2graph(smi2) | |
return get_approximate_ged(G1, G2) | |