|
|
|
|
|
|
|
|
|
import glob
|
|
import pickle
|
|
import random
|
|
import re
|
|
from argparse import ArgumentParser
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
from Bio import SeqIO
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
import numpy as np
|
|
import os
|
|
import pandas as pd
|
|
class bcolors:
|
|
RED = "\033[1;31m"
|
|
BLUE = "\033[1;34m"
|
|
CYAN = "\033[1;36m"
|
|
GREEN = "\033[0;32m"
|
|
RESET = "\033[0;0m"
|
|
BOLD = "\033[;1m"
|
|
REVERSE = "\033[;7m"
|
|
|
|
|
|
def check_path(dirout,file=False):
|
|
if file:dirout = dirout.rsplit('/',1)[0]
|
|
try:
|
|
if not os.path.exists(dirout):
|
|
print('make dir '+dirout)
|
|
os.makedirs(dirout)
|
|
except:
|
|
print(f'{dirout} have been made by other process')
|
|
|
|
def load_label_pred(fin_label,fin_pred):
|
|
with open(fin_label, 'rb') as f:
|
|
df_label = pickle.load(f)
|
|
df_label = df_label.squeeze()
|
|
df_pred = pd.read_table(fin_pred, comment='#', index_col=[0])
|
|
if type(df_label) == pd.DataFrame:
|
|
df_pred.index = df_label.index
|
|
df_pred.columns = df_label.columns
|
|
|
|
df_label = df_label.dropna(how='all')
|
|
|
|
|
|
df_label = df_label.dropna(axis=1, how='all')
|
|
df_pred = df_pred.loc[df_label.index, df_label.columns]
|
|
keep=0
|
|
if df_pred.columns[0].count('.')==2:
|
|
keep=-1
|
|
df_pred.columns = [e.split('.')[keep] + str(i+1) for i, e in enumerate(df_pred.columns)]
|
|
df_pred.index = [e.split('.')[keep] + str(i+1) for i, e in enumerate(df_pred.index)]
|
|
return df_label,df_pred
|
|
def doSavePredict(_id,seq,predict,fout,des):
|
|
|
|
|
|
check_path(fout)
|
|
df = pd.DataFrame(predict)
|
|
if not seq:df.to_csv(fout+ f'{_id}.txt',sep='\t',mode='w',float_format='%.5f')
|
|
else:
|
|
df.columns = [f'{elem}{index+1}' for index,elem in enumerate(seq['protein'])]
|
|
df.index = [f'{elem}{index+1}' for index,elem in enumerate(seq['rna'])]
|
|
with open(fout+ f'{_id}.txt','w') as f:
|
|
f.write(f'#{des}\n')
|
|
f.write(f"# row =rna:{seq['rna']}\n")
|
|
f.write(f"# col=protein:{seq['protein']}\n")
|
|
|
|
df.to_csv(fout+ f'{_id}.txt',sep='\t',mode='a',float_format='%.5f')
|
|
|
|
df = get_top_l_triplets(df, sum(df.shape))
|
|
df.to_csv(fout+ f'{_id}_topL.txt',sep='\t',mode='w',float_format='%.5f',index=False)
|
|
|
|
|
|
|
|
def get_top_l_triplets(df_pred, L):
|
|
"""
|
|
从Pandas DataFrame矩阵中提取值最大的前L个三元组。
|
|
|
|
参数:
|
|
- matrix_df: Pandas DataFrame,表示接触矩阵。
|
|
- L: int,要提取的三元组的数量。
|
|
|
|
返回:
|
|
- top_l_triplets: 列表,包含前L个三元组,每个三元组格式为(row_index, col_index, value)。
|
|
"""
|
|
df = df_pred.stack().reset_index()
|
|
df.columns = ['rna', 'protein', 'pred']
|
|
df = df.sort_values(by='pred', ascending=False).head(L)
|
|
return df
|
|
|
|
def doSavePredict_single(_id,seq,predict_rsa,fout,des,pred_asa=None):
|
|
check_path(fout)
|
|
BASES = 'AUCG'
|
|
asa_std = [400, 350, 350, 400]
|
|
dict_rnam1_ASA = dict(zip(BASES, asa_std))
|
|
sequence = re.sub(r"[T]", "U", ''.join(seq))
|
|
sequence = re.sub(r"[^AGCU]", BASES[random.randint(0, 3)], sequence)
|
|
ASA_scale = np.array([dict_rnam1_ASA[i] for i in sequence])
|
|
|
|
if pred_asa is None:
|
|
pred_asa = np.multiply(predict_rsa, ASA_scale).T
|
|
else:
|
|
predict_rsa = pred_asa/ASA_scale
|
|
col1 = np.array([i + 1 for i, I in enumerate(seq)])[None, :]
|
|
col2 = np.array([I for i, I in enumerate(seq)])[None, :]
|
|
col3 = pred_asa
|
|
col4 = predict_rsa
|
|
if len(col3[col3 == 0]):
|
|
exit(f'error in predict\t {_id},{seq}')
|
|
temp = np.vstack((np.char.mod('%d', col1), col2, np.char.mod('%.2f', col3), np.char.mod('%.3f', col4))).T
|
|
if fout:np.savetxt(fout + f'{_id}.txt', (temp), delimiter='\t\t', fmt="%s",
|
|
header=f'#{des}',
|
|
comments='')
|
|
|
|
return pred_asa,predict_rsa
|
|
|
|
def one_hot_encode(sequences,alpha='ACGU'):
|
|
|
|
sequences_arry = np.array(list(sequences)).reshape(-1, 1)
|
|
lable = np.array(list(alpha)).reshape(-1, 1)
|
|
enc = OneHotEncoder(handle_unknown='ignore')
|
|
enc.fit(lable)
|
|
seq_encode = enc.transform(sequences_arry).toarray()
|
|
|
|
return (seq_encode)
|
|
|
|
def get_bin_pred(df_pred,threshold):
|
|
bin_pred = df_pred.values >= threshold
|
|
bin_pred = bin_pred.astype(int)
|
|
return bin_pred
|
|
|
|
def seed_everything(seed=2022):
|
|
print('seed_everything to ',seed)
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
def getParam():
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument('--rootdir', default='',
|
|
type=str)
|
|
parser.add_argument('--rna_fasta', default='./example/inputs_batch/rna.fasta',
|
|
type=str)
|
|
|
|
parser.add_argument('--pro_fasta', default='./example/inputs_batch/protein.fasta',
|
|
type=str)
|
|
|
|
parser.add_argument('--csv', default='./example/inputs_batch/pairs.csv',
|
|
type=str)
|
|
parser.add_argument('--col', default='_id',
|
|
type=str)
|
|
parser.add_argument('--out', default='./example/outputs_batch/',
|
|
type=str)
|
|
parser.add_argument('--ffeat', default='./example/inputs_batch/embedding/{element}/{pdbid}.pickle',
|
|
type=str)
|
|
parser.add_argument('--fmodel', default='./weight/model_roc_0_38=0.845.pt',
|
|
type=str)
|
|
parser.add_argument('--device', default='cpu',
|
|
type=str)
|
|
parser.add_argument('--draw', action='store_true')
|
|
args = parser.parse_args()
|
|
return args
|
|
if __name__ == '__main__':
|
|
args = getParam()
|
|
rootdir = args.rootdir
|
|
csv = args.csv
|
|
col = args.col
|
|
rna_fasta = args.rna_fasta
|
|
pro_fasta = args.pro_fasta
|
|
ffeat = args.ffeat
|
|
fmodel = args.fmodel
|
|
device = args.device
|
|
out = args.out
|
|
draw = args.draw
|
|
check_path(out)
|
|
|
|
|
|
seed_everything(seed=2022)
|
|
|
|
models = [(model_path,torch.load(model_path, map_location=torch.device(device))) for model_path in glob.glob(fmodel)]
|
|
|
|
print('loading existed model', fmodel)
|
|
with torch.no_grad():
|
|
rna_dict = {}
|
|
for pdbid, seq in [(record.id, record.seq) for record in SeqIO.parse(rna_fasta, 'fasta')]:
|
|
rna_dict[pdbid]=str(seq)
|
|
pro_dict = {}
|
|
for pdbid, seq in [(record.id, record.seq) for record in SeqIO.parse(pro_fasta, 'fasta')]:
|
|
pro_dict[pdbid]=str(seq)
|
|
|
|
df = pd.read_csv(csv)
|
|
predict_scores = []
|
|
for pdbid in df[col]:
|
|
|
|
|
|
|
|
|
|
rnaid,proid = pdbid.split('.')
|
|
|
|
rnaseq,proseq= rna_dict[rnaid],pro_dict[proid]
|
|
|
|
with open(ffeat.format_map({'pdbid':rnaid,'element':'rna'}),'rb') as f:
|
|
rna_emb = pickle.load(f)
|
|
with open(ffeat.format_map({'pdbid':proid,'element':'protein'}),'rb') as f:
|
|
pro_emb = pickle.load(f)
|
|
|
|
rna_oh = one_hot_encode(rnaseq.replace('T','U'), alpha='ACGU')
|
|
pro_oh = one_hot_encode(proseq, alpha='GAVLIFWYDNEKQMSTCPHR')
|
|
|
|
|
|
x_train = np.concatenate([rna_oh,rna_emb],axis=1)
|
|
x_train = np.expand_dims(x_train,0)
|
|
x_train = torch.from_numpy(x_train).transpose(-1,-2)
|
|
x_train = x_train.to(device, dtype=torch.float)
|
|
x_rna = x_train
|
|
|
|
x_train = np.concatenate([pro_oh, pro_emb], axis=1)
|
|
x_train = np.expand_dims(x_train, 0)
|
|
x_train = torch.from_numpy(x_train).transpose(-1, -2)
|
|
x_train = x_train.to(device, dtype=torch.float)
|
|
x_pro = x_train
|
|
|
|
|
|
|
|
x_rna = x_rna.to(device, dtype=torch.float32)
|
|
x_pro = x_pro.to(device, dtype=torch.float32)
|
|
|
|
for i,(model_path,model) in enumerate(models):
|
|
model.eval()
|
|
outputs = model(x_pro, x_rna)
|
|
|
|
outputs = torch.squeeze(outputs, -1)
|
|
outputs = outputs.permute(0, 2, 1)
|
|
|
|
df_pred = outputs[0].cpu().detach().numpy()
|
|
des = f'predict by {__file__}\n#{model_path}'
|
|
doSavePredict(pdbid, {'rna':rnaseq,'protein':proseq}, df_pred,
|
|
out,
|
|
des
|
|
)
|
|
|
|
tmp = df_pred.flatten()
|
|
tmp.sort()
|
|
score = sum(tmp[::-1][:sum(df_pred.shape)])
|
|
predict_scores.append((pdbid, score))
|
|
print(pdbid,score)
|
|
|
|
if draw:
|
|
plt.figure(figsize=(20, 15))
|
|
top = sum(df_pred.shape)
|
|
df_pred = pd.DataFrame(df_pred)
|
|
threshold = df_pred.stack().nlargest(top).iloc[-1]
|
|
bin_pred = get_bin_pred(df_pred,threshold=threshold)
|
|
|
|
import seaborn as sns
|
|
sns.heatmap(df_pred,mask=bin_pred)
|
|
plt.title(f'Predicted contact map of {pdbid}\nPredidcted by RPcontact, top L=r+p')
|
|
plt.xlabel(proid)
|
|
plt.ylabel(rnaid)
|
|
handles, labels = plt.gca().get_legend_handles_labels()
|
|
|
|
plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, borderaxespad=1,
|
|
frameon=False)
|
|
|
|
ax = plt.gca()
|
|
ax.set_aspect('equal')
|
|
plt.tight_layout()
|
|
plt.savefig(f'{out}/{pdbid}_{i}_prob.png',dpi=300)
|
|
plt.show()
|
|
|
|
plt.clf()
|
|
ax = plt.gca()
|
|
tp = \
|
|
ax.plot(*np.where(bin_pred.T==1), ".", c='r',markersize=1, label='Predicted contact')[
|
|
0]
|
|
tp.set_markerfacecolor('w')
|
|
tp.set_markeredgecolor('r')
|
|
h,w = bin_pred.shape
|
|
plt.xlim([0,w])
|
|
plt.ylim([0,h])
|
|
plt.title(f'Predicted contact map of {pdbid}\nPredidcted by RPcontact, top L=r+p')
|
|
plt.xlabel(proid)
|
|
plt.ylabel(rnaid)
|
|
handles, labels = plt.gca().get_legend_handles_labels()
|
|
|
|
plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, borderaxespad=1,
|
|
frameon=False)
|
|
|
|
|
|
ax.set_aspect('equal')
|
|
plt.tight_layout()
|
|
plt.savefig(f'{out}/{pdbid}_{i}_binary.png',dpi=300)
|
|
plt.show()
|
|
|
|
|
|
print(f'predict {pdbid} with {len(seq)} nts')
|
|
|
|
df = pd.DataFrame(predict_scores, columns=['pdbid', 'contact_score'])
|
|
df.to_csv(args.out + '/predict_scores.tsv',index=False, sep='\t', mode='w', float_format='%.5f')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|