ribesstefano's picture
Setup the spaces app
9dd777e
from numba import njit
import numpy as np
import networkx as nx
from rdkit import Chem
def mol2graph(mol: Chem.Mol) -> nx.Graph:
""" Convert an RDKit molecule to a NetworkX graph.
Args:
mol (Chem.Mol): The RDKit molecule to convert.
Returns:
nx.Graph: The NetworkX graph representation of the molecule.
"""
# NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files
# TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff
if mol is None:
return nx.empty_graph()
G = nx.Graph()
for atom in mol.GetAtoms():
# Skip non-heavy atoms
if atom.GetAtomicNum() != 0:
G.add_node(atom.GetIdx(), label=atom.GetSymbol())
for bond in mol.GetBonds():
# Skip bonds to non-heavy atoms
if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0:
continue
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType())
return G
def smiles2graph(smiles: str) -> nx.Graph:
""" Convert a SMILES string to a NetworkX graph.
Args:
smiles (str): The SMILES string to convert.
Returns:
nx.Graph: The NetworkX graph representation of the molecule.
"""
return mol2graph(Chem.MolFromSmiles(smiles))
def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float:
""" Compute the graph edit distance between two SMILES strings.
Args:
smi1 (str): The first SMILES string.
smi2 (str): The second SMILES string.
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
Returns:
float: The graph edit distance between the two SMILES strings.
"""
ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs)
return ged if ged is not None else np.inf
def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float:
""" Compute the graph edit distance between two RDKit molecules.
Args:
mol1 (Chem.Mol): The first RDKit molecule.
mol2 (Chem.Mol): The second RDKit molecule.
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
Returns:
float: The graph edit distance between the two RDKit molecules.
"""
ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs)
return ged if ged is not None else np.inf
def get_smiles2graph_edit_distance_norm(
smi1: str,
smi2: str,
ged_G1_G2: None,
eps: float = 1e-9,
**kwargs,
) -> float:
""" Compute the normalized graph edit distance between two SMILES strings.
Args:
smi1 (str): The first SMILES string.
smi2 (str): The second SMILES string.
ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`.
eps (float): A small value to avoid division by zero.
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
Returns:
float: The normalized graph edit distance between the two SMILES strings.
"""
G1 = smiles2graph(smi1)
G2 = smiles2graph(smi2)
G0 = nx.empty_graph()
ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs)
ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs)
ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs)
if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]:
return np.inf
return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps)
def smiles2adjacency_matrix(smiles: str) -> np.ndarray:
return nx.adjacency_matrix(smiles2graph(smiles)).todense()
def build_label_mapping(G1, G2):
labels = set()
for G in [G1, G2]:
for node in G.nodes():
labels.add(G.nodes[node]['label'])
label_to_int = {label: idx for idx, label in enumerate(sorted(labels))}
return label_to_int
def preprocess_graph(G, label_to_int):
n = G.number_of_nodes()
adj = np.zeros((n, n), dtype=np.int32)
labels = np.zeros(n, dtype=np.int32)
node_id_to_idx = {}
for idx, node in enumerate(G.nodes()):
node_id_to_idx[node] = idx
label = G.nodes[node]['label']
labels[idx] = label_to_int[label]
for u, v in G.edges():
idx_u = node_id_to_idx[u]
idx_v = node_id_to_idx[v]
adj[idx_u, idx_v] = 1
adj[idx_v, idx_u] = 1 # Assuming undirected graph
return adj, labels
@njit
def compute_cost_matrix(labels1, labels2, degrees1, degrees2):
n1 = labels1.shape[0]
n2 = labels2.shape[0]
C = np.zeros((n1, n2), dtype=np.float64)
for i in range(n1):
for j in range(n2):
label_cost = 0.0 if labels1[i] == labels2[j] else 1.0
neighborhood_cost = abs(degrees1[i] - degrees2[j])
C[i, j] = label_cost + neighborhood_cost
return C
@njit
def greedy_assignment(C):
n1, n2 = C.shape
assigned_cols = np.full(n2, False)
row_ind = np.full(n1, -1, dtype=np.int32)
for i in range(n1):
min_cost = np.inf
min_j = -1
for j in range(n2):
if not assigned_cols[j] and C[i, j] < min_cost:
min_cost = C[i, j]
min_j = j
if min_j != -1:
row_ind[i] = min_j
assigned_cols[min_j] = True
return row_ind
@njit
def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins):
total_cost = 0.0
assigned_cols = np.full(n2, False)
for i in range(n1):
j = row_ind[i]
if j != -1:
total_cost += C[i, j]
assigned_cols[j] = True
else:
total_cost += c_node_del
for j in range(n2):
if not assigned_cols[j]:
total_cost += c_node_ins
return total_cost
def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0):
degrees1 = adj1.sum(axis=1)
degrees2 = adj2.sum(axis=1)
C = compute_cost_matrix(labels1, labels2, degrees1, degrees2)
row_ind = greedy_assignment(C)
total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins)
return total_cost
def get_approximate_ged(G1, G2):
label_to_int = build_label_mapping(G1, G2)
adj1, labels1 = preprocess_graph(G1, label_to_int)
adj2, labels2 = preprocess_graph(G2, label_to_int)
cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2)
return cost
def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float:
G1 = smiles2graph(smi1)
G2 = smiles2graph(smi2)
return get_approximate_ged(G1, G2)