|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
BASE_MODEL = "facebook/nllb-200-distilled-600M" |
|
|
ADAPTER_NO_TO_EN = "entropy25/mt_en_no_oil" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
|
base_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
model_no_to_en = PeftModel.from_pretrained(base_model, ADAPTER_NO_TO_EN) |
|
|
model_en_to_no = PeftModel.from_pretrained(base_model, ADAPTER_EN_TO_NO) |
|
|
|
|
|
LANG_CODES = { |
|
|
"English": "eng_Latn", |
|
|
"Norwegian": "nob_Latn" |
|
|
} |
|
|
|
|
|
def translate(text, source_lang, target_lang): |
|
|
if not text.strip(): |
|
|
return "Please enter text to translate" |
|
|
|
|
|
if source_lang == target_lang: |
|
|
return "Source and target languages must be different" |
|
|
|
|
|
try: |
|
|
model = model_no_to_en if source_lang == "Norwegian" else model_en_to_no |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
if hasattr(model, 'device'): |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids(LANG_CODES[target_lang]), |
|
|
max_length=512, |
|
|
num_beams=5 |
|
|
) |
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return f"Translation error: {str(e)}" |
|
|
|
|
|
def swap_languages(source, target, text, translation): |
|
|
return target, source, translation, text |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# Oil & Gas Professional Translation") |
|
|
gr.Markdown("English β Norwegian translation specialized for petroleum industry") |
|
|
|
|
|
with gr.Row(): |
|
|
source_lang = gr.Dropdown( |
|
|
choices=["English", "Norwegian"], |
|
|
label="Source Language", |
|
|
value="English" |
|
|
) |
|
|
|
|
|
swap_btn = gr.Button("β", scale=0, size="sm") |
|
|
|
|
|
target_lang = gr.Dropdown( |
|
|
choices=["English", "Norwegian"], |
|
|
label="Target Language", |
|
|
value="Norwegian" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="Enter text to translate", |
|
|
lines=8 |
|
|
) |
|
|
input_chars = gr.Textbox( |
|
|
label="Character Count", |
|
|
value="0", |
|
|
interactive=False, |
|
|
max_lines=1 |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox( |
|
|
label="Translation", |
|
|
lines=8, |
|
|
interactive=False |
|
|
) |
|
|
with gr.Row(): |
|
|
copy_btn = gr.Button("π Copy", scale=1) |
|
|
clear_btn = gr.Button("ποΈ Clear", scale=1) |
|
|
|
|
|
translate_btn = gr.Button("Translate", variant="primary", size="lg") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["The drilling operation encountered high pressure", "English", "Norwegian"], |
|
|
["Reservoaret viser god permeabilitet", "Norwegian", "English"] |
|
|
], |
|
|
inputs=[input_text, source_lang, target_lang] |
|
|
) |
|
|
|
|
|
input_text.change( |
|
|
fn=lambda x: str(len(x)), |
|
|
inputs=input_text, |
|
|
outputs=input_chars |
|
|
) |
|
|
|
|
|
translate_btn.click( |
|
|
fn=translate, |
|
|
inputs=[input_text, source_lang, target_lang], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
swap_btn.click( |
|
|
fn=swap_languages, |
|
|
inputs=[source_lang, target_lang, input_text, output_text], |
|
|
outputs=[source_lang, target_lang, input_text, output_text] |
|
|
) |
|
|
|
|
|
copy_btn.click( |
|
|
fn=lambda x: x, |
|
|
inputs=output_text, |
|
|
outputs=input_text |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: ("", ""), |
|
|
outputs=[input_text, output_text] |
|
|
) |
|
|
|
|
|
demo.launch() |