Spaces:
Sleeping
Sleeping
File size: 6,694 Bytes
9dd777e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from numba import njit
import numpy as np
import networkx as nx
from rdkit import Chem
def mol2graph(mol: Chem.Mol) -> nx.Graph:
""" Convert an RDKit molecule to a NetworkX graph.
Args:
mol (Chem.Mol): The RDKit molecule to convert.
Returns:
nx.Graph: The NetworkX graph representation of the molecule.
"""
# NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files
# TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff
if mol is None:
return nx.empty_graph()
G = nx.Graph()
for atom in mol.GetAtoms():
# Skip non-heavy atoms
if atom.GetAtomicNum() != 0:
G.add_node(atom.GetIdx(), label=atom.GetSymbol())
for bond in mol.GetBonds():
# Skip bonds to non-heavy atoms
if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0:
continue
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType())
return G
def smiles2graph(smiles: str) -> nx.Graph:
""" Convert a SMILES string to a NetworkX graph.
Args:
smiles (str): The SMILES string to convert.
Returns:
nx.Graph: The NetworkX graph representation of the molecule.
"""
return mol2graph(Chem.MolFromSmiles(smiles))
def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float:
""" Compute the graph edit distance between two SMILES strings.
Args:
smi1 (str): The first SMILES string.
smi2 (str): The second SMILES string.
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
Returns:
float: The graph edit distance between the two SMILES strings.
"""
ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs)
return ged if ged is not None else np.inf
def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float:
""" Compute the graph edit distance between two RDKit molecules.
Args:
mol1 (Chem.Mol): The first RDKit molecule.
mol2 (Chem.Mol): The second RDKit molecule.
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
Returns:
float: The graph edit distance between the two RDKit molecules.
"""
ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs)
return ged if ged is not None else np.inf
def get_smiles2graph_edit_distance_norm(
smi1: str,
smi2: str,
ged_G1_G2: None,
eps: float = 1e-9,
**kwargs,
) -> float:
""" Compute the normalized graph edit distance between two SMILES strings.
Args:
smi1 (str): The first SMILES string.
smi2 (str): The second SMILES string.
ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`.
eps (float): A small value to avoid division by zero.
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
Returns:
float: The normalized graph edit distance between the two SMILES strings.
"""
G1 = smiles2graph(smi1)
G2 = smiles2graph(smi2)
G0 = nx.empty_graph()
ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs)
ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs)
ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs)
if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]:
return np.inf
return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps)
def smiles2adjacency_matrix(smiles: str) -> np.ndarray:
return nx.adjacency_matrix(smiles2graph(smiles)).todense()
def build_label_mapping(G1, G2):
labels = set()
for G in [G1, G2]:
for node in G.nodes():
labels.add(G.nodes[node]['label'])
label_to_int = {label: idx for idx, label in enumerate(sorted(labels))}
return label_to_int
def preprocess_graph(G, label_to_int):
n = G.number_of_nodes()
adj = np.zeros((n, n), dtype=np.int32)
labels = np.zeros(n, dtype=np.int32)
node_id_to_idx = {}
for idx, node in enumerate(G.nodes()):
node_id_to_idx[node] = idx
label = G.nodes[node]['label']
labels[idx] = label_to_int[label]
for u, v in G.edges():
idx_u = node_id_to_idx[u]
idx_v = node_id_to_idx[v]
adj[idx_u, idx_v] = 1
adj[idx_v, idx_u] = 1 # Assuming undirected graph
return adj, labels
@njit
def compute_cost_matrix(labels1, labels2, degrees1, degrees2):
n1 = labels1.shape[0]
n2 = labels2.shape[0]
C = np.zeros((n1, n2), dtype=np.float64)
for i in range(n1):
for j in range(n2):
label_cost = 0.0 if labels1[i] == labels2[j] else 1.0
neighborhood_cost = abs(degrees1[i] - degrees2[j])
C[i, j] = label_cost + neighborhood_cost
return C
@njit
def greedy_assignment(C):
n1, n2 = C.shape
assigned_cols = np.full(n2, False)
row_ind = np.full(n1, -1, dtype=np.int32)
for i in range(n1):
min_cost = np.inf
min_j = -1
for j in range(n2):
if not assigned_cols[j] and C[i, j] < min_cost:
min_cost = C[i, j]
min_j = j
if min_j != -1:
row_ind[i] = min_j
assigned_cols[min_j] = True
return row_ind
@njit
def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins):
total_cost = 0.0
assigned_cols = np.full(n2, False)
for i in range(n1):
j = row_ind[i]
if j != -1:
total_cost += C[i, j]
assigned_cols[j] = True
else:
total_cost += c_node_del
for j in range(n2):
if not assigned_cols[j]:
total_cost += c_node_ins
return total_cost
def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0):
degrees1 = adj1.sum(axis=1)
degrees2 = adj2.sum(axis=1)
C = compute_cost_matrix(labels1, labels2, degrees1, degrees2)
row_ind = greedy_assignment(C)
total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins)
return total_cost
def get_approximate_ged(G1, G2):
label_to_int = build_label_mapping(G1, G2)
adj1, labels1 = preprocess_graph(G1, label_to_int)
adj2, labels2 = preprocess_graph(G2, label_to_int)
cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2)
return cost
def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float:
G1 = smiles2graph(smi1)
G2 = smiles2graph(smi2)
return get_approximate_ged(G1, G2)
|