|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from opt_einsum import contract as einsum |
|
import copy |
|
import dgl |
|
from util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points |
|
|
|
def init_lecun_normal(module, scale=1.0): |
|
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): |
|
normal = torch.distributions.normal.Normal(0, 1) |
|
|
|
alpha = (a - mu) / sigma |
|
beta = (b - mu) / sigma |
|
|
|
alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) |
|
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform |
|
|
|
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) |
|
x = mu + sigma * np.sqrt(2) * torch.erfinv(v) |
|
x = torch.clamp(x, a, b) |
|
|
|
return x |
|
|
|
def sample_truncated_normal(shape, scale=1.0): |
|
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 |
|
return stddev * truncated_normal(torch.rand(shape)) |
|
|
|
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) ) |
|
return module |
|
|
|
def init_lecun_normal_param(weight, scale=1.0): |
|
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): |
|
normal = torch.distributions.normal.Normal(0, 1) |
|
|
|
alpha = (a - mu) / sigma |
|
beta = (b - mu) / sigma |
|
|
|
alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) |
|
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform |
|
|
|
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) |
|
x = mu + sigma * np.sqrt(2) * torch.erfinv(v) |
|
x = torch.clamp(x, a, b) |
|
|
|
return x |
|
|
|
def sample_truncated_normal(shape, scale=1.0): |
|
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 |
|
return stddev * truncated_normal(torch.rand(shape)) |
|
|
|
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) ) |
|
return weight |
|
|
|
|
|
def create_custom_forward(module, **kwargs): |
|
def custom_forward(*inputs): |
|
return module(*inputs, **kwargs) |
|
return custom_forward |
|
|
|
def get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
class Dropout(nn.Module): |
|
|
|
def __init__(self, broadcast_dim=None, p_drop=0.15): |
|
super(Dropout, self).__init__() |
|
|
|
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop])) |
|
self.broadcast_dim=broadcast_dim |
|
self.p_drop=p_drop |
|
def forward(self, x): |
|
if not self.training: |
|
return x |
|
shape = list(x.shape) |
|
if not self.broadcast_dim == None: |
|
shape[self.broadcast_dim] = 1 |
|
mask = self.sampler.sample(shape).to(x.device).view(shape) |
|
|
|
x = mask * x / (1.0 - self.p_drop) |
|
return x |
|
|
|
def rbf(D): |
|
|
|
D_min, D_max, D_count = 0., 20., 36 |
|
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device) |
|
D_mu = D_mu[None,:] |
|
D_sigma = (D_max - D_min) / D_count |
|
D_expand = torch.unsqueeze(D, -1) |
|
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) |
|
return RBF |
|
|
|
def get_seqsep(idx): |
|
''' |
|
Input: |
|
- idx: residue indices of given sequence (B,L) |
|
Output: |
|
- seqsep: sequence separation feature with sign (B, L, L, 1) |
|
Sergey found that having sign in seqsep features helps a little |
|
''' |
|
seqsep = idx[:,None,:] - idx[:,:,None] |
|
sign = torch.sign(seqsep) |
|
neigh = torch.abs(seqsep) |
|
neigh[neigh > 1] = 0.0 |
|
neigh = sign * neigh |
|
return neigh.unsqueeze(-1) |
|
|
|
def make_full_graph(xyz, pair, idx, top_k=64, kmin=9): |
|
''' |
|
Input: |
|
- xyz: current backbone cooordinates (B, L, 3, 3) |
|
- pair: pair features from Trunk (B, L, L, E) |
|
- idx: residue index from ground truth pdb |
|
Output: |
|
- G: defined graph |
|
''' |
|
|
|
B, L = xyz.shape[:2] |
|
device = xyz.device |
|
|
|
|
|
sep = idx[:,None,:] - idx[:,:,None] |
|
b,i,j = torch.where(sep.abs() > 0) |
|
|
|
src = b*L+i |
|
tgt = b*L+j |
|
G = dgl.graph((src, tgt), num_nodes=B*L).to(device) |
|
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() |
|
|
|
return G, pair[b,i,j][...,None] |
|
|
|
def make_topk_graph(xyz, pair, idx, top_k=64, kmin=32, eps=1e-6): |
|
''' |
|
Input: |
|
- xyz: current backbone cooordinates (B, L, 3, 3) |
|
- pair: pair features from Trunk (B, L, L, E) |
|
- idx: residue index from ground truth pdb |
|
Output: |
|
- G: defined graph |
|
''' |
|
|
|
B, L = xyz.shape[:2] |
|
device = xyz.device |
|
|
|
|
|
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*999.9 |
|
|
|
sep = idx[:,None,:] - idx[:,:,None] |
|
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*999.9 |
|
D = D + sep*eps |
|
|
|
|
|
D_neigh, E_idx = torch.topk(D, min(top_k, L), largest=False) |
|
topk_matrix = torch.zeros((B, L, L), device=device) |
|
topk_matrix.scatter_(2, E_idx, 1.0) |
|
|
|
|
|
|
|
|
|
cond = torch.logical_or(topk_matrix > 0.0, sep < kmin) |
|
b,i,j = torch.where(cond) |
|
|
|
src = b*L+i |
|
tgt = b*L+j |
|
G = dgl.graph((src, tgt), num_nodes=B*L).to(device) |
|
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() |
|
|
|
return G, pair[b,i,j][...,None] |
|
|
|
def make_rotX(angs, eps=1e-6): |
|
B,L = angs.shape[:2] |
|
NORM = torch.linalg.norm(angs, dim=-1) + eps |
|
|
|
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) |
|
|
|
RTs[:,:,1,1] = angs[:,:,0]/NORM |
|
RTs[:,:,1,2] = -angs[:,:,1]/NORM |
|
RTs[:,:,2,1] = angs[:,:,1]/NORM |
|
RTs[:,:,2,2] = angs[:,:,0]/NORM |
|
return RTs |
|
|
|
|
|
def make_rotZ(angs, eps=1e-6): |
|
B,L = angs.shape[:2] |
|
NORM = torch.linalg.norm(angs, dim=-1) + eps |
|
|
|
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) |
|
|
|
RTs[:,:,0,0] = angs[:,:,0]/NORM |
|
RTs[:,:,0,1] = -angs[:,:,1]/NORM |
|
RTs[:,:,1,0] = angs[:,:,1]/NORM |
|
RTs[:,:,1,1] = angs[:,:,0]/NORM |
|
return RTs |
|
|
|
|
|
def make_rot_axis(angs, u, eps=1e-6): |
|
B,L = angs.shape[:2] |
|
NORM = torch.linalg.norm(angs, dim=-1) + eps |
|
|
|
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) |
|
|
|
ct = angs[:,:,0]/NORM |
|
st = angs[:,:,1]/NORM |
|
u0 = u[:,:,0] |
|
u1 = u[:,:,1] |
|
u2 = u[:,:,2] |
|
|
|
RTs[:,:,0,0] = ct+u0*u0*(1-ct) |
|
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st |
|
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st |
|
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st |
|
RTs[:,:,1,1] = ct+u1*u1*(1-ct) |
|
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st |
|
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st |
|
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st |
|
RTs[:,:,2,2] = ct+u2*u2*(1-ct) |
|
return RTs |
|
|
|
class ComputeAllAtomCoords(nn.Module): |
|
def __init__(self): |
|
super(ComputeAllAtomCoords, self).__init__() |
|
|
|
self.base_indices = nn.Parameter(base_indices, requires_grad=False) |
|
self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False) |
|
self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False) |
|
|
|
def forward(self, seq, xyz, alphas, non_ideal=False, use_H=True): |
|
B,L = xyz.shape[:2] |
|
|
|
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], non_ideal=non_ideal) |
|
|
|
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device) |
|
|
|
|
|
RTF0[:,:,:3,:3] = Rs |
|
RTF0[:,:,:3,3] = Ts |
|
|
|
|
|
RTF1 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:])) |
|
|
|
|
|
RTF2 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:])) |
|
|
|
|
|
RTF3 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:])) |
|
|
|
|
|
basexyzs = self.xyzs_in_base_frame[seq] |
|
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3]) |
|
CAr = (basexyzs[:,:,1,:3]) |
|
CBr = (basexyzs[:,:,4,:3]) |
|
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr) |
|
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8 |
|
|
|
|
|
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3] |
|
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr |
|
CBrotaxis2 = (CBr-CAr).cross(NCpp) |
|
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8 |
|
|
|
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 ) |
|
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 ) |
|
|
|
RTF8 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF0, CBrot1,CBrot2) |
|
|
|
|
|
RTF4 = torch.einsum( |
|
'brij,brjk,brkl,brlm->brim', |
|
RTF8, |
|
self.RTs_in_base_frame[seq,3,:], |
|
make_rotX(alphas[:,:,3,:]), |
|
make_rotZ(alphas[:,:,9,:])) |
|
|
|
|
|
RTF5 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:])) |
|
|
|
|
|
RTF6 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:])) |
|
|
|
|
|
RTF7 = torch.einsum( |
|
'brij,brjk,brkl->bril', |
|
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:])) |
|
|
|
RTframes = torch.stack(( |
|
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8 |
|
),dim=2) |
|
|
|
xyzs = torch.einsum( |
|
'brtij,brtj->brti', |
|
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs |
|
) |
|
|
|
if use_H: |
|
return RTframes, xyzs[...,:3] |
|
else: |
|
return RTframes, xyzs[...,:14,:3] |
|
|