File size: 5,715 Bytes
bfa0389
 
 
 
 
 
 
 
 
 
f872fa9
bfa0389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f872fa9
bfa0389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f872fa9
bfa0389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f872fa9
bfa0389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f872fa9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os, tempfile, time, traceback
from pathlib import Path
import gradio as gr
from groq import Groq

# Read secret from HF Spaces. Support both "groq_api_key" and "GROQ_API_KEY".
def _load_key() -> str:
    key = os.environ.get("GROQ_API_KEY") or os.environ.get("groq_api_key")
    if not key:
        raise RuntimeError(
            "Groq API key not found. In your Space settings -> Secrets, add 'groq_api_key'."
        )
    os.environ["GROQ_API_KEY"] = key
    return key

client = Groq(api_key=_load_key())

def transcribe_audio(audio_path: str, model: str = "whisper-large-v3") -> str:
    if not audio_path or not Path(audio_path).exists():
        raise FileNotFoundError("Audio file path is missing or not found.")
    with open(audio_path, "rb") as f:
        resp = client.audio.transcriptions.create(
            file=(Path(audio_path).name, f.read()),
            model=model,
            response_format="verbose_json",
        )
    return (getattr(resp, "text", "") or "").strip()

def stream_answer(prompt_text: str,
                  model: str = "llama-3.1-8b-instant",
                  temperature: float = 0.3):
    if not prompt_text.strip():
        raise ValueError("Empty prompt for the LLM.")
    stream = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "Answer clearly and concisely."},
            {"role": "user", "content": prompt_text},
        ],
        temperature=temperature,
        max_completion_tokens=1024,
        top_p=1,
        stream=True,
    )
    acc = []
    for chunk in stream:
        delta = chunk.choices[0].delta.content or ""
        if delta:
            acc.append(delta)
            yield "".join(acc)
    yield "".join(acc)

def text_to_speech(text: str,
                   voice: str = "Calum-PlayAI",
                   model: str = "playai-tts",
                   fmt: str = "wav") -> str:
    if not text.strip():
        raise ValueError("Empty text for TTS.")
    tts_input = text[:1200]
    resp = client.audio.speech.create(
        model=model,
        voice=voice,
        response_format=fmt,
        input=tts_input,
    )
    out_path = os.path.join(tempfile.gettempdir(), f"answer_{int(time.time())}.{fmt}")
    # BinaryAPIResponse uses write_to_file in Groq SDK
    resp.write_to_file(out_path)
    return out_path

def run_pipeline(audio_file, typed_question, llm_model, voice_name):
    transcript = ""
    answer = ""
    try:
        if typed_question and typed_question.strip():
            transcript = typed_question.strip()
            status = "Using typed question."
        else:
            if not audio_file:
                raise RuntimeError("Provide a recording or type a question.")
            status = "Transcribing audio..."
            yield transcript, answer, None, status
            transcript = transcribe_audio(audio_file)
            if not transcript:
                raise RuntimeError("No text returned by transcription.")
            status = "Transcription done."

        yield transcript, answer, None, status

        status = "Generating answer..."
        partial = ""
        for partial in stream_answer(transcript, model=llm_model):
            answer = partial
            yield transcript, answer, None, status
        if not answer.strip():
            raise RuntimeError("No text returned by the LLM.")

        status = "Converting answer to speech..."
        yield transcript, answer, None, status
        audio_out = text_to_speech(answer, voice=voice_name)
        status = "Done."
        yield transcript, answer, audio_out, status

    except Exception as e:
        err = "Error: " + str(e)
        short_tb = "\n".join(traceback.format_exc().splitlines()[-6:])
        help_tip = (
            "\nTips:\n"
            "- Check Space secret 'groq_api_key'.\n"
            "- Try a shorter audio clip.\n"
            "- Verify model names.\n"
            "- Confirm requirements installed."
        )
        yield transcript, answer, None, err + "\n" + short_tb + help_tip

with gr.Blocks(title="Audio Q&A with Groq") as demo:
    gr.Markdown("# Audio Q&A with Groq")
    gr.Markdown("One audio or typed question in, one answer out, plus speech.")

    with gr.Row():
        audio_in = gr.Audio(
            sources=["microphone", "upload"],
            type="filepath",
            label="Question audio"
        )
        typed_in = gr.Textbox(label="Or type your question", placeholder="Optional")

    with gr.Row():
        llm_model = gr.Dropdown(
            choices=[
                "llama-3.1-8b-instant",
                "llama-3.1-70b-versatile",
                "llama3-8b-8192",
            ],
            value="llama-3.1-8b-instant",
            label="LLM model"
        )
        voice_name = gr.Textbox(value="Calum-PlayAI", label="TTS voice")

    ask_btn = gr.Button("Run")
    clear_btn = gr.Button("Clear")

    transcript_box = gr.Textbox(label="Transcription", interactive=False, lines=4)
    answer_box = gr.Textbox(label="Answer", interactive=False, lines=10)
    answer_audio = gr.Audio(label="Answer speech", interactive=False)
    status_md = gr.Markdown("")

    ask_btn.click(
        fn=run_pipeline,
        inputs=[audio_in, typed_in, llm_model, voice_name],
        outputs=[transcript_box, answer_box, answer_audio, status_md]
    )

    def clear_all():
        return "", "", None, ""
    clear_btn.click(fn=clear_all, inputs=None, outputs=[transcript_box, answer_box, answer_audio, status_md])

if __name__ == "__main__":
    # On HF Spaces you can simply do demo.launch()
    # Queue enables generator streaming without extra args in Gradio v4
    demo.queue().launch()