File size: 4,081 Bytes
ebb2736
 
 
 
 
 
326df18
ebb2736
326df18
ebb2736
326df18
05dccf0
 
ebb2736
 
 
326df18
 
 
 
ebb2736
326df18
ebb2736
326df18
ebb2736
 
 
 
 
 
326df18
ebb2736
326df18
ebb2736
 
 
 
 
 
 
 
 
 
 
326df18
 
 
ebb2736
326df18
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb2736
326df18
ebb2736
326df18
ebb2736
326df18
ebb2736
 
326df18
 
 
 
 
 
 
 
ebb2736
 
326df18
ebb2736
326df18
ebb2736
326df18
 
ebb2736
326df18
 
ebb2736
326df18
 
 
 
 
 
ebb2736
326df18
ebb2736
326df18
ebb2736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326df18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from ttsmms import download, TTS
from langdetect import detect
from gradio_client import Client

# =========================
# Load ASR Model
# =========================
asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
# asr_model_name = "openai/whisper-large-v3-turbo"
processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)

# =========================
# Load Text Generation Model via Gradio Client
# =========================
llm_client = Client("Futuresony/Mr.Events")

# =========================
# Load TTS Models
# =========================
swahili_dir = download("swh", "./data/swahili")
english_dir = download("eng", "./data/english")

swahili_tts = TTS(swahili_dir)
english_tts = TTS(english_dir)

# =========================
# ASR Function
# =========================
def transcribe(audio_file):
    speech_array, sample_rate = torchaudio.load(audio_file)
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    speech_array = resampler(speech_array).squeeze().numpy()
    input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
    with torch.no_grad():
        logits = asr_model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    return transcription

# =========================
# Text Generation Function (Safe)
# =========================
def generate_text(prompt):
    print(f"[DEBUG] Generating text for prompt: {prompt} (type: {type(prompt)})")
    
    result = llm_client.predict(query=prompt, api_name="/chat")
    print(f"[DEBUG] /chat returned: {result} (type: {type(result)})")
    
    # Ensure result is always a string
    if not isinstance(result, str):
        try:
            result = " ".join(map(str, result)) if isinstance(result, (list, tuple)) else str(result)
        except Exception as e:
            print(f"[ERROR] Failed to convert result to string: {e}")
            result = "Error: Unable to generate text."

    return result.strip()

# =========================
# TTS Function
# =========================
def text_to_speech(text):
    print(f"[DEBUG] Converting text to speech: {text} (type: {type(text)})")
    lang = detect(text)
    wav_path = "./output.wav"
    try:
        if lang == "sw":
            swahili_tts.synthesis(text, wav_path=wav_path)
        else:
            english_tts.synthesis(text, wav_path=wav_path)
    except Exception as e:
        print(f"[ERROR] TTS synthesis failed: {e}")
        return None
    return wav_path

# =========================
# Combined Processing Function
# =========================
def process_audio(audio):
    print(f"[DEBUG] Processing audio: {audio} (type: {type(audio)})")
    
    transcription = transcribe(audio)
    print(f"[DEBUG] Transcription: {transcription}")
    
    generated_text = generate_text(transcription)
    print(f"[DEBUG] Generated Text: {generated_text}")
    
    speech_path = text_to_speech(generated_text)
    print(f"[DEBUG] Speech Path: {speech_path}")
    
    return transcription, generated_text, speech_path

# =========================
# Gradio Interface
# =========================
with gr.Blocks() as demo:
    gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
    gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")

    audio_input = gr.Audio(label="Input Audio", type="filepath")
    text_output = gr.Textbox(label="Transcription")
    generated_text_output = gr.Textbox(label="Generated Text")
    audio_output = gr.Audio(label="Output Speech")
    submit_btn = gr.Button("Submit")

    submit_btn.click(
        fn=process_audio,
        inputs=audio_input,
        outputs=[text_output, generated_text_output, audio_output]
    )

if __name__ == "__main__":
    demo.launch()