javi8979 commited on
Commit
87d8688
·
verified ·
1 Parent(s): 7371949

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import snapshot_download
3
+ from vllm import LLM, SamplingParams
4
+
5
+ # ------------------------
6
+ # 1) Load the Model
7
+ # ------------------------
8
+ # Download the model repository, specify revision if needed
9
+ model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main")
10
+ model_name = "salamandrata_7b_inst_q4.gguf"
11
+
12
+ # Create an LLM instance from vLLM
13
+ llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir)
14
+
15
+ # We can define a single helper function to call the model:
16
+ def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256):
17
+ """
18
+ Sends the prompt to the LLM using vLLM's chat interface.
19
+ """
20
+ messages = [{'role': 'user', 'content': prompt}]
21
+ outputs = llm.chat(
22
+ messages,
23
+ sampling_params=SamplingParams(
24
+ temperature=temperature,
25
+ stop_token_ids=[5], # you can adjust the stop token ID if needed
26
+ max_tokens=max_tokens
27
+ )
28
+ )
29
+ # The model returns a list of "Generation" objects, each containing .outputs
30
+ return outputs[0].outputs[0].text if outputs else ""
31
+
32
+ # ------------------------
33
+ # 2) Task-specific functions
34
+ # ------------------------
35
+
36
+ def general_translation(source_lang, target_lang, text):
37
+ """
38
+ General translation prompt:
39
+ Translate from source_lang into target_lang.
40
+ """
41
+ prompt = (
42
+ f"Translate the following text from {source_lang} into {target_lang}.\n"
43
+ f"{source_lang}: {text}\n"
44
+ f"{target_lang}:"
45
+ )
46
+ return call_model(prompt, temperature=0.1)
47
+
48
+ def post_editing(source_lang, target_lang, source_text, machine_translation):
49
+ """
50
+ Post-editing prompt:
51
+ Ask the model to fix any mistakes in the machine translation or keep it unedited.
52
+ """
53
+ prompt = (
54
+ f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
55
+ f"Source: {source_text}\n"
56
+ f"MT: {machine_translation}\n"
57
+ f"Corrected:"
58
+ )
59
+ return call_model(prompt, temperature=0.1)
60
+
61
+ def document_level_translation(source_lang, target_lang, document_text):
62
+ """
63
+ Document-level translation prompt:
64
+ Translate a multi-paragraph document.
65
+ """
66
+ prompt = (
67
+ f"Please translate this text from {source_lang} into {target_lang}.\n"
68
+ f"{source_lang}: {document_text}\n"
69
+ f"{target_lang}:"
70
+ )
71
+ return call_model(prompt, temperature=0.1)
72
+
73
+ def named_entity_recognition(tokenized_text):
74
+ """
75
+ Named-entity recognition prompt:
76
+ Label tokens as ORG, PER, LOC, MISC, or O.
77
+ Expects the user to provide a list of tokens.
78
+ """
79
+ # Convert the input string into a list of tokens, if the user typed them as space-separated words
80
+ # or if the user provided them as a Python list string, we can try to parse that.
81
+ # For simplicity, let's assume it's a space-separated string.
82
+ tokens = tokenized_text.strip().split()
83
+
84
+ prompt = (
85
+ "Analyse the following tokenized text and mark the tokens containing named entities.\n"
86
+ "Use the following annotation guidelines with these tags for named entities:\n"
87
+ "- ORG (Refers to named groups or organizations)\n"
88
+ "- PER (Refers to individual people or named groups of people)\n"
89
+ "- LOC (Refers to physical places or natural landmarks)\n"
90
+ "- MISC (Refers to entities that don't fit into standard categories).\n"
91
+ "Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n"
92
+ "If a token is not a named entity, label it as O.\n"
93
+ f"Input: {tokens}\n"
94
+ "Marked:"
95
+ )
96
+ return call_model(prompt, temperature=0.1)
97
+
98
+ def grammar_checker(source_lang, sentence):
99
+ """
100
+ Grammar checker prompt:
101
+ Fix any mistakes in the given source_lang sentence or keep it unedited if correct.
102
+ """
103
+ prompt = (
104
+ f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
105
+ f"Sentence: {sentence}\n"
106
+ f"Corrected:"
107
+ )
108
+ return call_model(prompt, temperature=0.1)
109
+
110
+ # ------------------------
111
+ # 3) Gradio UI
112
+ # ------------------------
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("## SalamandraTA-7B-Instruct Demo")
115
+ gr.Markdown(
116
+ "This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n"
117
+ "1. General Translation\n"
118
+ "2. Post-editing\n"
119
+ "3. Document-level Translation\n"
120
+ "4. Named-Entity Recognition (NER)\n"
121
+ "5. Grammar Checking"
122
+ )
123
+
124
+ with gr.Tab("1. General Translation"):
125
+ gr.Markdown("### General Translation")
126
+ src_lang_gt = gr.Textbox(label="Source Language", value="Spanish")
127
+ tgt_lang_gt = gr.Textbox(label="Target Language", value="English")
128
+ text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.")
129
+ translate_button = gr.Button("Translate")
130
+ output_gt = gr.Textbox(label="Translation Output", lines=4)
131
+ translate_button.click(fn=general_translation,
132
+ inputs=[src_lang_gt, tgt_lang_gt, text_gt],
133
+ outputs=output_gt)
134
+
135
+ with gr.Tab("2. Post-editing"):
136
+ gr.Markdown("### Post-editing (Source → Target)")
137
+ src_lang_pe = gr.Textbox(label="Source Language", value="Catalan")
138
+ tgt_lang_pe = gr.Textbox(label="Target Language", value="English")
139
+ source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.")
140
+ mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.")
141
+ post_edit_button = gr.Button("Post-edit")
142
+ output_pe = gr.Textbox(label="Post-edited Text", lines=4)
143
+ post_edit_button.click(fn=post_editing,
144
+ inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe],
145
+ outputs=output_pe)
146
+
147
+ with gr.Tab("3. Document-level Translation"):
148
+ gr.Markdown("### Document-level Translation")
149
+ src_lang_doc = gr.Textbox(label="Source Language", value="English")
150
+ tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian")
151
+ doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)",
152
+ lines=8,
153
+ value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, "
154
+ "has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico "
155
+ "and Canada by February 2025."))
156
+ doc_button = gr.Button("Translate Document")
157
+ doc_output = gr.Textbox(label="Document-level Translation Output", lines=8)
158
+ doc_button.click(fn=document_level_translation,
159
+ inputs=[src_lang_doc, tgt_lang_doc, doc_text],
160
+ outputs=doc_output)
161
+
162
+ with gr.Tab("4. Named-Entity Recognition"):
163
+ gr.Markdown("### Named-Entity Recognition (NER)")
164
+ text_ner = gr.Textbox(
165
+ label="Tokenized Text (space-separated tokens)",
166
+ lines=2,
167
+ value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso."
168
+ )
169
+ ner_button = gr.Button("Run NER")
170
+ ner_output = gr.Textbox(label="NER Output", lines=6)
171
+ ner_button.click(fn=named_entity_recognition,
172
+ inputs=[text_ner],
173
+ outputs=ner_output)
174
+
175
+ with gr.Tab("5. Grammar Checker"):
176
+ gr.Markdown("### Grammar Checker")
177
+ src_lang_gc = gr.Textbox(label="Source Language", value="Catalan")
178
+ text_gc = gr.Textbox(label="Sentence to Check",
179
+ lines=2,
180
+ value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.")
181
+ gc_button = gr.Button("Check Grammar")
182
+ gc_output = gr.Textbox(label="Corrected Sentence", lines=2)
183
+ gc_button.click(fn=grammar_checker,
184
+ inputs=[src_lang_gc, text_gc],
185
+ outputs=gc_output)
186
+
187
+ demo.launch()