AllerTrans / app.py
Faezeh Sarlakifar
Initial upload of AllerTrans app
326d9e6
raw
history blame
1.81 kB
import torch
import gradio as gr
import numpy as np
from transformers import T5Tokenizer, T5EncoderModel
import esm
from inference import load_models, predict_ensemble
# Load trained models
model_protT5, model_cat = load_models()
# Load ProtT5 model
tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model_t5 = model_t5.eval()
# Load ESM model
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
def extract_prott5_embedding(sequence):
sequence = sequence.replace(" ", "")
seq = " ".join(list(sequence))
ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
with torch.no_grad():
embedding = model_t5(**ids).last_hidden_state
return torch.mean(embedding, dim=1)
def extract_esm_embedding(sequence):
batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence)])
with torch.no_grad():
results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
token_representations = results["representations"][33]
return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
def classify(sequence):
protT5_emb = extract_prott5_embedding(sequence)
esm_emb = extract_esm_embedding(sequence)
concat = torch.cat((esm_emb, protT5_emb), dim=1)
pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
return "Allergen" if pred.item() == 1 else "Non-Allergen"
demo = gr.Interface(fn=classify,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
outputs=gr.Label(label="Prediction"))
if __name__ == "__main__":
demo.launch()