geminilive / app.py
nihalaninihal's picture
Update app.py
3d69bf6 verified
import asyncio
import pathlib
from typing import AsyncGenerator, Literal
from google import genai
from google.genai.types import (
Content,
LiveConnectConfig,
Part,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
)
import gradio as gr
from fastrtc import AsyncStreamHandler, WebRTC, async_aggregate_bytes_to_16bit
import numpy as np
current_dir = pathlib.Path(__file__).parent
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API"""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
input_sample_rate: int = 16000,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=input_sample_rate,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
def copy(self) -> "GeminiHandler":
"""Required implementation of the copy method for AsyncStreamHandler"""
return GeminiHandler(
expected_layout=self.expected_layout,
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
async def stream(self) -> AsyncGenerator[bytes, None]:
"""Helper method to stream input audio to the server. Used in start_stream."""
while not self.quit.is_set():
audio = await self.input_queue.get()
yield audio
return
async def connect(
self,
project_id: str,
location: str,
voice_name: str | None = None,
system_instruction: str | None = None,
) -> AsyncGenerator[bytes, None]:
"""Connect to the Gemini server and start the stream."""
client = genai.Client(vertexai=True, project=project_id, location=location)
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
),
system_instruction=Content(parts=[Part.from_text(text=system_instruction)]),
)
async with client.aio.live.connect(
model="gemini-2.0-flash-live-preview-04-09", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
yield audio.data
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
"""Receive audio from the user and put it in the input stream."""
_, array = frame
array = array.squeeze()
audio_message = array.tobytes()
self.input_queue.put_nowait(audio_message)
async def generator(self) -> None:
"""Helper method for placing audio from the server into the output queue."""
async for audio_response in async_aggregate_bytes_to_16bit(
self.connect(*self.latest_args[1:])
):
self.output_queue.put_nowait(audio_response)
async def emit(self) -> tuple[int, np.ndarray]:
"""Required implementation of the emit method for AsyncStreamHandler"""
if not self.args_set.is_set():
await self.wait_for_args()
asyncio.create_task(self.generator())
array = await self.output_queue.get()
return (self.output_sample_rate, array)
def shutdown(self) -> None:
"""Stop the stream method on shutdown"""
self.quit.set()
css = (current_dir / "style.css").read_text()
header = (current_dir / "header.html").read_text()
with gr.Blocks(css=css) as demo:
gr.HTML(header)
with gr.Group(visible=True, elem_id="api-form") as api_key_row:
with gr.Row():
_project_id = gr.Textbox(
label="Project ID",
placeholder="Enter your Google Cloud Project ID",
)
_location = gr.Dropdown(
label="Location",
choices=[
"us-central1",
],
value="us-central1",
info="You can find additional locations [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#united-states)",
)
_voice_name = gr.Dropdown(
label="Voice",
choices=[
"Puck",
"Charon",
"Kore",
"Fenrir",
"Aoede",
],
value="Puck",
)
_system_instruction = gr.Textbox(
label="System Instruction",
placeholder="Talk like a pirate.",
)
with gr.Row():
submit = gr.Button(value="Submit")
with gr.Row(visible=False) as row:
webrtc = WebRTC(
label="Conversation",
modality="audio",
mode="send-receive",
# See for changes needed to deploy behind a firewall
# https://fastrtc.org/deployment/
rtc_configuration=None,
)
webrtc.stream(
GeminiHandler(),
inputs=[webrtc, _project_id, _location, _voice_name, _system_instruction],
outputs=[webrtc],
time_limit=90,
concurrency_limit=2,
)
submit.click(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None,
[api_key_row, row],
)
if __name__ == "__main__":
demo.launch()