import torch import gradio as gr import numpy as np from transformers import T5Tokenizer, T5EncoderModel import esm from inference import load_models, predict_ensemble # Load trained models model_protT5, model_cat = load_models() # Load ProtT5 model tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False) model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50") model_t5 = model_t5.eval() # Load ESM model esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter() esm_model.eval() def extract_prott5_embedding(sequence): sequence = sequence.replace(" ", "") seq = " ".join(list(sequence)) ids = tokenizer_t5(seq, return_tensors="pt", padding=True) with torch.no_grad(): embedding = model_t5(**ids).last_hidden_state return torch.mean(embedding, dim=1) def extract_esm_embedding(sequence): batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence)]) with torch.no_grad(): results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False) token_representations = results["representations"][33] return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0) def classify(sequence): protT5_emb = extract_prott5_embedding(sequence) esm_emb = extract_esm_embedding(sequence) concat = torch.cat((esm_emb, protT5_emb), dim=1) pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat) return "Allergen" if pred.item() == 1 else "Non-Allergen" demo = gr.Interface(fn=classify, inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."), outputs=gr.Label(label="Prediction")) if __name__ == "__main__": demo.launch()