|
|
|
|
|
|
|
|
|
import glob
|
|
import os
|
|
import pickle
|
|
import random
|
|
from argparse import ArgumentParser
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
import torch
|
|
from Bio import SeqIO
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
import numpy as np
|
|
|
|
from predict import check_path, one_hot_encode, get_bin_pred, doSavePredict
|
|
|
|
|
|
def get_bin_label(df_label,distance_cutoff):
|
|
bin_label = df_label < distance_cutoff
|
|
bin_label = bin_label.astype(int)
|
|
return bin_label
|
|
|
|
def view_evaluate_contact_prob(df_label, bin_pred,ax=None,markersize=5):
|
|
confusing_matrix = np.zeros_like(df_label)
|
|
r, p = confusing_matrix.shape
|
|
if ax is None:
|
|
ax = plt
|
|
ax.xlim([-2, p + 2])
|
|
ax.ylim([-2, r + 2])
|
|
|
|
else:
|
|
|
|
ax.set_xlim([-2, p + 2])
|
|
ax.set_ylim([-2, r + 2])
|
|
|
|
ax.set_title('performance')
|
|
|
|
colors = [
|
|
'#f5e0c4',
|
|
|
|
'#b0d9db','#61b3b6','k',
|
|
'#ecbbd8','#9d4e7d','r'
|
|
|
|
]
|
|
tps = []
|
|
bin_label = df_label<8
|
|
temp = bin_pred - bin_label
|
|
fn = ax.plot(*np.where(temp.T == 1), ".", c=colors[0], markersize=markersize,label='False Positive')[0]
|
|
|
|
oc = ax.plot(*np.where(df_label.T.isna()), ".", c='gray', markersize=markersize, label='Missing in PDB')[0]
|
|
confusing_matrix[bin_label == 1] = 1
|
|
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[1],markersize=markersize, label='Ground truth (8Å)')[0]
|
|
temp = bin_label + bin_pred
|
|
tps.append(len(confusing_matrix[np.where(temp == 2)]))
|
|
confusing_matrix[np.where(temp == 2)] = 2
|
|
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[4],markersize=markersize, label='True Positive (8Å)')[0]
|
|
tp.set_markerfacecolor(colors[1])
|
|
tp.set_markeredgecolor(colors[4])
|
|
|
|
bin_label = df_label<5
|
|
temp = bin_label + bin_pred
|
|
tps.append(len(confusing_matrix[np.where(temp == 2)]))
|
|
|
|
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[2],markersize=markersize, label='Ground truth (5Å)')[0]
|
|
confusing_matrix[np.where(temp == 2)] = 2
|
|
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[5],markersize=markersize, label='True Positive (5Å)')[0]
|
|
tp.set_markerfacecolor(colors[2])
|
|
tp.set_markeredgecolor(colors[5])
|
|
bin_label = df_label<3.5
|
|
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[3],markersize=markersize, label='Ground truth (3.5Å)')[0]
|
|
temp = bin_label + bin_pred
|
|
tps.append(len(confusing_matrix[np.where(temp == 2)]))
|
|
|
|
confusing_matrix[np.where(temp == 2)] = 2
|
|
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[6],markersize=markersize, label='True Positive (3.5Å)')[0]
|
|
tp.set_markerfacecolor(colors[3])
|
|
tp.set_markeredgecolor(colors[6])
|
|
|
|
|
|
|
|
|
|
print(len(confusing_matrix[np.where(temp == 2)]))
|
|
return '/'.join([str(e) for e in tps[::-1]]),confusing_matrix
|
|
def seed_everything(seed=2022):
|
|
print('seed_everything to ',seed)
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def getParam():
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument('--rootdir', default='',
|
|
type=str)
|
|
parser.add_argument('--fasta', default='./example/inputs/8DMB_W.8DMB_P.fasta',
|
|
type=str)
|
|
parser.add_argument('--out', default='./example/outputs/',
|
|
type=str)
|
|
parser.add_argument('--ffeat', default='./example/inputs/{pdbid}.pickle',
|
|
type=str)
|
|
parser.add_argument('--fmodel', default='./weight/model_roc_0_38=0.845.pt',
|
|
type=str)
|
|
parser.add_argument('--device', default='cpu',
|
|
type=str)
|
|
parser.add_argument('--flabel', default='./example/inputs/{pdbid}.pickle',
|
|
type=str)
|
|
parser.add_argument('--draw', default=True,
|
|
type=bool)
|
|
args = parser.parse_args()
|
|
return args
|
|
if __name__ == '__main__':
|
|
args = getParam()
|
|
rootdir = args.rootdir
|
|
fasta = args.fasta
|
|
ffeat = args.ffeat
|
|
fmodel = args.fmodel
|
|
device = args.device
|
|
flabel = args.flabel
|
|
draw = args.draw
|
|
out = args.out
|
|
check_path(out)
|
|
|
|
|
|
seed_everything(seed=2022)
|
|
models = [(model_path,torch.load(model_path, map_location=torch.device(device))) for model_path in glob.glob(fmodel)]
|
|
print('loading existed model', fmodel)
|
|
with torch.no_grad():
|
|
for pdbid,seq in [(record.id,record.seq) for record in SeqIO.parse(fasta,'fasta')]:
|
|
rnaid,proid= pdbid.split('.')
|
|
rnaseq,proseq= seq.split('.')
|
|
|
|
with open(ffeat.format_map({'pdbid':rnaid}),'rb') as f:
|
|
rna_emb = pickle.load(f)
|
|
with open(ffeat.format_map({'pdbid':proid}),'rb') as f:
|
|
pro_emb = pickle.load(f)
|
|
|
|
rna_oh = one_hot_encode(rnaseq, alpha='ACGU')
|
|
pro_oh = one_hot_encode(proseq, alpha='GAVLIFWYDNEKQMSTCPHR')
|
|
|
|
|
|
x_train = np.concatenate([rna_oh,rna_emb],axis=1)
|
|
x_train = np.expand_dims(x_train,0)
|
|
x_train = torch.from_numpy(x_train).transpose(-1,-2)
|
|
x_train = x_train.to(device, dtype=torch.float)
|
|
x_rna = x_train
|
|
|
|
x_train = np.concatenate([pro_oh, pro_emb], axis=1)
|
|
x_train = np.expand_dims(x_train, 0)
|
|
x_train = torch.from_numpy(x_train).transpose(-1, -2)
|
|
x_train = x_train.to(device, dtype=torch.float)
|
|
x_pro = x_train
|
|
|
|
print('input data shape for rna and protein:',x_rna.shape,x_pro.shape)
|
|
|
|
x_rna = x_rna.to(device, dtype=torch.float32)
|
|
x_pro = x_pro.to(device, dtype=torch.float32)
|
|
plt.figure(figsize=(20, 15))
|
|
for i,(model_path,model) in enumerate(models):
|
|
model.eval()
|
|
outputs = model(x_pro, x_rna)
|
|
|
|
outputs = torch.squeeze(outputs, -1)
|
|
outputs = outputs.permute(0, 2, 1)
|
|
|
|
df_pred = outputs[0].cpu().detach().numpy()
|
|
|
|
des = f'predict by {__file__}\n#{model_path}'
|
|
doSavePredict(pdbid, {'rna':rnaseq,'protein':proseq}, df_pred,
|
|
out,
|
|
des
|
|
)
|
|
top = sum(df_pred.shape)
|
|
df_pred = pd.DataFrame(df_pred)
|
|
threshold = df_pred.stack().nlargest(top).iloc[-1]
|
|
if draw:
|
|
with open(flabel.format_map({'pdbid': pdbid}), 'rb') as f:
|
|
df_label = pickle.load(f)
|
|
df_label = df_label.squeeze()
|
|
bin_pred = get_bin_pred(df_pred, threshold=threshold)
|
|
view_evaluate_contact_prob(df_label, bin_pred, ax=None)
|
|
plt.title(f'Predicted contact map of {pdbid}\nPredidcted by RPcontact, top L=r+p')
|
|
plt.xlabel(proid)
|
|
plt.ylabel(rnaid)
|
|
handles, labels = plt.gca().get_legend_handles_labels()
|
|
|
|
plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, borderaxespad=1,
|
|
frameon=False)
|
|
|
|
ax = plt.gca()
|
|
ax.set_aspect('equal')
|
|
plt.tight_layout()
|
|
plt.savefig(f'{out}/{pdbid}_{i}_evaluate.png',dpi=900)
|
|
plt.show()
|
|
print(f'predict {pdbid} with {len(seq)} nts')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|