Spaces:
Sleeping
Sleeping
#app.py | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
from IndicTransToolkit.processor import IndicProcessor | |
import requests | |
from datetime import datetime | |
import tempfile | |
from gtts import gTTS | |
import os | |
import shutil | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models | |
model_en_to_indic = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True).to(DEVICE) | |
tokenizer_en_to_indic = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True) | |
model_indic_to_en = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-indic-en-1B", trust_remote_code=True).to(DEVICE) | |
tokenizer_indic_to_en = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-indic-en-1B", trust_remote_code=True) | |
ip = IndicProcessor(inference=True) | |
asr = pipeline("automatic-speech-recognition", model="openai/whisper-small") | |
# --- Supabase settings --- | |
SUPABASE_URL = "https://gptmdbhzblfybdnohqnh.supabase.co" | |
SUPABASE_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImdwdG1kYmh6YmxmeWJkbm9ocW5oIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDc0NjY1NDgsImV4cCI6MjA2MzA0MjU0OH0.CfWArts6Kd_x7Wj0a_nAyGJfrFt8F7Wdy_MdYDj9e7U" | |
# --- Supabase utilities --- | |
def save_to_supabase(input_text, output_text, direction): | |
if not input_text.strip() or not output_text.strip(): | |
return "Nothing to save." | |
table = "translations" if direction == "en_to_ks" else "ks_to_en_translations" | |
payload = { | |
"timestamp": datetime.utcnow().isoformat(), | |
"input_text": input_text, | |
"output_text": output_text | |
} | |
headers = { | |
"apikey": SUPABASE_API_KEY, | |
"Authorization": f"Bearer {SUPABASE_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post(f"{SUPABASE_URL}/rest/v1/{table}", json=payload, headers=headers) | |
return "Saved successfully!" if response.status_code == 201 else "β Failed to save." | |
except Exception as e: | |
logging.error("Save error: %s", e) | |
return "Save error." | |
# --- Save verified translation --- | |
def save_verified_translation(original_text, verified_text): | |
if not original_text.strip() or not verified_text.strip(): | |
return "Nothing to save." | |
payload = { | |
"timestamp": datetime.utcnow().isoformat(), | |
"original_translation": original_text, | |
"verified_translation": verified_text | |
} | |
headers = { | |
"apikey": SUPABASE_API_KEY, | |
"Authorization": f"Bearer {SUPABASE_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post(f"{SUPABASE_URL}/rest/v1/verified_translations", json=payload, headers=headers) | |
return "Verified translation saved!" if response.status_code == 201 else "β Failed to save verified translation." | |
except Exception as e: | |
logging.error("Verified Save error: %s", e) | |
return "Verified save error." | |
def get_translation_history(direction): | |
headers = { | |
"apikey": SUPABASE_API_KEY, | |
"Authorization": f"Bearer {SUPABASE_API_KEY}" | |
} | |
table = "translations" if direction == "en_to_ks" else "ks_to_en_translations" | |
try: | |
res = requests.get(f"{SUPABASE_URL}/rest/v1/{table}?order=timestamp.desc&limit=20", headers=headers) | |
normal_data = res.json() if res.status_code == 200 else [] | |
vres = requests.get(f"{SUPABASE_URL}/rest/v1/verified_translations?order=timestamp.desc&limit=20", headers=headers) | |
verified_data = vres.json() if vres.status_code == 200 else [] | |
normal_history = "\n".join([ | |
f"Input: {r['input_text']} β Output: {r['output_text']}" | |
for r in normal_data | |
]) or "No regular translations yet." | |
verified_history = "\n".join([ | |
f"Verified: {r['original_translation']} β {r['verified_translation']}" | |
for r in verified_data | |
]) or "No verified translations yet." | |
return f"--- Regular Translations ---\n{normal_history}\n\n--- Verified Translations ---\n{verified_history}" | |
except Exception as e: | |
logging.error("History error: %s", e) | |
return "Error loading history." | |
# --- Translation with TTS integration --- | |
def translate(text, direction, generate_tts=False): | |
if not text.strip(): | |
return "Enter some text.", None | |
if direction == "en_to_ks": | |
src_lang, tgt_lang = "eng_Latn", "kas_Arab" | |
model, tokenizer = model_en_to_indic, tokenizer_en_to_indic | |
else: | |
src_lang, tgt_lang = "kas_Arab", "eng_Latn" | |
model, tokenizer = model_indic_to_en, tokenizer_indic_to_en | |
try: | |
batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang) | |
tokens = tokenizer(batch, return_tensors="pt", padding=True).to(DEVICE) | |
with torch.no_grad(): | |
output = model.generate(**tokens, max_length=256, num_beams=5) | |
result = tokenizer.batch_decode(output, skip_special_tokens=True) | |
final = ip.postprocess_batch(result, lang=tgt_lang)[0] | |
# Generate TTS for KSβEN direction if requested | |
audio_path = None | |
if generate_tts and direction == "ks_to_en": | |
audio_path = synthesize_tts(final) | |
return final, audio_path | |
except Exception as e: | |
logging.error("Translation error: %s", e) | |
return "Translation failed.", None | |
# --- TTS for English output --- | |
def synthesize_tts(text): | |
try: | |
tts = gTTS(text=text, lang="en") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: | |
tts.save(f.name) | |
return f.name | |
except Exception as e: | |
logging.error("TTS error: %s", e) | |
return None | |
# --- STT for English audio --- | |
def transcribe_audio(audio_path): | |
try: | |
if not audio_path: | |
return None, "No audio file provided" | |
# Create a persistent copy of the audio file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
temp_path = f.name | |
shutil.copy(audio_path, temp_path) | |
transcription = asr(temp_path)["text"] | |
os.unlink(temp_path) # Clean up temporary file | |
return transcription, None | |
except Exception as e: | |
logging.error("STT error: %s", e) | |
return None, f"Transcription failed: {str(e)}" | |
# --- Store audio file path --- | |
def store_audio(audio_path): | |
"""Store audio path in state and return it to keep it visible""" | |
return audio_path | |
# --- Handle audio translation --- | |
def handle_audio_translation(audio_path, direction): | |
if direction != "en_to_ks": | |
return "Audio input is only supported for English to Kashmiri.", "", "", audio_path | |
transcription, error = transcribe_audio(audio_path) | |
if error: | |
return error, "", "", audio_path | |
translated, _ = translate(transcription, direction, generate_tts=False) | |
return "", transcription, translated, audio_path | |
# --- Switch UI direction --- | |
def switch_direction(direction, input_text_val, output_text_val, audio_path): | |
new_direction = "ks_to_en" if direction == "en_to_ks" else "en_to_ks" | |
input_label = "Kashmiri Text" if new_direction == "ks_to_en" else "English Text" | |
output_label = "English Translation" if new_direction == "ks_to_en" else "Kashmiri Translation" | |
return new_direction, gr.update(value=output_text_val, label=input_label), gr.update(value=input_text_val, label=output_label), None | |
# === Gradio Interface === | |
with gr.Blocks() as interface: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: space-between; align-items: center; padding: 10px;"> | |
<img src="https://raw.githubusercontent.com/BurhaanRasheedZargar/Images/211321a234613a9c3dd944fe9367cf13d1386239/assets/left_logo.png" style="height:150px; width:auto;"> | |
<h2 style="margin: 0; text-align: center;">English β Kashmiri Translator</h2> | |
<img src="https://raw.githubusercontent.com/BurhaanRasheedZargar/Images/77797f7f7cbee328fa0f9d31cf3e290441e04cd3/assets/right_logo.png"> | |
</div> | |
""") | |
translation_direction = gr.State(value="en_to_ks") | |
stored_audio = gr.State() | |
with gr.Row(): | |
input_text = gr.Textbox(label="English Text", placeholder="Enter text here...", lines=2) | |
output_text = gr.Textbox(label="Kashmiri Translation", placeholder="Translated text...", lines=2) | |
with gr.Row(): | |
verified_text = gr.Textbox(label="βοΈ Edit Translation", placeholder="Edit translation here...", lines=2) | |
with gr.Row(): | |
translate_button = gr.Button("Translate") | |
save_button = gr.Button("Save Translation") | |
switch_button = gr.Button("Switch Direction") | |
verify_button = gr.Button("β Verify & Save") | |
save_status = gr.Textbox(label="Save Status", interactive=False) | |
history = gr.Textbox(label="Translation History", lines=8, interactive=False) | |
with gr.Row(): | |
audio_input = gr.Audio(type="filepath", label="ποΈ Record English audio", sources=["microphone"]) | |
audio_output = gr.Audio(label="π English TTS", interactive=False) | |
with gr.Row(): | |
stt_button = gr.Button("π€ Transcribe & Translate (EN β KS)") | |
tts_button = gr.Button("π Translate & Speak (KS β EN)") | |
# Store audio when recorded | |
audio_input.change( | |
fn=store_audio, | |
inputs=audio_input, | |
outputs=stored_audio | |
) | |
# Events | |
translate_button.click( | |
fn=translate, | |
inputs=[input_text, translation_direction, gr.State(False)], | |
outputs=[output_text, audio_output] | |
).then( | |
fn=lambda txt: txt, | |
inputs=output_text, | |
outputs=verified_text | |
) | |
tts_button.click( | |
fn=translate, | |
inputs=[input_text, translation_direction, gr.State(True)], | |
outputs=[output_text, audio_output] | |
).then( | |
fn=lambda txt: txt, | |
inputs=output_text, | |
outputs=verified_text | |
) | |
save_button.click( | |
fn=save_to_supabase, | |
inputs=[input_text, output_text, translation_direction], | |
outputs=save_status | |
).then( | |
fn=get_translation_history, | |
inputs=translation_direction, | |
outputs=history | |
) | |
switch_button.click( | |
fn=switch_direction, | |
inputs=[translation_direction, input_text, output_text, stored_audio], | |
outputs=[translation_direction, input_text, output_text, audio_output] | |
) | |
stt_button.click( | |
fn=handle_audio_translation, | |
inputs=[stored_audio, translation_direction], | |
outputs=[save_status, input_text, output_text, audio_input] | |
).then( | |
fn=lambda txt: txt, | |
inputs=output_text, | |
outputs=verified_text | |
) | |
verify_button.click( | |
fn=save_verified_translation, | |
inputs=[output_text, verified_text], | |
outputs=save_status | |
).then( | |
fn=get_translation_history, | |
inputs=translation_direction, | |
outputs=history | |
) | |
if __name__ == "__main__": | |
interface.queue().launch(share=True) |