File size: 8,017 Bytes
28b803a
 
652f4e8
28b803a
 
652f4e8
 
28b803a
f20a943
28b803a
 
 
 
 
 
694d5e0
28b803a
 
 
 
 
 
 
 
 
 
 
f20a943
 
28b803a
 
 
694d5e0
28b803a
7398e14
 
 
 
28b803a
 
 
652f4e8
7398e14
 
 
 
 
f20a943
 
7398e14
652f4e8
 
7398e14
 
 
 
 
652f4e8
 
 
 
694d5e0
652f4e8
 
 
694d5e0
652f4e8
f20a943
652f4e8
 
 
 
 
 
 
 
 
 
 
 
694d5e0
0c44d3f
 
 
 
 
 
 
 
 
f82feee
694d5e0
 
 
 
 
 
 
 
f82feee
 
694d5e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f82feee
694d5e0
 
 
 
 
 
 
f82feee
 
694d5e0
 
 
 
 
 
 
 
 
f82feee
ce1e5cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datetime import datetime
from functools import lru_cache
import torch
import requests

# Language codes
LANGUAGE_CODES = {
    "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
    "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
    "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
}

# Translation history
class TranslationHistory:
    def __init__(self):
        self.history = []
    def add(self, src, translated, src_lang, tgt_lang):
        self.history.insert(0, {
            "source": src, "translated": translated,
            "src_lang": src_lang, "tgt_lang": tgt_lang,
            "timestamp": datetime.now().isoformat()
        })
        if len(self.history) > 100:
            self.history.pop()
    def get(self): return self.history
    def clear(self): self.history = []

history = TranslationHistory()

# Translation model
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

@lru_cache(maxsize=512)
def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
    if not text.strip(): return ""
    src_code = LANGUAGE_CODES.get(src_lang, src_lang)
    tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
    input_tokens = tokenizer(text, return_tensors="pt", padding=True)
    input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
    forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
    output = model.generate(
        **input_tokens,
        forced_bos_token_id=forced_bos_token_id,
        max_length=max_length, temperature=temperature,
        num_beams=5, early_stopping=True
    )
    result = tokenizer.decode(output[0], skip_special_tokens=True)
    history.add(text, result, src_lang, tgt_lang)
    return result

def translate_file(file, src_lang, tgt_lang, max_length, temperature):
    try:
        lines = file.decode("utf-8").splitlines()
        translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()]
        return "\n".join(translated)
    except Exception as e:
        return f"File translation error: {e}"

# Summarizer API
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
HF_API_KEY = os.environ.get("HF_API_KEY")
headers = {"Authorization": f"Bearer {HF_API_KEY}"}

def summarize_text(text, max_length):
    if not text.strip(): return ""
    min_length = max(10, max_length // 4)
    response = requests.post(API_URL, headers=headers, json={
        "inputs": text,
        "parameters": {"min_length": min_length, "max_length": max_length}
    })
    result = response.json()
    return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result)

# Paraphraser
paraphrase_tokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase")
paraphrase_model.to(device)

def paraphrase_text(input_text, num_return_sequences, num_beams):
    batch = paraphrase_tokenizer([input_text], truncation=True, padding="longest", max_length=60, return_tensors="pt").to(device)
    translated = paraphrase_model.generate(**batch, max_length=60, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
    return paraphrase_tokenizer.batch_decode(translated, skip_special_tokens=True)

# Grammar Corrector
grammar_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
grammar_tokenizer = AutoTokenizer.from_pretrained(grammar_model_name)
grammar_model = AutoModelForSeq2SeqLM.from_pretrained(grammar_model_name)
grammar_model.to(device)

def correct_grammar(text):
    input_ids = grammar_tokenizer(f"grammar: {text}", return_tensors="pt", truncation=True).input_ids.to(device)
    output_ids = grammar_model.generate(input_ids, max_length=256, num_beams=5)
    return grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)

# UI Style
gradio_style = """
.gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
"""

# Gradio UI
with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
    gr.Markdown("## 🤖 AI Toolbox: Translate, Summarize, Paraphrase, Correct Grammar")

    with gr.Tab("🌐 Translator"):
        src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English")
        swap = gr.Button("⇄")
        tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean")
        input_text = gr.Textbox(lines=3, label="Input Text")
        output_text = gr.Textbox(lines=3, label="Translated Output", interactive=False)
        translate_btn = gr.Button("🚀 Translate")
        clear_btn = gr.Button("🧽 Clear")
        max_length = gr.Slider(10, 512, value=128, label="Max Length")
        temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
        translate_btn.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
        clear_btn.click(lambda: ("", ""), None, [input_text, output_text])

    with gr.Tab("📁 File Translator"):
        file_input = gr.File(label="Upload .txt File")
        file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English")
        file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean")
        f_max_length = gr.Slider(10, 512, value=128, label="Max Length")
        f_temp = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
        file_btn = gr.Button("Translate File")
        file_result = gr.Textbox(lines=10, label="File Output", interactive=False)
        file_btn.click(lambda file, src, tgt, ml, temp: translate_file(file.read(), src, tgt, ml, temp),
                       [file_input, file_src, file_tgt, f_max_length, f_temp], file_result)

    with gr.Tab("📝 Summarizer"):
        summary_input = gr.Textbox(lines=5, label="Enter text to summarize")
        summary_length = gr.Slider(32, 512, value=128, step=8, label="Max Length")
        summary_output = gr.Textbox(lines=5, label="Summary", interactive=False)
        summarize_btn = gr.Button("Summarize")
        summarize_btn.click(summarize_text, [summary_input, summary_length], summary_output)

    with gr.Tab("🔁 Paraphraser"):
        para_input = gr.Textbox(lines=4, label="Enter text to paraphrase")
        num_outputs = gr.Slider(1, 5, value=3, step=1, label="Number of Paraphrases")
        beam_width = gr.Slider(1, 10, value=5, step=1, label="Beam Width")
        para_output = gr.Textbox(label="Paraphrased Sentences", lines=6)
        para_btn = gr.Button("Paraphrase")
        para_btn.click(lambda text, num, beams: "\n\n".join(paraphrase_text(text, num, beams)),
                       [para_input, num_outputs, beam_width], para_output)

    with gr.Tab("🛠 Grammar Corrector"):
        grammar_input = gr.Textbox(lines=5, label="Enter sentence to correct")
        grammar_output = gr.Textbox(label="Corrected Sentence", lines=5)
        grammar_btn = gr.Button("Correct Grammar")
        grammar_btn.click(correct_grammar, grammar_input, grammar_output)

    gr.Markdown(f"""
    ### ℹ️ Info
    - Translator: `{model_name}` on `{device}`
    - Paraphraser: `tuner007/pegasus_paraphrase`
    - Summarizer: `facebook/bart-large-cnn`
    - Grammar Corrector: `{grammar_model_name}`
    - API Token: {'✅ Found' if HF_API_KEY else '❌ Not Found'}
    """)

if __name__ == "__main__":
    demo.launch(share=True)