Project_1 / app.py
Futuresony's picture
Rename app(good).py to app.py
270d231 verified
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from ttsmms import download, TTS
from langdetect import detect
from gradio_client import Client
# =========================
# Load ASR Model
# =========================
asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
# asr_model_name = "openai/whisper-large-v3-turbo"
processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
# =========================
# Load Text Generation Model via Gradio Client
# =========================
llm_client = Client("Futuresony/Mr.Events")
# =========================
# Load TTS Models
# =========================
swahili_dir = download("swh", "./data/swahili")
english_dir = download("eng", "./data/english")
swahili_tts = TTS(swahili_dir)
english_tts = TTS(english_dir)
# =========================
# ASR Function
# =========================
def transcribe(audio_file):
speech_array, sample_rate = torchaudio.load(audio_file)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
speech_array = resampler(speech_array).squeeze().numpy()
input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
with torch.no_grad():
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# =========================
# Text Generation Function (Safe)
# =========================
def generate_text(prompt):
print(f"[DEBUG] Generating text for prompt: {prompt} (type: {type(prompt)})")
result = llm_client.predict(query=prompt, api_name="/chat")
print(f"[DEBUG] /chat returned: {result} (type: {type(result)})")
# Ensure result is always a string
if not isinstance(result, str):
try:
result = " ".join(map(str, result)) if isinstance(result, (list, tuple)) else str(result)
except Exception as e:
print(f"[ERROR] Failed to convert result to string: {e}")
result = "Error: Unable to generate text."
return result.strip()
# =========================
# TTS Function
# =========================
def text_to_speech(text):
print(f"[DEBUG] Converting text to speech: {text} (type: {type(text)})")
lang = detect(text)
wav_path = "./output.wav"
try:
if lang == "sw":
swahili_tts.synthesis(text, wav_path=wav_path)
else:
english_tts.synthesis(text, wav_path=wav_path)
except Exception as e:
print(f"[ERROR] TTS synthesis failed: {e}")
return None
return wav_path
# =========================
# Combined Processing Function
# =========================
def process_audio(audio):
print(f"[DEBUG] Processing audio: {audio} (type: {type(audio)})")
transcription = transcribe(audio)
print(f"[DEBUG] Transcription: {transcription}")
generated_text = generate_text(transcription)
print(f"[DEBUG] Generated Text: {generated_text}")
speech_path = text_to_speech(generated_text)
print(f"[DEBUG] Speech Path: {speech_path}")
return transcription, generated_text, speech_path
# =========================
# Gradio Interface
# =========================
with gr.Blocks() as demo:
gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
audio_input = gr.Audio(label="Input Audio", type="filepath")
text_output = gr.Textbox(label="Transcription")
generated_text_output = gr.Textbox(label="Generated Text")
audio_output = gr.Audio(label="Output Speech")
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=process_audio,
inputs=audio_input,
outputs=[text_output, generated_text_output, audio_output]
)
if __name__ == "__main__":
demo.launch()