import math from collections import defaultdict from typing import Literal import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from rdkit import Chem from scipy.sparse import coo_matrix from torch_geometric.data import Data from torch_geometric.nn.pool.topk_pool import TopKPooling from torch_geometric.nn.glob import global_mean_pool as gap, global_max_pool as gmp from torch_geometric.utils import add_self_loops, remove_self_loops from torch_geometric.nn.conv.message_passing import MessagePassing class CoaDTIPro(nn.Module): def __init__(self, esm_model_and_alphabet, n_fingerprint, dim, n_word, layer_output, layer_coa, nhead=8, dropout=0.1, co_attention: Literal['stack', 'encoder', 'inter'] = 'inter', gcn_pooling=False, ): super().__init__() self.co_attention = co_attention self.layer_output = layer_output self.layer_coa = layer_coa self.embed_word = nn.Embedding(n_word, dim) self.gnn = GNN(n_fingerprint, gcn_pooling) self.esm_model, self.alphabet = esm_model_and_alphabet self.batch_converter = self.alphabet.get_batch_converter() self.W_attention = nn.Linear(dim, dim) self.W_out = nn.Sequential( nn.Linear(2 * dim, dim), nn.Linear(dim, 128), nn.Linear(128, 64) ) self.coa_layers = CoAttention(dim, nhead, dropout, layer_coa, co_attention) self.lin = nn.Linear(768, 512) # bert1024 esm768 self.W_interaction = nn.Linear(64, 2) def attention_cnn(self, x, xs, layer): """The attention mechanism is applied to the last layer of CNN.""" xs = torch.unsqueeze(torch.unsqueeze(xs, 0), 0) for i in range(layer): xs = torch.relu(self.W_cnn[i](xs)) xs = torch.squeeze(torch.squeeze(xs, 0), 0) h = torch.relu(self.W_attention(x)) hs = torch.relu(self.W_attention(xs)) weights = torch.tanh(F.linear(h, hs)) ys = torch.t(weights) * hs return torch.unsqueeze(torch.mean(ys, 0), 0) def forward(self, inputs, proteins): """Compound vector with GNN.""" compound_vector = self.gnn(inputs) compound_vector = torch.unsqueeze(compound_vector, 0) # sequence-like GNN ouput _, _, proteins = self.batch_converter([(None, protein) for protein in proteins]) with torch.no_grad(): results = self.esm_model(proteins.to(compound_vector.device), repr_layers=[6]) token_representations = results["representations"][6] protein_vector = token_representations[:, 1:, :] protein_vector = self.lin(torch.squeeze(protein_vector, 1)) protein_vector, compound_vector = self.coa_layers(protein_vector, compound_vector) protein_vector = protein_vector.mean(dim=1) compound_vector = compound_vector.mean(dim=1) """Concatenate the above two vectors and output the interaction.""" cat_vector = torch.cat((compound_vector, protein_vector), 1) cat_vector = torch.tanh(self.W_out(cat_vector)) interaction = self.W_interaction(cat_vector) return interaction class CoAttention(nn.Module): def __init__(self, dim, nhead, dropout, layer_coa, co_attention): super().__init__() self.co_attention = co_attention if self.co_attention == 'encoder': self.coa_layers = EncoderCrossAtt(dim, nhead, dropout, layer_coa) elif self.co_attention == 'stack': self.coa_layers = nn.ModuleList([StackCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)]) elif self.co_attention == 'inter': self.coa_layers = nn.ModuleList([InterCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)]) def forward(self, protein_vector, compound_vector): # x and y are the input tensors for the two modalities # edge_index_x and edge_index_y are the edge indices for the graph data if self.co_attention == 'encoder': return self.coa_layers(protein_vector, compound_vector) else: # loop over the sequential layers and pass the arguments for layer in self.coa_layers: protein_vector, compound_vector = layer(protein_vector, compound_vector) return protein_vector, compound_vector class EncoderCrossAtt(nn.Module): def __init__(self, dim, nhead, dropout, layers): super().__init__() # self.encoder_layers = nn.ModuleList([SEA(dim, dropout) for _ in range(layers)]) self.encoder_layers = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)]) self.decoder_sa = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)]) self.decoder_coa = nn.ModuleList([DPA(dim, nhead, dropout) for _ in range(layers)]) self.layer_coa = layers def forward(self, protein_vector, compound_vector): for i in range(self.layer_coa): compound_vector = self.encoder_layers[i](compound_vector, None) # self-attention for i in range(self.layer_coa): protein_vector = self.decoder_sa[i](protein_vector, None) protein_vector = self.decoder_coa[i](protein_vector, compound_vector, None)# co-attention return protein_vector, compound_vector class InterCrossAtt(nn.Module): def __init__(self, dim, nhead, dropout): super().__init__() self.sca = SA(dim, nhead, dropout) self.spa = SA(dim, nhead, dropout) self.coa_pc = DPA(dim, nhead, dropout) self.coa_cp = DPA(dim, nhead, dropout) def forward(self, protein_vector, compound_vector): compound_vector = self.sca(compound_vector, None) # self-attention protein_vector = self.spa(protein_vector, None) # self-attention compound_covector = self.coa_pc(compound_vector, protein_vector, None) # co-attention protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention return protein_covector, compound_covector class StackCrossAtt(nn.Module): def __init__(self, dim, nhead, dropout): super().__init__() self.sca = SA(dim, nhead, dropout) self.spa = SA(dim, nhead, dropout) self.coa_cp = DPA(dim, nhead, dropout) def forward(self, protein_vector, compound_vector): compound_vector = self.sca(compound_vector, None) # self-attention protein_vector = self.spa(protein_vector, None) # self-attention protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention return protein_covector, compound_vector class MHAtt(nn.Module): def __init__(self, hid_dim, n_heads, dropout): super().__init__() self.linear_v = nn.Linear(hid_dim, hid_dim) self.linear_k = nn.Linear(hid_dim, hid_dim) self.linear_q = nn.Linear(hid_dim, hid_dim) self.linear_merge = nn.Linear(hid_dim, hid_dim) self.hid_dim = hid_dim self.dropout = dropout self.nhead = n_heads self.dropout = nn.Dropout(dropout) self.hidden_size_head = int(self.hid_dim / self.nhead) def forward(self, v, k, q, mask): n_batches = q.size(0) v = self.linear_v(v).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) k = self.linear_k(k).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) q = self.linear_q(q).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) atted = self.att(v, k, q, mask) atted = atted.transpose(1, 2).contiguous().view(n_batches, -1, self.hid_dim) atted = self.linear_merge(atted) return atted def att(self, value, key, query, mask): d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask, -1e9) att_map = F.softmax(scores, dim=-1) att_map = self.dropout(att_map) return torch.matmul(att_map, value) class DPA(nn.Module): def __init__(self, hid_dim, n_heads, dropout): super().__init__() self.mhatt1 = MHAtt(hid_dim, n_heads, dropout) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(hid_dim) def forward(self, x, y, y_mask=None): x = self.norm1(x + self.dropout1(self.mhatt1(y, y, x, y_mask))) return x class SA(nn.Module): def __init__(self, hid_dim, n_heads, dropout): super().__init__() self.mhatt1 = MHAtt(hid_dim, n_heads, dropout) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(hid_dim) def forward(self, x, mask=None): x = self.norm1(x + self.dropout1(self.mhatt1(x, x, x, mask))) return x class SAGEConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='max') # "Max" aggregation. self.lin = torch.nn.Linear(in_channels, out_channels) self.act = torch.nn.ReLU() self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False) self.update_act = torch.nn.ReLU() def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) def message(self, x_j): # x_j has shape [E, in_channels] x_j = self.lin(x_j) x_j = self.act(x_j) return x_j def update(self, aggr_out, x): # aggr_out has shape [N, out_channels] new_embedding = torch.cat([aggr_out, x], dim=1) new_embedding = self.update_lin(new_embedding) new_embedding = self.update_act(new_embedding) return new_embedding class GNN(nn.Module): def __init__(self, n_fingerprint, pooling, embed_dim=128): super().__init__() self.pooling = pooling self.embed_fingerprint = nn.Embedding(num_embeddings=n_fingerprint, embedding_dim=embed_dim) self.conv1 = SAGEConv(embed_dim, 128) self.pool1 = TopKPooling(128, ratio=0.8) self.conv2 = SAGEConv(128, 128) self.pool2 = TopKPooling(128, ratio=0.8) self.conv3 = SAGEConv(128, 128) self.pool3 = TopKPooling(128, ratio=0.8) self.linp1 = torch.nn.Linear(256, 128) self.linp2 = torch.nn.Linear(128, 512) self.lin = torch.nn.Linear(128, 512) self.bn1 = torch.nn.BatchNorm1d(128) self.bn2 = torch.nn.BatchNorm1d(64) self.act1 = torch.nn.ReLU() self.act2 = torch.nn.ReLU() def forward(self, data): # x, edge_index, batch = data.x, data.edge_index, data.batch x, edge_index, batch = data.x, data.edge_index, data.batch x = self.embed_fingerprint(x) x = x.squeeze(1) x = F.relu(self.conv1(x, edge_index)) if self.pooling: x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = x1 + x2 + x3 x = self.linp1(x) x = self.act1(x) x = self.linp2(x) else: x = F.relu(self.conv2(x, edge_index)) x = self.lin(x) return x atom_dict = defaultdict(lambda: len(atom_dict)) # 51 bindingdb: 26 bond_dict = defaultdict(lambda: len(bond_dict)) # 4 bindingdb: 4 fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) # 6341 bindingdb: 20366 edge_dict = defaultdict(lambda: len(edge_dict)) # 17536 bindingdb: 77916 word_dict = defaultdict(lambda: len(word_dict)) # 22 bindingdb: 21 def drug_featurizer(smiles, radius=2): mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) atoms = create_atoms(mol) i_jbond_dict = create_ijbonddict(mol) fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) adjacency = coo_matrix(Chem.GetAdjacencyMatrix(mol)) adjacency = coo_matrix(adjacency) edge_index = np.array([adjacency.row, adjacency.col]) return Data(x=torch.LongTensor(fingerprints).unsqueeze(1), edge_index=torch.LongTensor(edge_index)) def create_atoms(mol): """Create a list of atom (e.g., hydrogen and oxygen) IDs considering the aromaticity.""" # GetSymbol: obtain the symbol of the atom atoms = [a.GetSymbol() for a in mol.GetAtoms()] for a in mol.GetAromaticAtoms(): i = a.GetIdx() atoms[i] = (atoms[i], 'aromatic') # turn it into index atoms = [atom_dict[a] for a in atoms] return np.array(atoms) def create_ijbonddict(mol): """Create a dictionary, which each key is a node ID and each value is the tuples of its neighboring node and bond (e.g., single and double) IDs.""" i_jbond_dict = defaultdict(lambda: []) for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() bond = bond_dict[str(b.GetBondType())] i_jbond_dict[i].append((j, bond)) i_jbond_dict[j].append((i, bond)) return i_jbond_dict def extract_fingerprints(atoms, i_jbond_dict, radius=2): """Extract the r-radius subgraphs (i.e., fingerprints) from a molecular graph using Weisfeiler-Lehman algorithm.""" fingerprints = None if (len(atoms) == 1) or (radius == 0): fingerprints = [fingerprint_dict[a] for a in atoms] else: nodes = atoms i_jedge_dict = i_jbond_dict for _ in range(radius): """Update each node ID considering its neighboring nodes and edges (i.e., r-radius subgraphs or fingerprints).""" fingerprints = [] for i, j_edge in i_jedge_dict.items(): neighbors = [(nodes[j], edge) for j, edge in j_edge] fingerprint = (nodes[i], tuple(sorted(neighbors))) fingerprints.append(fingerprint_dict[fingerprint]) nodes = fingerprints """Also update each edge ID considering two nodes on its both sides.""" _i_jedge_dict = defaultdict(lambda: []) for i, j_edge in i_jedge_dict.items(): for j, edge in j_edge: both_side = tuple(sorted((nodes[i], nodes[j]))) edge = edge_dict[(both_side, edge)] _i_jedge_dict[i].append((j, edge)) i_jedge_dict = _i_jedge_dict return np.array(fingerprints)