File size: 6,694 Bytes
9dd777e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from numba import njit
import numpy as np
import networkx as nx
from rdkit import Chem


def mol2graph(mol: Chem.Mol) -> nx.Graph:
    """ Convert an RDKit molecule to a NetworkX graph.
    
    Args:
        mol (Chem.Mol): The RDKit molecule to convert.

    Returns:
        nx.Graph: The NetworkX graph representation of the molecule.
    """
    # NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files
    # TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff
    if mol is None:
        return nx.empty_graph()
    G = nx.Graph()
    for atom in mol.GetAtoms():
        # Skip non-heavy atoms
        if atom.GetAtomicNum() != 0:
            G.add_node(atom.GetIdx(), label=atom.GetSymbol())
    for bond in mol.GetBonds():
        # Skip bonds to non-heavy atoms
        if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0:
            continue
        G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType())
    return G

def smiles2graph(smiles: str) -> nx.Graph:
    """ Convert a SMILES string to a NetworkX graph.
    
    Args:
        smiles (str): The SMILES string to convert.

    Returns:
        nx.Graph: The NetworkX graph representation of the molecule.
    """
    return mol2graph(Chem.MolFromSmiles(smiles))

def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float:
    """ Compute the graph edit distance between two SMILES strings.
    
    Args:
        smi1 (str): The first SMILES string.
        smi2 (str): The second SMILES string.
        **kwargs: Additional keyword arguments for `nx.graph_edit_distance`.

    Returns:
        float: The graph edit distance between the two SMILES strings.
    """
    ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs)
    return ged if ged is not None else np.inf

def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float:
    """ Compute the graph edit distance between two RDKit molecules.

    Args:
        mol1 (Chem.Mol): The first RDKit molecule.
        mol2 (Chem.Mol): The second RDKit molecule.
        **kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
 
    Returns:
        float: The graph edit distance between the two RDKit molecules.
    """
    ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs)
    return ged if ged is not None else np.inf

def get_smiles2graph_edit_distance_norm(
        smi1: str,
        smi2: str,
        ged_G1_G2: None,
        eps: float = 1e-9,
        **kwargs,
) -> float:
    """ Compute the normalized graph edit distance between two SMILES strings.
    
    Args:
        smi1 (str): The first SMILES string.
        smi2 (str): The second SMILES string.
        ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`.
        eps (float): A small value to avoid division by zero.
        **kwargs: Additional keyword arguments for `nx.graph_edit_distance`.

    Returns:
        float: The normalized graph edit distance between the two SMILES strings.
    """
    G1 = smiles2graph(smi1)
    G2 = smiles2graph(smi2)
    G0 = nx.empty_graph()
    ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs)
    ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs)
    ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs)
    if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]:
        return np.inf
    return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps)

def smiles2adjacency_matrix(smiles: str) -> np.ndarray:
    return nx.adjacency_matrix(smiles2graph(smiles)).todense()

def build_label_mapping(G1, G2):
    labels = set()
    for G in [G1, G2]:
        for node in G.nodes():
            labels.add(G.nodes[node]['label'])
    label_to_int = {label: idx for idx, label in enumerate(sorted(labels))}
    return label_to_int

def preprocess_graph(G, label_to_int):
    n = G.number_of_nodes()
    adj = np.zeros((n, n), dtype=np.int32)
    labels = np.zeros(n, dtype=np.int32)
    node_id_to_idx = {}
    for idx, node in enumerate(G.nodes()):
        node_id_to_idx[node] = idx
        label = G.nodes[node]['label']
        labels[idx] = label_to_int[label]
    for u, v in G.edges():
        idx_u = node_id_to_idx[u]
        idx_v = node_id_to_idx[v]
        adj[idx_u, idx_v] = 1
        adj[idx_v, idx_u] = 1  # Assuming undirected graph
    return adj, labels

@njit
def compute_cost_matrix(labels1, labels2, degrees1, degrees2):
    n1 = labels1.shape[0]
    n2 = labels2.shape[0]
    C = np.zeros((n1, n2), dtype=np.float64)
    for i in range(n1):
        for j in range(n2):
            label_cost = 0.0 if labels1[i] == labels2[j] else 1.0
            neighborhood_cost = abs(degrees1[i] - degrees2[j])
            C[i, j] = label_cost + neighborhood_cost
    return C

@njit
def greedy_assignment(C):
    n1, n2 = C.shape
    assigned_cols = np.full(n2, False)
    row_ind = np.full(n1, -1, dtype=np.int32)
    for i in range(n1):
        min_cost = np.inf
        min_j = -1
        for j in range(n2):
            if not assigned_cols[j] and C[i, j] < min_cost:
                min_cost = C[i, j]
                min_j = j
        if min_j != -1:
            row_ind[i] = min_j
            assigned_cols[min_j] = True
    return row_ind

@njit
def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins):
    total_cost = 0.0
    assigned_cols = np.full(n2, False)
    for i in range(n1):
        j = row_ind[i]
        if j != -1:
            total_cost += C[i, j]
            assigned_cols[j] = True
        else:
            total_cost += c_node_del
    for j in range(n2):
        if not assigned_cols[j]:
            total_cost += c_node_ins
    return total_cost

def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0):
    degrees1 = adj1.sum(axis=1)
    degrees2 = adj2.sum(axis=1)
    C = compute_cost_matrix(labels1, labels2, degrees1, degrees2)
    row_ind = greedy_assignment(C)
    total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins)
    return total_cost

def get_approximate_ged(G1, G2):
    label_to_int = build_label_mapping(G1, G2)
    adj1, labels1 = preprocess_graph(G1, label_to_int)
    adj2, labels2 = preprocess_graph(G2, label_to_int)
    cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2)
    return cost

def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float:
    G1 = smiles2graph(smi1)
    G2 = smiles2graph(smi2)
    return get_approximate_ged(G1, G2)