|
import math |
|
import os |
|
import csv |
|
import random |
|
import torch |
|
from torch.utils import data |
|
import numpy as np |
|
from dateutil import parser |
|
import contigs |
|
from util import * |
|
from kinematics import * |
|
import pandas as pd |
|
import sys |
|
import torch.nn as nn |
|
from icecream import ic |
|
def write_pdb(filename, seq, atoms, Bfacts=None, prefix=None, chains=None): |
|
L = len(seq) |
|
ctr = 1 |
|
seq = seq.long() |
|
with open(filename, 'w+') as f: |
|
for i,s in enumerate(seq): |
|
if chains is None: |
|
chain='A' |
|
else: |
|
chain=chains[i] |
|
|
|
if (len(atoms.shape)==2): |
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( |
|
"ATOM", ctr, " CA ", util.num2aa[s], |
|
chain, i+1, atoms[i,0], atoms[i,1], atoms[i,2], |
|
1.0, Bfacts[i] ) ) |
|
ctr += 1 |
|
|
|
elif atoms.shape[1]==3: |
|
for j,atm_j in enumerate((" N "," CA "," C ")): |
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( |
|
"ATOM", ctr, atm_j, num2aa[s], |
|
chain, i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], |
|
1.0, Bfacts[i] ) ) |
|
ctr += 1 |
|
else: |
|
atms = aa2long[s] |
|
for j,atm_j in enumerate(atms): |
|
if (atm_j is not None): |
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( |
|
"ATOM", ctr, atm_j, num2aa[s], |
|
chain, i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], |
|
1.0, Bfacts[i] ) ) |
|
ctr += 1 |
|
|
|
def preprocess(xyz_t, t1d, DEVICE, masks_1d, ti_dev=None, ti_flip=None, ang_ref=None): |
|
|
|
B, _, L, _, _ = xyz_t.shape |
|
|
|
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L).to(DEVICE, non_blocking=True) |
|
alpha, _, alpha_mask,_ = get_torsions(xyz_t.reshape(-1,L,27,3), seq_tmp, ti_dev, ti_flip, ang_ref) |
|
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) |
|
alpha[torch.isnan(alpha)] = 0.0 |
|
alpha = alpha.reshape(B,-1,L,10,2) |
|
alpha_mask = alpha_mask.reshape(B,-1,L,10,1) |
|
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B,-1,L,30) |
|
|
|
xyz_t = get_init_xyz(xyz_t) |
|
xyz_prev = xyz_t[:,0] |
|
state = t1d[:,0] |
|
alpha = alpha[:,0] |
|
t2d=xyz_to_t2d(xyz_t) |
|
return (t2d, alpha, alpha_mask, alpha_t, t1d, xyz_t, xyz_prev, state) |
|
|
|
def TemplFeaturizeFixbb(seq, conf_1d=None): |
|
""" |
|
Template 1D featurizer for fixed BB examples : |
|
Parameters: |
|
seq (torch.tensor, required): Integer sequence |
|
conf_1d (torch.tensor, optional): Precalcualted confidence tensor |
|
""" |
|
L = seq.shape[-1] |
|
t1d = torch.nn.functional.one_hot(seq, num_classes=21) |
|
if conf_1d is None: |
|
conf = torch.ones_like(seq)[...,None] |
|
else: |
|
conf = conf_1d[:,None] |
|
t1d = torch.cat((t1d, conf), dim=-1) |
|
return t1d |
|
|
|
def MSAFeaturize_fixbb(msa, params): |
|
''' |
|
Input: full msa information |
|
Output: Single sequence, with some percentage of amino acids mutated (but no resides 'masked') |
|
|
|
This is modified from autofold2, to remove mutations of the single sequence |
|
''' |
|
N, L = msa.shape |
|
|
|
raw_profile = torch.nn.functional.one_hot(msa, num_classes=22) |
|
raw_profile = raw_profile.float().mean(dim=0) |
|
|
|
b_seq = list() |
|
b_msa_clust = list() |
|
b_msa_seed = list() |
|
b_msa_extra = list() |
|
b_mask_pos = list() |
|
for i_cycle in range(params['MAXCYCLE']): |
|
assert torch.max(msa) < 22 |
|
msa_onehot = torch.nn.functional.one_hot(msa[:1],num_classes=22) |
|
msa_fakeprofile_onehot = torch.nn.functional.one_hot(msa[:1],num_classes=26) |
|
msa_full_onehot = torch.cat((msa_onehot, msa_fakeprofile_onehot), dim=-1) |
|
|
|
|
|
msa_extra_onehot = torch.nn.functional.one_hot(msa[:1],num_classes=25) |
|
|
|
|
|
msa_clust = msa[:1] |
|
mask_pos = torch.full_like(msa_clust, 1).bool() |
|
b_seq.append(msa[0].clone()) |
|
b_msa_seed.append(msa_full_onehot[:1].clone()) |
|
b_msa_extra.append(msa_extra_onehot[:1].clone()) |
|
b_msa_clust.append(msa_clust[:1].clone()) |
|
b_mask_pos.append(mask_pos[:1].clone()) |
|
|
|
b_seq = torch.stack(b_seq) |
|
b_msa_clust = torch.stack(b_msa_clust) |
|
b_msa_seed = torch.stack(b_msa_seed) |
|
b_msa_extra = torch.stack(b_msa_extra) |
|
b_mask_pos = torch.stack(b_mask_pos) |
|
|
|
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos |
|
|
|
def MSAFeaturize(msa, params): |
|
''' |
|
Input: full msa information |
|
Output: Single sequence, with some percentage of amino acids mutated (but no resides 'masked') |
|
|
|
This is modified from autofold2, to remove mutations of the single sequence |
|
''' |
|
N, L = msa.shape |
|
|
|
raw_profile = torch.nn.functional.one_hot(msa, num_classes=22) |
|
raw_profile = raw_profile.float().mean(dim=0) |
|
|
|
b_seq = list() |
|
b_msa_clust = list() |
|
b_msa_seed = list() |
|
b_msa_extra = list() |
|
b_mask_pos = list() |
|
for i_cycle in range(params['MAXCYCLE']): |
|
assert torch.max(msa) < 22 |
|
msa_onehot = torch.nn.functional.one_hot(msa,num_classes=22) |
|
msa_fakeprofile_onehot = torch.nn.functional.one_hot(msa,num_classes=26) |
|
msa_full_onehot = torch.cat((msa_onehot, msa_fakeprofile_onehot), dim=-1) |
|
|
|
|
|
msa_extra_onehot = torch.nn.functional.one_hot(msa,num_classes=25) |
|
|
|
|
|
msa_clust = msa |
|
mask_pos = torch.full_like(msa_clust, 1).bool() |
|
b_seq.append(msa[0].clone()) |
|
b_msa_seed.append(msa_full_onehot.clone()) |
|
b_msa_extra.append(msa_extra_onehot.clone()) |
|
b_msa_clust.append(msa_clust.clone()) |
|
b_mask_pos.append(mask_pos.clone()) |
|
|
|
b_seq = torch.stack(b_seq) |
|
b_msa_clust = torch.stack(b_msa_clust) |
|
b_msa_seed = torch.stack(b_msa_seed) |
|
b_msa_extra = torch.stack(b_msa_extra) |
|
b_mask_pos = torch.stack(b_mask_pos) |
|
|
|
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos |
|
|
|
def mask_inputs(seq, msa_masked, msa_full, xyz_t, t1d, input_seq_mask=None, input_str_mask=None, input_t1dconf_mask=None, loss_seq_mask=None, loss_str_mask=None): |
|
""" |
|
Parameters: |
|
seq (torch.tensor, required): (B,I,L) integer sequence |
|
msa_masked (torch.tensor, required): (B,I,N_short,L,46) |
|
msa_full (torch,.tensor, required): (B,I,N_long,L,23) |
|
|
|
xyz_t (torch,tensor): (B,T,L,14,3) template crds BEFORE they go into get_init_xyz |
|
|
|
t1d (torch.tensor, required): (B,I,L,22) this is the t1d before tacking on the chi angles |
|
|
|
str_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where structure is masked at False positions |
|
seq_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where seq is masked at False positions |
|
""" |
|
|
|
|
|
B,_,_ = seq.shape |
|
assert B == 1, 'batch sizes > 1 not supported' |
|
seq_mask = input_seq_mask[0] |
|
seq[:,:,~seq_mask] = 21 |
|
|
|
|
|
|
|
msa_masked[:,:,:,~seq_mask,:20] = 0 |
|
msa_masked[:,:,:,~seq_mask,20] = 0 |
|
msa_masked[:,:,:,~seq_mask,21] = 1 |
|
|
|
|
|
|
|
|
|
msa_masked[:,:,:,~seq_mask,22:42] = 0 |
|
msa_masked[:,:,:,~seq_mask,43] = 1 |
|
msa_masked[:,:,:,~seq_mask,42] = 0 |
|
|
|
|
|
msa_masked[:,:,:,~seq_mask,44:] = 0 |
|
|
|
|
|
|
|
msa_full[:,:,:,~seq_mask,:20] = 0 |
|
msa_full[:,:,:,~seq_mask,21] = 1 |
|
msa_full[:,:,:,~seq_mask,20] = 0 |
|
msa_full[:,:,:,~seq_mask,-1] = 0 |
|
|
|
|
|
|
|
|
|
t1d[:,:,~seq_mask,:20] = 0 |
|
t1d[:,:,~seq_mask,20] = 1 |
|
|
|
t1d[:,:,:,21] *= input_t1dconf_mask |
|
|
|
|
|
print('expanding t1d to 24 dims') |
|
|
|
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],t1d.shape[2],2)).float()), -1).to(seq.device) |
|
|
|
xyz_t[:,:,~seq_mask,3:,:] = float('nan') |
|
|
|
|
|
str_mask = input_str_mask[0] |
|
xyz_t[:,:,~str_mask,:,:] = float('nan') |
|
|
|
return seq, msa_masked, msa_full, xyz_t, t1d |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_translated_coords(args): |
|
''' |
|
Parses args.res_translate |
|
''' |
|
|
|
res_translate = [] |
|
for res in args.res_translate.split(":"): |
|
temp_str = [] |
|
for i in res.split(','): |
|
temp_str.append(i) |
|
if temp_str[-1][0].isalpha() is True: |
|
temp_str.append(2.0) |
|
for i in temp_str[:-1]: |
|
if '-' in i: |
|
start = int(i.split('-')[0][1:]) |
|
while start <= int(i.split('-')[1]): |
|
res_translate.append((i.split('-')[0][0] + str(start),float(temp_str[-1]))) |
|
start += 1 |
|
else: |
|
res_translate.append((i, float(temp_str[-1]))) |
|
start = 0 |
|
|
|
output = [] |
|
for i in res_translate: |
|
temp = (i[0], i[1], start) |
|
output.append(temp) |
|
start += 1 |
|
|
|
return output |
|
|
|
def get_tied_translated_coords(args, untied_translate=None): |
|
''' |
|
Parses args.tie_translate |
|
''' |
|
|
|
|
|
|
|
res_translate = [] |
|
block = 0 |
|
for res in args.tie_translate.split(":"): |
|
temp_str = [] |
|
for i in res.split(','): |
|
temp_str.append(i) |
|
if temp_str[-1][0].isalpha() is True: |
|
temp_str.append(2.0) |
|
for i in temp_str[:-1]: |
|
if '-' in i: |
|
start = int(i.split('-')[0][1:]) |
|
while start <= int(i.split('-')[1]): |
|
res_translate.append((i.split('-')[0][0] + str(start),float(temp_str[-1]), block)) |
|
start += 1 |
|
else: |
|
res_translate.append((i, float(temp_str[-1]), block)) |
|
block += 1 |
|
|
|
|
|
if untied_translate != None: |
|
checker = [i[0] for i in res_translate] |
|
untied_check = [i[0] for i in untied_translate] |
|
for i in checker: |
|
if i in untied_check: |
|
print(f'WARNING: residue {i} is specified both in --res_translate and --tie_translate. Residue {i} will be ignored in --res_translate, and instead only moved in a tied block (--tie_translate)') |
|
|
|
final_output = res_translate |
|
for i in untied_translate: |
|
if i[0] not in checker: |
|
final_output.append((i[0],i[1],i[2] + block + 1)) |
|
else: |
|
final_output = res_translate |
|
|
|
return final_output |
|
|
|
|
|
|
|
def translate_coords(parsed_pdb, res_translate): |
|
''' |
|
Takes parsed list in format [(chain_residue,distance,tieing_block)] and randomly translates residues accordingly. |
|
''' |
|
|
|
pdb_idx = parsed_pdb['pdb_idx'] |
|
xyz = np.copy(parsed_pdb['xyz']) |
|
translated_coord_dict = {} |
|
|
|
temp = [int(i[2]) for i in res_translate] |
|
blocks = np.max(temp) |
|
|
|
for block in range(blocks + 1): |
|
init_dist = 1.01 |
|
while init_dist > 1: |
|
x = random.uniform(-1,1) |
|
y = random.uniform(-1,1) |
|
z = random.uniform(-1,1) |
|
init_dist = np.sqrt(x**2 + y**2 + z**2) |
|
x=x/init_dist |
|
y=y/init_dist |
|
z=z/init_dist |
|
translate_dist = random.uniform(0,1) |
|
for res in res_translate: |
|
if res[2] == block: |
|
res_idx = pdb_idx.index((res[0][0],int(res[0][1:]))) |
|
original_coords = np.copy(xyz[res_idx,:,:]) |
|
for i in range(14): |
|
if parsed_pdb['mask'][res_idx, i]: |
|
xyz[res_idx,i,0] += np.float32(x * translate_dist * float(res[1])) |
|
xyz[res_idx,i,1] += np.float32(y * translate_dist * float(res[1])) |
|
xyz[res_idx,i,2] += np.float32(z * translate_dist * float(res[1])) |
|
translated_coords = xyz[res_idx,:,:] |
|
translated_coord_dict[res[0]] = (original_coords.tolist(), translated_coords.tolist()) |
|
|
|
return xyz[:,:,:], translated_coord_dict |
|
|
|
def parse_block_rotate(args): |
|
block_translate = [] |
|
block = 0 |
|
for res in args.block_rotate.split(":"): |
|
temp_str = [] |
|
for i in res.split(','): |
|
temp_str.append(i) |
|
if temp_str[-1][0].isalpha() is True: |
|
temp_str.append(10) |
|
for i in temp_str[:-1]: |
|
if '-' in i: |
|
start = int(i.split('-')[0][1:]) |
|
while start <= int(i.split('-')[1]): |
|
block_translate.append((i.split('-')[0][0] + str(start),float(temp_str[-1]), block)) |
|
start += 1 |
|
else: |
|
block_translate.append((i, float(temp_str[-1]), block)) |
|
block += 1 |
|
return block_translate |
|
|
|
def rotate_block(xyz, block_rotate,pdb_index): |
|
rotated_coord_dict = {} |
|
|
|
temp = [int(i[2]) for i in block_rotate] |
|
blocks = np.max(temp) |
|
for block in range(blocks + 1): |
|
idxs = [pdb_index.index((i[0][0],int(i[0][1:]))) for i in block_rotate if i[2] == block] |
|
angle = [i[1] for i in block_rotate if i[2] == block][0] |
|
block_xyz = xyz[idxs,:,:] |
|
com = [float(torch.mean(block_xyz[:,:,i])) for i in range(3)] |
|
origin_xyz = np.copy(block_xyz) |
|
for i in range(np.shape(origin_xyz)[0]): |
|
for j in range(14): |
|
origin_xyz[i,j] = origin_xyz[i,j] - com |
|
rotated_xyz = rigid_rotate(origin_xyz,angle,angle,angle) |
|
recovered_xyz = np.copy(rotated_xyz) |
|
for i in range(np.shape(origin_xyz)[0]): |
|
for j in range(14): |
|
recovered_xyz[i,j] = rotated_xyz[i,j] + com |
|
recovered_xyz=torch.tensor(recovered_xyz) |
|
rotated_coord_dict[f'rotated_block_{block}_original'] = block_xyz |
|
rotated_coord_dict[f'rotated_block_{block}_rotated'] = recovered_xyz |
|
xyz_out = torch.clone(xyz) |
|
for i in range(len(idxs)): |
|
xyz_out[idxs[i]] = recovered_xyz[i] |
|
return xyz_out,rotated_coord_dict |
|
|
|
def rigid_rotate(xyz,a=180,b=180,c=180): |
|
|
|
a=(a/180)*math.pi |
|
b=(b/180)*math.pi |
|
c=(c/180)*math.pi |
|
alpha = random.uniform(-a, a) |
|
beta = random.uniform(-b, b) |
|
gamma = random.uniform(-c, c) |
|
rotated = [] |
|
for i in range(np.shape(xyz)[0]): |
|
for j in range(14): |
|
try: |
|
x = xyz[i,j,0] |
|
y = xyz[i,j,1] |
|
z = xyz[i,j,2] |
|
x2 = x*math.cos(alpha) - y*math.sin(alpha) |
|
y2 = x*math.sin(alpha) + y*math.cos(alpha) |
|
x3 = x2*math.cos(beta) - z*math.sin(beta) |
|
z2 = x2*math.sin(beta) + z*math.cos(beta) |
|
y3 = y2*math.cos(gamma) - z2*math.sin(gamma) |
|
z3 = y2*math.sin(gamma) + z2*math.cos(gamma) |
|
rotated.append([x3,y3,z3]) |
|
except: |
|
rotated.append([float('nan'),float('nan'),float('nan')]) |
|
rotated=np.array(rotated) |
|
rotated=np.reshape(rotated, [np.shape(xyz)[0],14,3]) |
|
|
|
return rotated |
|
|
|
|
|
|
|
def find_contigs(mask): |
|
""" |
|
Find contiguous regions in a mask that are True with no False in between |
|
|
|
Parameters: |
|
mask (torch.tensor or np.array, required): 1D boolean array |
|
|
|
Returns: |
|
contigs (list): List of tuples, each tuple containing the beginning and the |
|
""" |
|
assert len(mask.shape) == 1 |
|
|
|
contigs = [] |
|
found_contig = False |
|
for i,b in enumerate(mask): |
|
|
|
|
|
if b and not found_contig: |
|
contig = [i] |
|
found_contig = True |
|
|
|
elif b and found_contig: |
|
pass |
|
|
|
elif not b and found_contig: |
|
contig.append(i) |
|
found_contig = False |
|
contigs.append(tuple(contig)) |
|
|
|
else: |
|
pass |
|
|
|
|
|
|
|
if b: |
|
contig.append(i+1) |
|
found_contig = False |
|
contigs.append(tuple(contig)) |
|
|
|
return contigs |
|
|
|
|
|
def reindex_chains(pdb_idx): |
|
""" |
|
Given a list of (chain, index) tuples, and the indices where chains break, create a reordered indexing |
|
|
|
Parameters: |
|
|
|
pdb_idx (list, required): List of tuples (chainID, index) |
|
|
|
breaks (list, required): List of indices where chains begin |
|
""" |
|
|
|
new_breaks, new_idx = [],[] |
|
current_chain = None |
|
|
|
chain_and_idx_to_torch = {} |
|
|
|
for i,T in enumerate(pdb_idx): |
|
|
|
chain, idx = T |
|
|
|
if chain != current_chain: |
|
new_breaks.append(i) |
|
current_chain = chain |
|
|
|
|
|
chain_and_idx_to_torch[chain] = {} |
|
|
|
|
|
chain_and_idx_to_torch[chain][idx] = i |
|
|
|
|
|
new_idx.append(i) |
|
|
|
new_idx = np.array(new_idx) |
|
|
|
num_additions = 0 |
|
for i in new_breaks[1:]: |
|
new_idx[np.where(new_idx==(i+ num_additions*500))[0][0]:] += 500 |
|
num_additions += 1 |
|
|
|
return new_idx, chain_and_idx_to_torch,new_breaks[1:] |
|
|
|
class ObjectView(object): |
|
''' |
|
Easy wrapper to access dictionary values with "dot" notiation instead |
|
''' |
|
def __init__(self, d): |
|
self.__dict__ = d |
|
|
|
def split_templates(xyz_t, t1d, multi_templates,mappings,multi_tmpl_conf=None): |
|
templates = multi_templates.split(":") |
|
if multi_tmpl_conf is not None: |
|
multi_tmpl_conf = [float(i) for i in multi_tmpl_conf.split(",")] |
|
assert len(templates) == len(multi_tmpl_conf), "Number of templates must equal number of confidences specified in --multi_tmpl_conf flag" |
|
for idx, template in enumerate(templates): |
|
parts = template.split(",") |
|
template_mask = torch.zeros(xyz_t.shape[2]).bool() |
|
for part in parts: |
|
start = int(part.split("-")[0][1:]) |
|
end = int(part.split("-")[1]) + 1 |
|
chain = part[0] |
|
for i in range(start, end): |
|
try: |
|
ref_pos = mappings['complex_con_ref_pdb_idx'].index((chain, i)) |
|
hal_pos_0 = mappings['complex_con_hal_idx0'][ref_pos] |
|
except: |
|
ref_pos = mappings['con_ref_pdb_idx'].index((chain, i)) |
|
hal_pos_0 = mappings['con_hal_idx0'][ref_pos] |
|
template_mask[hal_pos_0] = True |
|
|
|
xyz_t_temp = torch.clone(xyz_t) |
|
xyz_t_temp[:,:,~template_mask,:,:] = float('nan') |
|
t1d_temp = torch.clone(t1d) |
|
t1d_temp[:,:,~template_mask,:20] =0 |
|
t1d_temp[:,:,~template_mask,20] = 1 |
|
if multi_tmpl_conf is not None: |
|
t1d_temp[:,:,template_mask,21] = multi_tmpl_conf[idx] |
|
if idx != 0: |
|
xyz_t_out = torch.cat((xyz_t_out, xyz_t_temp),dim=1) |
|
t1d_out = torch.cat((t1d_out, t1d_temp),dim=1) |
|
else: |
|
xyz_t_out = xyz_t_temp |
|
t1d_out = t1d_temp |
|
return xyz_t_out, t1d_out |
|
|
|
|
|
class ContigMap(): |
|
''' |
|
New class for doing mapping. |
|
Supports multichain or multiple crops from a single receptor chain. |
|
Also supports indexing jump (+200) or not, based on contig input. |
|
Default chain outputs are inpainted chains as A (and B, C etc if multiple chains), and all fragments of receptor chain on the next one (generally B) |
|
Output chains can be specified. Sequence must be the same number of elements as in contig string |
|
''' |
|
def __init__(self, parsed_pdb, contigs=None, inpaint_seq=None, inpaint_str=None, length=None, ref_idx=None, hal_idx=None, idx_rf=None, inpaint_seq_tensor=None, inpaint_str_tensor=None, topo=False): |
|
|
|
if contigs is None and ref_idx is None: |
|
sys.exit("Must either specify a contig string or precise mapping") |
|
if idx_rf is not None or hal_idx is not None or ref_idx is not None: |
|
if idx_rf is None or hal_idx is None or ref_idx is None: |
|
sys.exit("If you're specifying specific contig mappings, the reference and output positions must be specified, AND the indexing for RoseTTAFold (idx_rf)") |
|
|
|
self.chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ' |
|
if length is not None: |
|
if '-' not in length: |
|
self.length = [int(length),int(length)+1] |
|
else: |
|
self.length = [int(length.split("-")[0]),int(length.split("-")[1])+1] |
|
else: |
|
self.length = None |
|
self.ref_idx = ref_idx |
|
self.hal_idx=hal_idx |
|
self.idx_rf=idx_rf |
|
self.inpaint_seq = ','.join(inpaint_seq).split(",") if inpaint_seq is not None else None |
|
self.inpaint_str = ','.join(inpaint_str).split(",") if inpaint_str is not None else None |
|
self.inpaint_seq_tensor=inpaint_seq_tensor |
|
self.inpaint_str_tensor=inpaint_str_tensor |
|
self.parsed_pdb = parsed_pdb |
|
self.topo=topo |
|
if ref_idx is None: |
|
|
|
self.contigs=contigs |
|
self.sampled_mask,self.contig_length,self.n_inpaint_chains = self.get_sampled_mask() |
|
self.receptor_chain = self.chain_order[self.n_inpaint_chains] |
|
self.receptor, self.receptor_hal, self.receptor_rf, self.inpaint, self.inpaint_hal, self.inpaint_rf= self.expand_sampled_mask() |
|
self.ref = self.inpaint + self.receptor |
|
self.hal = self.inpaint_hal + self.receptor_hal |
|
self.rf = self.inpaint_rf + self.receptor_rf |
|
else: |
|
|
|
self.ref=ref_idx |
|
self.hal=hal_idx |
|
self.rf = rf_idx |
|
self.mask_1d = [False if i == ('_','_') else True for i in self.ref] |
|
|
|
|
|
if self.inpaint_seq_tensor is None: |
|
if self.inpaint_seq is not None: |
|
self.inpaint_seq = self.get_inpaint_seq_str(self.inpaint_seq) |
|
else: |
|
self.inpaint_seq = np.array([True if i != ('_','_') else False for i in self.ref]) |
|
else: |
|
self.inpaint_seq = self.inpaint_seq_tensor |
|
|
|
if self.inpaint_str_tensor is None: |
|
if self.inpaint_str is not None: |
|
self.inpaint_str = self.get_inpaint_seq_str(self.inpaint_str) |
|
else: |
|
self.inpaint_str = np.array([True if i != ('_','_') else False for i in self.ref]) |
|
else: |
|
self.inpaint_str = self.inpaint_str_tensor |
|
|
|
self.ref_idx0,self.hal_idx0, self.ref_idx0_inpaint, self.hal_idx0_inpaint, self.ref_idx0_receptor, self.hal_idx0_receptor=self.get_idx0() |
|
|
|
def get_sampled_mask(self): |
|
''' |
|
Function to get a sampled mask from a contig. |
|
''' |
|
length_compatible=False |
|
count = 0 |
|
while length_compatible is False: |
|
inpaint_chains=0 |
|
contig_list = self.contigs |
|
sampled_mask = [] |
|
sampled_mask_length = 0 |
|
|
|
if all([i[0].isalpha() for i in contig_list[-1].split(",")]): |
|
contig_list[-1] = f'{contig_list[-1]},0' |
|
for con in contig_list: |
|
if ((all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0')) or self.topo is True: |
|
|
|
sampled_mask.append(con) |
|
else: |
|
inpaint_chains += 1 |
|
|
|
subcons = con.split(",") |
|
subcon_out = [] |
|
for subcon in subcons: |
|
if subcon[0].isalpha(): |
|
subcon_out.append(subcon) |
|
if '-' in subcon: |
|
sampled_mask_length += (int(subcon.split("-")[1])-int(subcon.split("-")[0][1:])+1) |
|
else: |
|
sampled_mask_length += 1 |
|
|
|
else: |
|
if '-' in subcon: |
|
length_inpaint=random.randint(int(subcon.split("-")[0]),int(subcon.split("-")[1])) |
|
subcon_out.append(f'{length_inpaint}-{length_inpaint}') |
|
sampled_mask_length += length_inpaint |
|
elif subcon == '0': |
|
subcon_out.append('0') |
|
else: |
|
length_inpaint=int(subcon) |
|
subcon_out.append(f'{length_inpaint}-{length_inpaint}') |
|
sampled_mask_length += int(subcon) |
|
sampled_mask.append(','.join(subcon_out)) |
|
|
|
if self.length is not None: |
|
if sampled_mask_length >= self.length[0] and sampled_mask_length < self.length[1]: |
|
length_compatible = True |
|
else: |
|
length_compatible = True |
|
count+=1 |
|
if count == 100000: |
|
sys.exit("Contig string incompatible with --length range") |
|
return sampled_mask, sampled_mask_length, inpaint_chains |
|
|
|
def expand_sampled_mask(self): |
|
chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ' |
|
receptor = [] |
|
inpaint = [] |
|
receptor_hal = [] |
|
inpaint_hal = [] |
|
receptor_idx = 1 |
|
inpaint_idx = 1 |
|
inpaint_chain_idx=-1 |
|
receptor_chain_break=[] |
|
inpaint_chain_break = [] |
|
for con in self.sampled_mask: |
|
if (all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0') or self.topo is True: |
|
|
|
subcons = con.split(",")[:-1] |
|
assert all([i[0] == subcons[0][0] for i in subcons]), "If specifying fragmented receptor in a single block of the contig string, they MUST derive from the same chain" |
|
assert all(int(subcons[i].split("-")[0][1:]) < int(subcons[i+1].split("-")[0][1:]) for i in range(len(subcons)-1)), "If specifying multiple fragments from the same chain, pdb indices must be in ascending order!" |
|
for idx, subcon in enumerate(subcons): |
|
ref_to_add = [(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)] |
|
receptor.extend(ref_to_add) |
|
receptor_hal.extend([(self.receptor_chain,i) for i in np.arange(receptor_idx, receptor_idx+len(ref_to_add))]) |
|
receptor_idx += len(ref_to_add) |
|
if idx != len(subcons)-1: |
|
idx_jump = int(subcons[idx+1].split("-")[0][1:]) - int(subcon.split("-")[1]) -1 |
|
receptor_chain_break.append((receptor_idx-1,idx_jump)) |
|
else: |
|
receptor_chain_break.append((receptor_idx-1,200)) |
|
else: |
|
inpaint_chain_idx += 1 |
|
for subcon in con.split(","): |
|
if subcon[0].isalpha(): |
|
ref_to_add=[(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)] |
|
inpaint.extend(ref_to_add) |
|
inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+len(ref_to_add))]) |
|
inpaint_idx += len(ref_to_add) |
|
|
|
else: |
|
inpaint.extend([('_','_')] * int(subcon.split("-")[0])) |
|
inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+int(subcon.split("-")[0]))]) |
|
inpaint_idx += int(subcon.split("-")[0]) |
|
inpaint_chain_break.append((inpaint_idx-1,200)) |
|
|
|
if self.topo is True or inpaint_hal == []: |
|
receptor_hal = [(i[0], i[1]) for i in receptor_hal] |
|
else: |
|
receptor_hal = [(i[0], i[1] + inpaint_hal[-1][1]) for i in receptor_hal] |
|
|
|
inpaint_rf = np.arange(0,len(inpaint)) |
|
receptor_rf = np.arange(len(inpaint)+200,len(inpaint)+len(receptor)+200) |
|
for ch_break in inpaint_chain_break[:-1]: |
|
receptor_rf[:] += 200 |
|
inpaint_rf[ch_break[0]:] += ch_break[1] |
|
for ch_break in receptor_chain_break[:-1]: |
|
receptor_rf[ch_break[0]:] += ch_break[1] |
|
|
|
return receptor, receptor_hal, receptor_rf.tolist(), inpaint, inpaint_hal, inpaint_rf.tolist() |
|
|
|
def get_inpaint_seq_str(self, inpaint_s): |
|
''' |
|
function to generate inpaint_str or inpaint_seq masks specific to this contig |
|
''' |
|
s_mask = np.copy(self.mask_1d) |
|
inpaint_s_list = [] |
|
for i in inpaint_s: |
|
if '-' in i: |
|
inpaint_s_list.extend([(i[0],p) for p in range(int(i.split("-")[0][1:]), int(i.split("-")[1])+1)]) |
|
else: |
|
inpaint_s_list.append((i[0],int(i[1:]))) |
|
for res in inpaint_s_list: |
|
if res in self.ref: |
|
s_mask[self.ref.index(res)] = False |
|
|
|
return np.array(s_mask) |
|
|
|
def get_idx0(self): |
|
ref_idx0=[] |
|
hal_idx0=[] |
|
ref_idx0_inpaint=[] |
|
hal_idx0_inpaint=[] |
|
ref_idx0_receptor=[] |
|
hal_idx0_receptor=[] |
|
for idx, val in enumerate(self.ref): |
|
if val != ('_','_'): |
|
assert val in self.parsed_pdb['pdb_idx'],f"{val} is not in pdb file!" |
|
hal_idx0.append(idx) |
|
ref_idx0.append(self.parsed_pdb['pdb_idx'].index(val)) |
|
for idx, val in enumerate(self.inpaint): |
|
if val != ('_','_'): |
|
hal_idx0_inpaint.append(idx) |
|
ref_idx0_inpaint.append(self.parsed_pdb['pdb_idx'].index(val)) |
|
for idx, val in enumerate(self.receptor): |
|
if val != ('_','_'): |
|
hal_idx0_receptor.append(idx) |
|
ref_idx0_receptor.append(self.parsed_pdb['pdb_idx'].index(val)) |
|
|
|
|
|
return ref_idx0, hal_idx0, ref_idx0_inpaint, hal_idx0_inpaint, ref_idx0_receptor, hal_idx0_receptor |
|
|
|
def get_mappings(rm): |
|
mappings = {} |
|
mappings['con_ref_pdb_idx'] = [i for i in rm.inpaint if i != ('_','_')] |
|
mappings['con_hal_pdb_idx'] = [rm.inpaint_hal[i] for i in range(len(rm.inpaint_hal)) if rm.inpaint[i] != ("_","_")] |
|
mappings['con_ref_idx0'] = rm.ref_idx0_inpaint |
|
mappings['con_hal_idx0'] = rm.hal_idx0_inpaint |
|
if rm.inpaint != rm.ref: |
|
mappings['complex_con_ref_pdb_idx'] = [i for i in rm.ref if i != ("_","_")] |
|
mappings['complex_con_hal_pdb_idx'] = [rm.hal[i] for i in range(len(rm.hal)) if rm.ref[i] != ("_","_")] |
|
mappings['receptor_con_ref_pdb_idx'] = [i for i in rm.receptor if i != ("_","_")] |
|
mappings['receptor_con_hal_pdb_idx'] = [rm.receptor_hal[i] for i in range(len(rm.receptor_hal)) if rm.receptor[i] != ("_","_")] |
|
mappings['complex_con_ref_idx0'] = rm.ref_idx0 |
|
mappings['complex_con_hal_idx0'] = rm.hal_idx0 |
|
mappings['receptor_con_ref_idx0'] = rm.ref_idx0_receptor |
|
mappings['receptor_con_hal_idx0'] = rm.hal_idx0_receptor |
|
mappings['inpaint_str'] = rm.inpaint_str |
|
mappings['inpaint_seq'] = rm.inpaint_seq |
|
mappings['sampled_mask'] = rm.sampled_mask |
|
mappings['mask_1d'] = rm.mask_1d |
|
return mappings |
|
|
|
def lddt_unbin(pred_lddt): |
|
nbin = pred_lddt.shape[1] |
|
bin_step = 1.0 / nbin |
|
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device) |
|
|
|
pred_lddt = nn.Softmax(dim=1)(pred_lddt) |
|
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) |
|
|
|
|