entropy25 commited on
Commit
07f92fc
·
verified ·
1 Parent(s): f3eeb2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -42
app.py CHANGED
@@ -3,62 +3,142 @@ import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
- # Base model
7
- base_model_name = "facebook/nllb-200-distilled-600M"
8
- # LoRA adapter
9
- adapter_model_name = "entropy25/mt_en_no_oil"
10
 
11
- # Load with optimizations to reduce storage
12
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
13
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
14
- base_model_name,
15
- torch_dtype=torch.float16, # Use half precision
16
  low_cpu_mem_usage=True,
17
  device_map="auto"
18
  )
19
 
20
- # Apply the LoRA adapter
21
- model = PeftModel.from_pretrained(base_model, adapter_model_name)
 
 
 
 
 
22
 
23
  def translate(text, source_lang, target_lang):
24
  if not text.strip():
25
- return "Please enter text to translate."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- lang_map = {
28
- "English": "eng_Latn",
29
- "Norwegian": "nob_Latn"
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- inputs = tokenizer(
33
- text,
34
- return_tensors="pt",
35
- truncation=True,
36
- max_length=512
 
 
 
 
 
 
 
 
 
37
  )
38
 
39
- # Move inputs to the same device as model
40
- if hasattr(model, 'device'):
41
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
42
 
43
- outputs = model.generate(
44
- **inputs,
45
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(lang_map[target_lang]),
46
- max_length=512,
47
- num_beams=5
48
  )
49
 
50
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
- return result
 
 
 
 
 
 
 
 
52
 
53
- # Simple Gradio UI
54
- gr.Interface(
55
- fn=lambda text, src, tgt: translate(text, src, tgt),
56
- inputs=[
57
- gr.Textbox(label="Input text", lines=6),
58
- gr.Dropdown(choices=["English", "Norwegian"], label="Source language", value="English"),
59
- gr.Dropdown(choices=["English", "Norwegian"], label="Target language", value="Norwegian")
60
- ],
61
- outputs=gr.Textbox(label="Translation", lines=6),
62
- title="LoRA-Enhanced English↔Norwegian Translator",
63
- description="Fine-tuned NLLB-200 model with LoRA adapter: entropy25/mt_en_no_oil"
64
- ).launch()
 
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from peft import PeftModel
5
 
6
+ BASE_MODEL = "facebook/nllb-200-distilled-600M"
7
+ ADAPTER_NO_TO_EN = "entropy25/mt_en_no_oil"
8
+ #ADAPTER_EN_TO_NO = "entropy25/no_en"
 
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
11
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
12
+ BASE_MODEL,
13
+ torch_dtype=torch.float16,
14
  low_cpu_mem_usage=True,
15
  device_map="auto"
16
  )
17
 
18
+ model_no_to_en = PeftModel.from_pretrained(base_model, ADAPTER_NO_TO_EN)
19
+ model_en_to_no = PeftModel.from_pretrained(base_model, ADAPTER_EN_TO_NO)
20
+
21
+ LANG_CODES = {
22
+ "English": "eng_Latn",
23
+ "Norwegian": "nob_Latn"
24
+ }
25
 
26
  def translate(text, source_lang, target_lang):
27
  if not text.strip():
28
+ return "Please enter text to translate"
29
+
30
+ if source_lang == target_lang:
31
+ return "Source and target languages must be different"
32
+
33
+ try:
34
+ model = model_no_to_en if source_lang == "Norwegian" else model_en_to_no
35
+
36
+ inputs = tokenizer(
37
+ text,
38
+ return_tensors="pt",
39
+ truncation=True,
40
+ max_length=512
41
+ )
42
+
43
+ if hasattr(model, 'device'):
44
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
45
+
46
+ outputs = model.generate(
47
+ **inputs,
48
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(LANG_CODES[target_lang]),
49
+ max_length=512,
50
+ num_beams=5
51
+ )
52
+
53
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+ return result
55
+
56
+ except Exception as e:
57
+ return f"Translation error: {str(e)}"
58
+
59
+ def swap_languages(source, target, text, translation):
60
+ return target, source, translation, text
61
+
62
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
63
+ gr.Markdown("# Oil & Gas Professional Translation")
64
+ gr.Markdown("English ↔ Norwegian translation specialized for petroleum industry")
65
+
66
+ with gr.Row():
67
+ source_lang = gr.Dropdown(
68
+ choices=["English", "Norwegian"],
69
+ label="Source Language",
70
+ value="English"
71
+ )
72
+
73
+ swap_btn = gr.Button("⇄", scale=0, size="sm")
74
+
75
+ target_lang = gr.Dropdown(
76
+ choices=["English", "Norwegian"],
77
+ label="Target Language",
78
+ value="Norwegian"
79
+ )
80
 
81
+ with gr.Row():
82
+ with gr.Column():
83
+ input_text = gr.Textbox(
84
+ label="Input Text",
85
+ placeholder="Enter text to translate",
86
+ lines=8
87
+ )
88
+ input_chars = gr.Textbox(
89
+ label="Character Count",
90
+ value="0",
91
+ interactive=False,
92
+ max_lines=1
93
+ )
94
+
95
+ with gr.Column():
96
+ output_text = gr.Textbox(
97
+ label="Translation",
98
+ lines=8,
99
+ interactive=False
100
+ )
101
+ with gr.Row():
102
+ copy_btn = gr.Button("📋 Copy", scale=1)
103
+ clear_btn = gr.Button("🗑️ Clear", scale=1)
104
 
105
+ translate_btn = gr.Button("Translate", variant="primary", size="lg")
106
+
107
+ gr.Examples(
108
+ examples=[
109
+ ["The drilling operation encountered high pressure", "English", "Norwegian"],
110
+ ["Reservoaret viser god permeabilitet", "Norwegian", "English"]
111
+ ],
112
+ inputs=[input_text, source_lang, target_lang]
113
+ )
114
+
115
+ input_text.change(
116
+ fn=lambda x: str(len(x)),
117
+ inputs=input_text,
118
+ outputs=input_chars
119
  )
120
 
121
+ translate_btn.click(
122
+ fn=translate,
123
+ inputs=[input_text, source_lang, target_lang],
124
+ outputs=output_text
125
+ )
126
 
127
+ swap_btn.click(
128
+ fn=swap_languages,
129
+ inputs=[source_lang, target_lang, input_text, output_text],
130
+ outputs=[source_lang, target_lang, input_text, output_text]
 
131
  )
132
 
133
+ copy_btn.click(
134
+ fn=lambda x: x,
135
+ inputs=output_text,
136
+ outputs=input_text
137
+ )
138
+
139
+ clear_btn.click(
140
+ fn=lambda: ("", ""),
141
+ outputs=[input_text, output_text]
142
+ )
143
 
144
+ demo.launch()