javi8979 commited on
Commit
3e11881
·
verified ·
1 Parent(s): 43dc82a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -159
app.py CHANGED
@@ -1,170 +1,78 @@
1
  import gradio as gr
2
- from datetime import datetime
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
 
5
 
6
- # ------------------------
7
- # 1) Load the Model
8
- # ------------------------
9
  model_id = "BSC-LT/salamandraTA-7b-instruct"
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
  device_map="auto",
14
- torch_dtype=torch.bfloat16
15
  )
16
 
17
- # Common function to generate text using transformers
18
- def call_model(prompt: str, max_tokens: int = 256):
19
- message = [{"role": "user", "content": prompt}]
20
- date_string = datetime.today().strftime('%Y-%m-%d')
21
-
22
- chat_prompt = tokenizer.apply_chat_template(
23
- message,
24
- tokenize=False,
25
- add_generation_prompt=True,
26
- date_string=date_string
27
- )
28
-
29
- inputs = tokenizer.encode(chat_prompt, return_tensors="pt").to(model.device)
30
- input_length = inputs.shape[1]
31
- outputs = model.generate(
32
- input_ids=inputs,
33
- max_new_tokens=max_tokens,
34
- do_sample=True,
35
- num_beams=5,
36
- early_stopping=True
37
- )
38
- return tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
39
-
40
- # ------------------------
41
- # 2) Task-specific functions
42
- # ------------------------
43
-
44
- def general_translation(source_lang, target_lang, text):
45
- prompt = (
46
- f"Translate the following text from {source_lang} into {target_lang}.\n"
47
- f"{source_lang}: {text}\n"
48
- f"{target_lang}:"
49
- )
50
- return call_model(prompt)
51
-
52
- def post_editing(source_lang, target_lang, source_text, machine_translation):
53
- prompt = (
54
- f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
55
- f"Source: {source_text}\n"
56
- f"MT: {machine_translation}\n"
57
- f"Corrected:"
58
- )
59
- return call_model(prompt, temperature=0.1)
60
-
61
- def document_level_translation(source_lang, target_lang, document_text):
62
- prompt = (
63
- f"Please translate this text from {source_lang} into {target_lang}.\n"
64
- f"{source_lang}: {document_text}\n"
65
- f"{target_lang}:"
66
- )
67
- return call_model(prompt)
68
-
69
- def named_entity_recognition(tokenized_text):
70
- tokens = tokenized_text.strip().split()
71
- prompt = (
72
- "Analyse the following tokenized text and mark the tokens containing named entities.\n"
73
- "Use the following annotation guidelines with these tags for named entities:\n"
74
- "- ORG (Refers to named groups or organizations)\n"
75
- "- PER (Refers to individual people or named groups of people)\n"
76
- "- LOC (Refers to physical places or natural landmarks)\n"
77
- "- MISC (Refers to entities that don't fit into standard categories).\n"
78
- "Prepend B- to the first token of a given entity and I- to the remaining ones if they exist.\n"
79
- "If a token is not a named entity, label it as O.\n"
80
- f"Input: {tokens}\n"
81
- "Marked:"
82
- )
83
- return call_model(prompt)
84
-
85
- def grammar_checker(source_lang, sentence):
86
- prompt = (
87
- f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
88
- f"Sentence: {sentence}\n"
89
- f"Corrected:"
90
- )
91
- return call_model(prompt)
92
-
93
- # ------------------------
94
- # 3) Gradio UI
95
- # ------------------------
96
- with gr.Blocks() as demo:
97
- gr.Markdown("## SalamandraTA-7B-Instruct Demo")
98
- gr.Markdown(
99
- "This Gradio app demonstrates various use-cases for the **SalamandraTA-7B-Instruct** model, including:\n"
100
- "1. General Translation\n"
101
- "2. Post-editing\n"
102
- "3. Document-level Translation\n"
103
- "4. Named-Entity Recognition (NER)\n"
104
- "5. Grammar Checking"
105
- )
106
-
107
- with gr.Tab("1. General Translation"):
108
- gr.Markdown("### General Translation")
109
- src_lang_gt = gr.Textbox(label="Source Language", value="Spanish")
110
- tgt_lang_gt = gr.Textbox(label="Target Language", value="English")
111
- text_gt = gr.Textbox(label="Text to Translate", lines=4, value="Ayer se fue, tomó sus cosas y se puso a navegar.")
112
- translate_button = gr.Button("Translate")
113
- output_gt = gr.Textbox(label="Translation Output", lines=4)
114
- translate_button.click(fn=general_translation,
115
- inputs=[src_lang_gt, tgt_lang_gt, text_gt],
116
- outputs=output_gt)
117
-
118
- with gr.Tab("2. Post-editing"):
119
- gr.Markdown("### Post-editing (Source → Target)")
120
- src_lang_pe = gr.Textbox(label="Source Language", value="Catalan")
121
- tgt_lang_pe = gr.Textbox(label="Target Language", value="English")
122
- source_text_pe = gr.Textbox(label="Source Text", lines=2, value="Rafael Nadal i Maria Magdalena van inspirar a una generació sencera.")
123
- mt_text_pe = gr.Textbox(label="Machine Translation", lines=2, value="Rafael Christmas and Maria the Muffin inspired an entire generation each in their own way.")
124
- post_edit_button = gr.Button("Post-edit")
125
- output_pe = gr.Textbox(label="Post-edited Text", lines=4)
126
- post_edit_button.click(fn=post_editing,
127
- inputs=[src_lang_pe, tgt_lang_pe, source_text_pe, mt_text_pe],
128
- outputs=output_pe)
129
-
130
- with gr.Tab("3. Document-level Translation"):
131
- gr.Markdown("### Document-level Translation")
132
- src_lang_doc = gr.Textbox(label="Source Language", value="English")
133
- tgt_lang_doc = gr.Textbox(label="Target Language", value="Asturian")
134
- doc_text = gr.Textbox(label="Document Text (multiple paragraphs allowed)",
135
- lines=8,
136
- value=("President Donald Trump, who campaigned on promises to crack down on illegal immigration, "
137
- "has raised alarms in the U.S. dairy industry with his threat to impose 25% tariffs on Mexico "
138
- "and Canada by February 2025."))
139
- doc_button = gr.Button("Translate Document")
140
- doc_output = gr.Textbox(label="Document-level Translation Output", lines=8)
141
- doc_button.click(fn=document_level_translation,
142
- inputs=[src_lang_doc, tgt_lang_doc, doc_text],
143
- outputs=doc_output)
144
-
145
- with gr.Tab("4. Named-Entity Recognition"):
146
- gr.Markdown("### Named-Entity Recognition (NER)")
147
- text_ner = gr.Textbox(
148
- label="Tokenized Text (space-separated tokens)",
149
- lines=2,
150
- value="La defensa del antiguo responsable de la RFEF confirma que interpondrá un recurso."
151
  )
