AllerTrans / app.py
sfaezella's picture
Add full model description and caution note to Gradio app
bd47e61 verified
raw
history blame
3.98 kB
import torch
import gradio as gr
import numpy as np
from transformers import T5Tokenizer, T5EncoderModel
import esm
from inference import load_models, predict_ensemble
from transformers import AutoTokenizer, AutoModel
import spaces
# Load trained models
model_protT5, model_cat = load_models()
# Load ProtT5 model
tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model_t5 = model_t5.eval()
# Load the tokenizer and model
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
esm_model = AutoModel.from_pretrained(model_name)
def extract_prott5_embedding(sequence):
sequence = sequence.replace(" ", "")
seq = " ".join(list(sequence))
ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
with torch.no_grad():
embedding = model_t5(**ids).last_hidden_state
return torch.mean(embedding, dim=1)
# Extract ESM2 embedding
def extract_esm_embedding(sequence):
# Tokenize the sequence
inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
# Forward pass through the model
with torch.no_grad():
outputs = esm_model(**inputs)
# Extract the embeddings from the 33rd layer (ESM2 layer)
token_representations = outputs.last_hidden_state # This is the default layer
return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
def estimate_duration(sequence):
# Estimate duration based on sequence length
base_time = 30 # Base time in seconds
time_per_residue = 0.5 # Estimated time per residue
estimated_time = base_time + len(sequence) * time_per_residue
return min(int(estimated_time), 300) # Cap at 300 seconds
@spaces.GPU(duration=120)
def classify(sequence):
protT5_emb = extract_prott5_embedding(sequence)
esm_emb = extract_esm_embedding(sequence)
concat = torch.cat((esm_emb, protT5_emb), dim=1)
pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
demo = gr.Interface(fn=classify,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
outputs=gr.Label(label="Prediction"))
# if __name__ == "__main__":
# demo.launch()
description_md = """
### ℹ️ **About AllerTrans – Allergenicity Prediction Tool**
**🧬 Input Format – FASTA Sequences**
This tool accepts protein sequences in FASTA format
**💡 Accepted Proteins**
- Natural and recombinant proteins
- Pharmaceutical and industrial proteins
- Synthetic sequences (tags or mutations allowed)
🔎 **Note of Caution**:
While our model demonstrates promising performance—particularly with recombinant proteins, as evidenced by our additional evaluation with a recombinant protein dataset
from UniProt—**we advise caution when generalizing the results to all recombinant protein scenarios**.
The specificity of the model to various recombinant constructs and modifications has not been explored.
**🧠 Prediction Process**
- Embeddings via ProtT5 + ESM-2
- Deep neural network for classification
**⚠️ Disclaimer**
Although AllerTrans provides highly accurate predictions, it is intended as a screening tool. For clinical or regulatory decisions, always confirm results with experimental validation.
"""
with gr.Blocks() as demo:
gr.Markdown(description_md)
with gr.Row():
input_box = gr.Textbox(lines=3, placeholder="Enter protein sequence...")
output_label = gr.Label(label="Prediction")
classify_btn = gr.Button("Run Prediction")
classify_btn.click(classify, inputs=input_box, outputs=output_label)
if __name__ == "__main__":
demo.launch()