RPcontact / evaluate.py
julse's picture
Upload 23 files
82d55c6 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by: julse@qq.com
# des : evaluate RPcontact
import glob
import os
import pickle
import random
from argparse import ArgumentParser
import matplotlib.pyplot as plt
import pandas as pd
import torch
from Bio import SeqIO
from sklearn.preprocessing import OneHotEncoder
import numpy as np
from predict import check_path, one_hot_encode, get_bin_pred, doSavePredict
def get_bin_label(df_label,distance_cutoff):
bin_label = df_label < distance_cutoff
bin_label = bin_label.astype(int)
return bin_label
def view_evaluate_contact_prob(df_label, bin_pred,ax=None,markersize=5):
confusing_matrix = np.zeros_like(df_label)
r, p = confusing_matrix.shape
if ax is None:
ax = plt
ax.xlim([-2, p + 2])
ax.ylim([-2, r + 2])
# plt.xticks(rotation=90)
else:
ax.set_xlim([-2, p + 2])
ax.set_ylim([-2, r + 2])
# plt.setp(ax.get_xticklabels(), rotation=90)
ax.set_title('performance')
colors = [
'#f5e0c4', # lightblue for FP
# '#aaa6ce','#66609c','k',# light purple, dark purple,black, for Groud truth
'#b0d9db','#61b3b6','k',# light purple, dark purple,black, for Groud truth
'#ecbbd8','#9d4e7d','r' # for TP
]
tps = []
bin_label = df_label<8
temp = bin_pred - bin_label
fn = ax.plot(*np.where(temp.T == 1), ".", c=colors[0], markersize=markersize,label='False Positive')[0]
# 绘制NaN值的数据点为灰色
oc = ax.plot(*np.where(df_label.T.isna()), ".", c='gray', markersize=markersize, label='Missing in PDB')[0]
confusing_matrix[bin_label == 1] = 1 #ground truth
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[1],markersize=markersize, label='Ground truth (8Å)')[0]
temp = bin_label + bin_pred
tps.append(len(confusing_matrix[np.where(temp == 2)]))
confusing_matrix[np.where(temp == 2)] = 2 # TP : blue
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[4],markersize=markersize, label='True Positive (8Å)')[0]
tp.set_markerfacecolor(colors[1])
tp.set_markeredgecolor(colors[4])
bin_label = df_label<5
temp = bin_label + bin_pred
tps.append(len(confusing_matrix[np.where(temp == 2)]))
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[2],markersize=markersize, label='Ground truth (5Å)')[0]
confusing_matrix[np.where(temp == 2)] = 2 # TP : blue
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[5],markersize=markersize, label='True Positive (5Å)')[0]
tp.set_markerfacecolor(colors[2])
tp.set_markeredgecolor(colors[5])
bin_label = df_label<3.5
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[3],markersize=markersize, label='Ground truth (3.5Å)')[0]
temp = bin_label + bin_pred
tps.append(len(confusing_matrix[np.where(temp == 2)]))
confusing_matrix[np.where(temp == 2)] = 2 # TP : blue
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[6],markersize=markersize, label='True Positive (3.5Å)')[0]
tp.set_markerfacecolor(colors[3])
tp.set_markeredgecolor(colors[6])
# ax.legend()
# plt.show()
# tp = len(confusing_matrix[np.where(temp == 2)])
print(len(confusing_matrix[np.where(temp == 2)]))
return '/'.join([str(e) for e in tps[::-1]]),confusing_matrix
def seed_everything(seed=2022):
print('seed_everything to ',seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 程序每次运行结果一致,但是程序中多次生成随机数每次不一致 # https://blog.csdn.net/qq_42951560/article/details/112174334
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # minbatch的长度一直在变化,这个优化比较浪费时间
def getParam():
parser = ArgumentParser()
# data
parser.add_argument('--rootdir', default='',
type=str)
parser.add_argument('--fasta', default='./example/inputs/8DMB_W.8DMB_P.fasta',
type=str)
parser.add_argument('--out', default='./example/outputs/',
type=str)
parser.add_argument('--ffeat', default='./example/inputs/{pdbid}.pickle',
type=str)
parser.add_argument('--fmodel', default='./weight/model_roc_0_38=0.845.pt',
type=str)
parser.add_argument('--device', default='cpu',
type=str)
parser.add_argument('--flabel', default='./example/inputs/{pdbid}.pickle',
type=str)
parser.add_argument('--draw', default=True,
type=bool)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = getParam()
rootdir = args.rootdir
fasta = args.fasta
ffeat = args.ffeat
fmodel = args.fmodel
device = args.device
flabel = args.flabel
draw = args.draw
out = args.out
check_path(out)
# pdbid = fasta.rsplit('/',1)[0].split('.')[0]
seed_everything(seed=2022)
models = [(model_path,torch.load(model_path, map_location=torch.device(device))) for model_path in glob.glob(fmodel)]
print('loading existed model', fmodel)
with torch.no_grad():
for pdbid,seq in [(record.id,record.seq) for record in SeqIO.parse(fasta,'fasta')]:
rnaid,proid= pdbid.split('.')
rnaseq,proseq= seq.split('.')
with open(ffeat.format_map({'pdbid':rnaid}),'rb') as f:
rna_emb = pickle.load(f)
with open(ffeat.format_map({'pdbid':proid}),'rb') as f:
pro_emb = pickle.load(f)
rna_oh = one_hot_encode(rnaseq, alpha='ACGU')
pro_oh = one_hot_encode(proseq, alpha='GAVLIFWYDNEKQMSTCPHR')
# mask = np.ones((emb.shape[0],1)) # mask missing nt when evaluate the model
x_train = np.concatenate([rna_oh,rna_emb],axis=1)
x_train = np.expand_dims(x_train,0)
x_train = torch.from_numpy(x_train).transpose(-1,-2)
x_train = x_train.to(device, dtype=torch.float)
x_rna = x_train
x_train = np.concatenate([pro_oh, pro_emb], axis=1)
x_train = np.expand_dims(x_train, 0)
x_train = torch.from_numpy(x_train).transpose(-1, -2)
x_train = x_train.to(device, dtype=torch.float)
x_pro = x_train
print('input data shape for rna and protein:',x_rna.shape,x_pro.shape)
x_rna = x_rna.to(device, dtype=torch.float32)
x_pro = x_pro.to(device, dtype=torch.float32)
plt.figure(figsize=(20, 15))
for i,(model_path,model) in enumerate(models):
model.eval()
outputs = model(x_pro, x_rna) # [1, 299, 74, 1]
# print('outputs,',outputs.device)
outputs = torch.squeeze(outputs, -1)
outputs = outputs.permute(0, 2, 1)
df_pred = outputs[0].cpu().detach().numpy()
# seq = data._seq[pdbid] if pdbid in data._seq else None
des = f'predict by {__file__}\n#{model_path}'
doSavePredict(pdbid, {'rna':rnaseq,'protein':proseq}, df_pred,
out,
des
)
top = sum(df_pred.shape)
df_pred = pd.DataFrame(df_pred)
threshold = df_pred.stack().nlargest(top).iloc[-1]
if draw:
with open(flabel.format_map({'pdbid': pdbid}), 'rb') as f:
df_label = pickle.load(f)
df_label = df_label.squeeze()
bin_pred = get_bin_pred(df_pred, threshold=threshold)
view_evaluate_contact_prob(df_label, bin_pred, ax=None)
plt.title(f'Predicted contact map of {pdbid}\nPredidcted by RPcontact, top L=r+p')
plt.xlabel(proid)
plt.ylabel(rnaid)
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, borderaxespad=1,
frameon=False)
# 设置坐标轴的相同缩放
ax = plt.gca()
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig(f'{out}/{pdbid}_{i}_evaluate.png',dpi=900)
plt.show()
print(f'predict {pdbid} with {len(seq)} nts')