152
- ner_button = gr.Button("Run NER")
153
- ner_output = gr.Textbox(label="NER Output", lines=6)
154
- ner_button.click(fn=named_entity_recognition,
155
- inputs=[text_ner],
156
- outputs=ner_output)
157
-
158
- with gr.Tab("5. Grammar Checker"):
159
- gr.Markdown("### Grammar Checker")
160
- src_lang_gc = gr.Textbox(label="Source Language", value="Catalan")
161
- text_gc = gr.Textbox(label="Sentence to Check",
162
- lines=2,
163
- value="Entonses, el meu jefe m’ha dit que he de treballar els fins de setmana.")
164
- gc_button = gr.Button("Check Grammar")
165
- gc_output = gr.Textbox(label="Corrected Sentence", lines=2)
166
- gc_button.click(fn=grammar_checker,
167
- inputs=[src_lang_gc, text_gc],
168
- outputs=gc_output)
169
-
170
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from datetime import datetime
6
 
 
 
 
7
  model_id = "BSC-LT/salamandraTA-7b-instruct"
8
+
9
+ # Load tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
  device_map="auto",
15
+ torch_dtype=torch.bfloat16 # Usa bf16 como en el ejemplo original
16
  )
17
 
18
+ languages = [ "Spanish", "Catalan", "English", "French", "German", "Italian", "Portuguese", "Euskera", "Galician",
19
+ "Bulgarian", "Czech", "Lithuanian", "Croatian", "Dutch", "Romanian", "Danish", "Greek", "Finnish",
20
+ "Hungarian", "Slovak", "Slovenian", "Estonian", "Polish", "Latvian", "Swedish", "Maltese",
21
+ "Irish", "Aranese", "Aragonese", "Asturian" ]
22
+
23
+ example_sentence = ["Ahir se'n va anar, va agafar les seves coses i es va posar a navegar."]
24
+
25
+ @spaces.GPU(duration=120)
26
+ def translate(input_text, source, target):
27
+ sentences = [s for s in input_text.strip().split('\n') if s.strip()]
28
+ translated_sentences = []
29
+
30
+ for sentence in sentences:
31
+ prompt_text = f"Translate the following text from {source} into {target}.\n{source}: {sentence} \n{target}:"
32
+ messages = [{"role": "user", "content": prompt_text}]
33
+ date_string = datetime.today().strftime('%Y-%m-%d')
34
+
35
+ prompt = tokenizer.apply_chat_template(
36
+ messages,
37
+ tokenize=False,
38
+ add_generation_prompt=True,
39
+ date_string=date_string
40
+ )
41
+
42
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
43
+ input_length = inputs.input_ids.shape[1]
44
+
45
+ output = model.generate(
46
+ input_ids=inputs.input_ids,
47
+ max_new_tokens=400,
48
+ early_stopping=True,
49
+ num_beams=5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
+
52
+ decoded = tokenizer.decode(output[0, input_length:], skip_special_tokens=True).strip()
53
+ translated_sentences.append(decoded)
54
+
55
+ return '\n'.join(translated_sentences), ""
56
+
57
+ with gr.Blocks() as demo:
58
+ gr.HTML("""<html>
59
+ <head><style>h1 { text-align: center; }</style></head>
60
+ <body><h1>SalamandraTA 7B Translate</h1></body>
61
+ </html>""")
62
+
63
+ with gr.Row():
64
+ with gr.Column():
65
+ source_language_dropdown = gr.Dropdown(choices=languages, value="Catalan", label="Source Language")
66
+ input_textbox = gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text")
67
+ with gr.Column():
68
+ target_language_dropdown = gr.Dropdown(choices=languages, value="English", label="Target Language")
69
+ translated_textbox = gr.Textbox(lines=5, placeholder="", label="Translated Text")
70
+
71
+ info_label = gr.HTML("")
72
+ btn = gr.Button("Translate")
73
+ btn.click(translate, inputs=[input_textbox, source_language_dropdown, target_language_dropdown],
74
+ outputs=[translated_textbox, info_label])
75
+ gr.Examples(example_sentence, inputs=[input_textbox])
76
+
77
+ if __name__ == "__main__":
78
+ demo.launch()