sfaezella commited on
Commit
b0b116b
Β·
verified Β·
1 Parent(s): f745782

Update input sequence description

Browse files
Files changed (1) hide show
  1. app.py +83 -80
app.py CHANGED
@@ -1,81 +1,84 @@
1
- import torch
2
- import gradio as gr
3
- import numpy as np
4
- from transformers import T5Tokenizer, T5EncoderModel
5
- import esm
6
- from inference import load_models, predict_ensemble
7
- from transformers import AutoTokenizer, AutoModel
8
- import spaces
9
-
10
- # Load trained models
11
- model_protT5, model_cat = load_models()
12
-
13
- # Load ProtT5 model
14
- tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
- model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
- model_t5 = model_t5.eval()
17
-
18
- # Load the tokenizer and model
19
- model_name = "facebook/esm2_t33_650M_UR50D"
20
- tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
- esm_model = AutoModel.from_pretrained(model_name)
22
-
23
- def extract_prott5_embedding(sequence):
24
- sequence = sequence.replace(" ", "")
25
- seq = " ".join(list(sequence))
26
- ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
27
- with torch.no_grad():
28
- embedding = model_t5(**ids).last_hidden_state
29
- return torch.mean(embedding, dim=1)
30
-
31
-
32
- # Extract ESM2 embedding
33
- def extract_esm_embedding(sequence):
34
- # Tokenize the sequence
35
- inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
36
-
37
- # Forward pass through the model
38
- with torch.no_grad():
39
- outputs = esm_model(**inputs)
40
-
41
- # Extract the embeddings from the 33rd layer (ESM2 layer)
42
- token_representations = outputs.last_hidden_state # This is the default layer
43
- return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
-
45
- def estimate_duration(sequence):
46
- # Estimate duration based on sequence length
47
- base_time = 30 # Base time in seconds
48
- time_per_residue = 0.5 # Estimated time per residue
49
- estimated_time = base_time + len(sequence) * time_per_residue
50
- return min(int(estimated_time), 300) # Cap at 300 seconds
51
-
52
- @spaces.GPU(duration=120)
53
- def classify(sequence):
54
- protT5_emb = extract_prott5_embedding(sequence)
55
- esm_emb = extract_esm_embedding(sequence)
56
- concat = torch.cat((esm_emb, protT5_emb), dim=1)
57
- pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
58
- return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
59
-
60
- description_md = """
61
- ## πŸ“Œ **About AllerTrans – A Powerful Tool for Predicting the Allergenicity of Protein Sequences**
62
-
63
- **🧬 Input Format – FASTA Sequences:** This tool accepts protein sequences in FASTA format.
64
-
65
- **🧾 Output Explanation** – AllerTrans classifies your input sequence into one of the following categories:
66
- ###### **🟒 Non-Allergen:** The protein is unlikely to cause an allergic reaction and can be considered safe regarding allergenicity.
67
- ###### **πŸ”΄ Potential Allergen:** The protein has the potential to trigger an allergic response or exhibit cross-reactivity in some individuals.
68
-
69
- **πŸ”Ž Caution & Disclaimer:**
70
- ###### Our model has demonstrated promising performance on the AlgPred 2.0 validation set, which includes a wide range of allergenic and non-allergenic sequences from diverse sources. AllerTrans is also capable of handling recombinant proteins, as supported by additional evaluation using a recombinant protein dataset from UniProt. However, **we advise caution when using this tool on all constructs and modifications of recombinant proteins**. The model's generalizability across various recombinant scenarios has yet to be fully explored.
71
-
72
- ###### 🚨 Remember, AllerTrans is designed as a reliable screening tool. However, for clinical or regulatory decisions, always confirm the prediction results through experimental validation.
73
- """
74
-
75
- demo = gr.Interface(fn=classify,
76
- inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
77
- outputs=gr.Label(label="Prediction"),
78
- description=description_md)
79
-
80
- if __name__ == "__main__":
 
 
 
