Spaces:
Running
Running
File size: 8,017 Bytes
28b803a 652f4e8 28b803a 652f4e8 28b803a f20a943 28b803a 694d5e0 28b803a f20a943 28b803a 694d5e0 28b803a 7398e14 28b803a 652f4e8 7398e14 f20a943 7398e14 652f4e8 7398e14 652f4e8 694d5e0 652f4e8 694d5e0 652f4e8 f20a943 652f4e8 694d5e0 0c44d3f f82feee 694d5e0 f82feee 694d5e0 f82feee 694d5e0 f82feee 694d5e0 f82feee ce1e5cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datetime import datetime
from functools import lru_cache
import torch
import requests
# Language codes
LANGUAGE_CODES = {
"English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
"Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
"Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
}
# Translation history
class TranslationHistory:
def __init__(self):
self.history = []
def add(self, src, translated, src_lang, tgt_lang):
self.history.insert(0, {
"source": src, "translated": translated,
"src_lang": src_lang, "tgt_lang": tgt_lang,
"timestamp": datetime.now().isoformat()
})
if len(self.history) > 100:
self.history.pop()
def get(self): return self.history
def clear(self): self.history = []
history = TranslationHistory()
# Translation model
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
@lru_cache(maxsize=512)
def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
if not text.strip(): return ""
src_code = LANGUAGE_CODES.get(src_lang, src_lang)
tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
input_tokens = tokenizer(text, return_tensors="pt", padding=True)
input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
output = model.generate(
**input_tokens,
forced_bos_token_id=forced_bos_token_id,
max_length=max_length, temperature=temperature,
num_beams=5, early_stopping=True
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
history.add(text, result, src_lang, tgt_lang)
return result
def translate_file(file, src_lang, tgt_lang, max_length, temperature):
try:
lines = file.decode("utf-8").splitlines()
translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()]
return "\n".join(translated)
except Exception as e:
return f"File translation error: {e}"
# Summarizer API
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
HF_API_KEY = os.environ.get("HF_API_KEY")
headers = {"Authorization": f"Bearer {HF_API_KEY}"}
def summarize_text(text, max_length):
if not text.strip(): return ""
min_length = max(10, max_length // 4)
response = requests.post(API_URL, headers=headers, json={
"inputs": text,
"parameters": {"min_length": min_length, "max_length": max_length}
})
result = response.json()
return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result)
# Paraphraser
paraphrase_tokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase")
paraphrase_model.to(device)
def paraphrase_text(input_text, num_return_sequences, num_beams):
batch = paraphrase_tokenizer([input_text], truncation=True, padding="longest", max_length=60, return_tensors="pt").to(device)
translated = paraphrase_model.generate(**batch, max_length=60, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
return paraphrase_tokenizer.batch_decode(translated, skip_special_tokens=True)
# Grammar Corrector
grammar_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
grammar_tokenizer = AutoTokenizer.from_pretrained(grammar_model_name)
grammar_model = AutoModelForSeq2SeqLM.from_pretrained(grammar_model_name)
grammar_model.to(device)
def correct_grammar(text):
input_ids = grammar_tokenizer(f"grammar: {text}", return_tensors="pt", truncation=True).input_ids.to(device)
output_ids = grammar_model.generate(input_ids, max_length=256, num_beams=5)
return grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# UI Style
gradio_style = """
.gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
"""
# Gradio UI
with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🤖 AI Toolbox: Translate, Summarize, Paraphrase, Correct Grammar")
with gr.Tab("🌐 Translator"):
src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English")
swap = gr.Button("⇄")
tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean")
input_text = gr.Textbox(lines=3, label="Input Text")
output_text = gr.Textbox(lines=3, label="Translated Output", interactive=False)
translate_btn = gr.Button("🚀 Translate")
clear_btn = gr.Button("🧽 Clear")
max_length = gr.Slider(10, 512, value=128, label="Max Length")
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
translate_btn.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
clear_btn.click(lambda: ("", ""), None, [input_text, output_text])
with gr.Tab("📁 File Translator"):
file_input = gr.File(label="Upload .txt File")
file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English")
file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean")
f_max_length = gr.Slider(10, 512, value=128, label="Max Length")
f_temp = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
file_btn = gr.Button("Translate File")
file_result = gr.Textbox(lines=10, label="File Output", interactive=False)
file_btn.click(lambda file, src, tgt, ml, temp: translate_file(file.read(), src, tgt, ml, temp),
[file_input, file_src, file_tgt, f_max_length, f_temp], file_result)
with gr.Tab("📝 Summarizer"):
summary_input = gr.Textbox(lines=5, label="Enter text to summarize")
summary_length = gr.Slider(32, 512, value=128, step=8, label="Max Length")
summary_output = gr.Textbox(lines=5, label="Summary", interactive=False)
summarize_btn = gr.Button("Summarize")
summarize_btn.click(summarize_text, [summary_input, summary_length], summary_output)
with gr.Tab("🔁 Paraphraser"):
para_input = gr.Textbox(lines=4, label="Enter text to paraphrase")
num_outputs = gr.Slider(1, 5, value=3, step=1, label="Number of Paraphrases")
beam_width = gr.Slider(1, 10, value=5, step=1, label="Beam Width")
para_output = gr.Textbox(label="Paraphrased Sentences", lines=6)
para_btn = gr.Button("Paraphrase")
para_btn.click(lambda text, num, beams: "\n\n".join(paraphrase_text(text, num, beams)),
[para_input, num_outputs, beam_width], para_output)
with gr.Tab("🛠 Grammar Corrector"):
grammar_input = gr.Textbox(lines=5, label="Enter sentence to correct")
grammar_output = gr.Textbox(label="Corrected Sentence", lines=5)
grammar_btn = gr.Button("Correct Grammar")
grammar_btn.click(correct_grammar, grammar_input, grammar_output)
gr.Markdown(f"""
### ℹ️ Info
- Translator: `{model_name}` on `{device}`
- Paraphraser: `tuner007/pegasus_paraphrase`
- Summarizer: `facebook/bart-large-cnn`
- Grammar Corrector: `{grammar_model_name}`
- API Token: {'✅ Found' if HF_API_KEY else '❌ Not Found'}
""")
if __name__ == "__main__":
demo.launch(share=True)
|