Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import os | |
from datetime import datetime | |
from functools import lru_cache | |
import torch | |
import requests | |
# Language codes | |
LANGUAGE_CODES = { | |
"English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans", | |
"Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl", | |
"Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai" | |
} | |
# Translation history | |
class TranslationHistory: | |
def __init__(self): | |
self.history = [] | |
def add(self, src, translated, src_lang, tgt_lang): | |
self.history.insert(0, { | |
"source": src, "translated": translated, | |
"src_lang": src_lang, "tgt_lang": tgt_lang, | |
"timestamp": datetime.now().isoformat() | |
}) | |
if len(self.history) > 100: | |
self.history.pop() | |
def get(self): return self.history | |
def clear(self): self.history = [] | |
history = TranslationHistory() | |
# Translation model | |
model_name = "facebook/nllb-200-distilled-600M" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7): | |
if not text.strip(): return "" | |
src_code = LANGUAGE_CODES.get(src_lang, src_lang) | |
tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang) | |
input_tokens = tokenizer(text, return_tensors="pt", padding=True) | |
input_tokens = {k: v.to(device) for k, v in input_tokens.items()} | |
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code) | |
output = model.generate( | |
**input_tokens, | |
forced_bos_token_id=forced_bos_token_id, | |
max_length=max_length, temperature=temperature, | |
num_beams=5, early_stopping=True | |
) | |
result = tokenizer.decode(output[0], skip_special_tokens=True) | |
history.add(text, result, src_lang, tgt_lang) | |
return result | |
def translate_file(file, src_lang, tgt_lang, max_length, temperature): | |
try: | |
lines = file.decode("utf-8").splitlines() | |
translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()] | |
return "\n".join(translated) | |
except Exception as e: | |
return f"File translation error: {e}" | |
# Summarizer API | |
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn" | |
HF_API_KEY = os.environ.get("HF_API_KEY") | |
headers = {"Authorization": f"Bearer {HF_API_KEY}"} | |
def summarize_text(text, max_length): | |
if not text.strip(): return "" | |
min_length = max(10, max_length // 4) | |
response = requests.post(API_URL, headers=headers, json={ | |
"inputs": text, | |
"parameters": {"min_length": min_length, "max_length": max_length} | |
}) | |
result = response.json() | |
return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result) | |
# Paraphraser | |
paraphrase_tokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase") | |
paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase") | |
paraphrase_model.to(device) | |
def paraphrase_text(input_text, num_return_sequences, num_beams): | |
batch = paraphrase_tokenizer([input_text], truncation=True, padding="longest", max_length=60, return_tensors="pt").to(device) | |
translated = paraphrase_model.generate(**batch, max_length=60, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5) | |
return paraphrase_tokenizer.batch_decode(translated, skip_special_tokens=True) | |
# Grammar Corrector | |
grammar_model_name = "pszemraj/flan-t5-large-grammar-synthesis" | |
grammar_tokenizer = AutoTokenizer.from_pretrained(grammar_model_name) | |
grammar_model = AutoModelForSeq2SeqLM.from_pretrained(grammar_model_name) | |
grammar_model.to(device) | |
def correct_grammar(text): | |
input_ids = grammar_tokenizer(f"grammar: {text}", return_tensors="pt", truncation=True).input_ids.to(device) | |
output_ids = grammar_model.generate(input_ids, max_length=256, num_beams=5) | |
return grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# UI Style | |
gradio_style = """ | |
.gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; } | |
textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; } | |
textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; } | |
""" | |
# Gradio UI | |
with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## 🤖 AI Toolbox: Translate, Summarize, Paraphrase, Correct Grammar") | |
with gr.Tab("🌐 Translator"): | |
src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English") | |
swap = gr.Button("⇄") | |
tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean") | |
input_text = gr.Textbox(lines=3, label="Input Text") | |
output_text = gr.Textbox(lines=3, label="Translated Output", interactive=False) | |
translate_btn = gr.Button("🚀 Translate") | |
clear_btn = gr.Button("🧽 Clear") | |
max_length = gr.Slider(10, 512, value=128, label="Max Length") | |
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature") | |
translate_btn.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text) | |
clear_btn.click(lambda: ("", ""), None, [input_text, output_text]) | |
with gr.Tab("📁 File Translator"): | |
file_input = gr.File(label="Upload .txt File") | |
file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English") | |
file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean") | |
f_max_length = gr.Slider(10, 512, value=128, label="Max Length") | |
f_temp = gr.Slider(0.1, 2.0, value=0.7, label="Temperature") | |
file_btn = gr.Button("Translate File") | |
file_result = gr.Textbox(lines=10, label="File Output", interactive=False) | |
file_btn.click(lambda file, src, tgt, ml, temp: translate_file(file.read(), src, tgt, ml, temp), | |
[file_input, file_src, file_tgt, f_max_length, f_temp], file_result) | |
with gr.Tab("📝 Summarizer"): | |
summary_input = gr.Textbox(lines=5, label="Enter text to summarize") | |
summary_length = gr.Slider(32, 512, value=128, step=8, label="Max Length") | |
summary_output = gr.Textbox(lines=5, label="Summary", interactive=False) | |
summarize_btn = gr.Button("Summarize") | |
summarize_btn.click(summarize_text, [summary_input, summary_length], summary_output) | |
with gr.Tab("🔁 Paraphraser"): | |
para_input = gr.Textbox(lines=4, label="Enter text to paraphrase") | |
num_outputs = gr.Slider(1, 5, value=3, step=1, label="Number of Paraphrases") | |
beam_width = gr.Slider(1, 10, value=5, step=1, label="Beam Width") | |
para_output = gr.Textbox(label="Paraphrased Sentences", lines=6) | |
para_btn = gr.Button("Paraphrase") | |
para_btn.click(lambda text, num, beams: "\n\n".join(paraphrase_text(text, num, beams)), | |
[para_input, num_outputs, beam_width], para_output) | |
with gr.Tab("🛠 Grammar Corrector"): | |
grammar_input = gr.Textbox(lines=5, label="Enter sentence to correct") | |
grammar_output = gr.Textbox(label="Corrected Sentence", lines=5) | |
grammar_btn = gr.Button("Correct Grammar") | |
grammar_btn.click(correct_grammar, grammar_input, grammar_output) | |
gr.Markdown(f""" | |
### ℹ️ Info | |
- Translator: `{model_name}` on `{device}` | |
- Paraphraser: `tuner007/pegasus_paraphrase` | |
- Summarizer: `facebook/bart-large-cnn` | |
- Grammar Corrector: `{grammar_model_name}` | |
- API Token: {'✅ Found' if HF_API_KEY else '❌ Not Found'} | |
""") | |
if __name__ == "__main__": | |
demo.launch(share=True) | |