81
  demo.launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ from transformers import T5Tokenizer, T5EncoderModel
5
+ import esm
6
+ from inference import load_models, predict_ensemble
7
+ from transformers import AutoTokenizer, AutoModel
8
+ import spaces
9
+
10
+ # Load trained models
11
+ model_protT5, model_cat = load_models()
12
+
13
+ # Load ProtT5 model
14
+ tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
+ model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
+ model_t5 = model_t5.eval()
17
+
18
+ # Load the tokenizer and model
19
+ model_name = "facebook/esm2_t33_650M_UR50D"
20
+ tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
+ esm_model = AutoModel.from_pretrained(model_name)
22
+
23
+ def extract_prott5_embedding(sequence):
24
+ sequence = sequence.replace(" ", "")
25
+ seq = " ".join(list(sequence))
26
+ ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
27
+ with torch.no_grad():
28
+ embedding = model_t5(**ids).last_hidden_state
29
+ return torch.mean(embedding, dim=1)
30
+
31
+
32
+ # Extract ESM2 embedding
33
+ def extract_esm_embedding(sequence):
34
+ # Tokenize the sequence
35
+ inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
36
+
37
+ # Forward pass through the model
38
+ with torch.no_grad():
39
+ outputs = esm_model(**inputs)
40
+
41
+ # Extract the embeddings from the 33rd layer (ESM2 layer)
42
+ token_representations = outputs.last_hidden_state # This is the default layer
43
+ return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
+
45
+ def estimate_duration(sequence):
46
+ # Estimate duration based on sequence length
47
+ base_time = 30 # Base time in seconds
48
+ time_per_residue = 0.5 # Estimated time per residue
49
+ estimated_time = base_time + len(sequence) * time_per_residue
50
+ return min(int(estimated_time), 300) # Cap at 300 seconds
51
+
52
+ @spaces.GPU(duration=120)
53
+ def classify(sequence):
54
+ protT5_emb = extract_prott5_embedding(sequence)
55
+ esm_emb = extract_esm_embedding(sequence)
56
+ concat = torch.cat((esm_emb, protT5_emb), dim=1)
57
+ pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
58
+ return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
59
+
60
+ description_md = """
61
+ ## πŸ“Œ **About AllerTrans – A Powerful Tool for Predicting the Allergenicity of Protein Sequences**
62
+
63
+ **🧬 Input Format – FASTA Sequences:** This tool accepts protein sequences in FASTA format.
64
+
65
+ Please provide a single protein sequence at a time.
66
+ For faster predictions, you may enter only the amino acid sequence, without the FASTA header or any additional information.
67
+
68
+ **🧾 Output Explanation** – AllerTrans classifies your input sequence into one of the following categories:
69
+ ###### **🟒 Non-Allergen:** The protein is unlikely to cause an allergic reaction and can be considered safe regarding allergenicity.
70
+ ###### **πŸ”΄ Potential Allergen:** The protein has the potential to trigger an allergic response or exhibit cross-reactivity in some individuals.
71
+
72
+ **πŸ”Ž Caution & Disclaimer:**
73
+ ###### Our model has demonstrated promising performance on the AlgPred 2.0 validation set, which includes a wide range of allergenic and non-allergenic sequences from diverse sources. AllerTrans is also capable of handling recombinant proteins, as supported by additional evaluation using a recombinant protein dataset from UniProt. However, **we advise caution when using this tool on all constructs and modifications of recombinant proteins**. The model's generalizability across various recombinant scenarios has yet to be fully explored.
74
+
75
+ ###### 🚨 Remember, AllerTrans is designed as a reliable screening tool. However, for clinical or regulatory decisions, always confirm the prediction results through experimental validation.
76
+ """
77
+
78
+ demo = gr.Interface(fn=classify,
79
+ inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
80
+ outputs=gr.Label(label="Prediction"),
81
+ description=description_md)
82
+
83
+ if __name__ == "__main__":
84
  demo.launch()