entropy25 commited on
Commit
e9db9b3
·
verified ·
1 Parent(s): 56b2235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -3,8 +3,13 @@ import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from peft import PeftModel
5
 
 
6
  base_model_name = "facebook/nllb-200-distilled-600M"
7
- adapter_model_name = "entropy25/mt_en_no_oil"
 
 
 
 
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
10
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -14,7 +19,8 @@ base_model = AutoModelForSeq2SeqLM.from_pretrained(
14
  device_map="auto"
15
  )
16
 
17
- model = PeftModel.from_pretrained(base_model, adapter_model_name)
 
18
 
19
  def translate(text, source_lang, target_lang):
20
  if not text.strip():
@@ -23,11 +29,18 @@ def translate(text, source_lang, target_lang):
23
  if source_lang == target_lang:
24
  return text
25
 
26
- lang_map = {
27
- "English": "eng_Latn",
28
- "Norwegian": "nob_Latn"
29
- }
 
 
 
 
 
 
30
 
 
31
  sentences = text.split('\n')
32
  translated_sentences = []
33
 
@@ -48,7 +61,7 @@ def translate(text, source_lang, target_lang):
48
 
49
  outputs = model.generate(
50
  **inputs,
51
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(lang_map[target_lang]),
52
  max_length=512,
53
  num_beams=5
54
  )
@@ -212,7 +225,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
212
 
213
  gr.HTML(
214
  "<div class='footer-info'>"
215
- "Oil & Gas Translation • English ↔ Norwegian"
216
  "</div>"
217
  )
218
 
 
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
+
7
  base_model_name = "facebook/nllb-200-distilled-600M"
8
+
9
+
10
+ adapter_en_to_no = "entropy25/mt_en_no_oil"
11
+ adapter_no_to_en = "entropy25/mt_no_en_oil"
12
+
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
15
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
 
19
  device_map="auto"
20
  )
21
 
22
+ model_en_to_no = PeftModel.from_pretrained(base_model, adapter_en_to_no)
23
+ model_no_to_en = PeftModel.from_pretrained(base_model, adapter_no_to_en)
24
 
25
  def translate(text, source_lang, target_lang):
26
  if not text.strip():
 
29
  if source_lang == target_lang:
30
  return text
31
 
32
+ if source_lang == "English" and target_lang == "Norwegian":
33
+ model = model_en_to_no
34
+ src_code = "eng_Latn"
35
+ tgt_code = "nob_Latn"
36
+ elif source_lang == "Norwegian" and target_lang == "English":
37
+ model = model_no_to_en
38
+ src_code = "nob_Latn"
39
+ tgt_code = "eng_Latn"
40
+ else:
41
+ return "Unsupported language pair"
42
 
43
+
44
  sentences = text.split('\n')
45
  translated_sentences = []
46
 
 
61
 
62
  outputs = model.generate(
63
  **inputs,
64
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
65
  max_length=512,
66
  num_beams=5
67
  )
 
225
 
226
  gr.HTML(
227
  "<div class='footer-info'>"
228
+ "Oil & Gas Translation • English ↔ Norwegian • Bidirectional Model"
229
  "</div>"
230
  )
231