import asyncio import pathlib from typing import AsyncGenerator, Literal from google import genai from google.genai.types import ( Content, LiveConnectConfig, Part, PrebuiltVoiceConfig, SpeechConfig, VoiceConfig, ) import gradio as gr from fastrtc import AsyncStreamHandler, WebRTC, async_aggregate_bytes_to_16bit import numpy as np current_dir = pathlib.Path(__file__).parent class GeminiHandler(AsyncStreamHandler): """Handler for the Gemini API""" def __init__( self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000, output_frame_size: int = 480, input_sample_rate: int = 16000, ) -> None: super().__init__( expected_layout, output_sample_rate, output_frame_size, input_sample_rate=input_sample_rate, ) self.input_queue: asyncio.Queue = asyncio.Queue() self.output_queue: asyncio.Queue = asyncio.Queue() self.quit: asyncio.Event = asyncio.Event() def copy(self) -> "GeminiHandler": """Required implementation of the copy method for AsyncStreamHandler""" return GeminiHandler( expected_layout=self.expected_layout, output_sample_rate=self.output_sample_rate, output_frame_size=self.output_frame_size, ) async def stream(self) -> AsyncGenerator[bytes, None]: """Helper method to stream input audio to the server. Used in start_stream.""" while not self.quit.is_set(): audio = await self.input_queue.get() yield audio return async def connect( self, project_id: str, location: str, voice_name: str | None = None, system_instruction: str | None = None, ) -> AsyncGenerator[bytes, None]: """Connect to the Gemini server and start the stream.""" client = genai.Client(vertexai=True, project=project_id, location=location) config = LiveConnectConfig( response_modalities=["AUDIO"], speech_config=SpeechConfig( voice_config=VoiceConfig( prebuilt_voice_config=PrebuiltVoiceConfig( voice_name=voice_name, ) ) ), system_instruction=Content(parts=[Part.from_text(text=system_instruction)]), ) async with client.aio.live.connect( model="gemini-2.0-flash-live-preview-04-09", config=config ) as session: async for audio in session.start_stream( stream=self.stream(), mime_type="audio/pcm" ): if audio.data: yield audio.data async def receive(self, frame: tuple[int, np.ndarray]) -> None: """Receive audio from the user and put it in the input stream.""" _, array = frame array = array.squeeze() audio_message = array.tobytes() self.input_queue.put_nowait(audio_message) async def generator(self) -> None: """Helper method for placing audio from the server into the output queue.""" async for audio_response in async_aggregate_bytes_to_16bit( self.connect(*self.latest_args[1:]) ): self.output_queue.put_nowait(audio_response) async def emit(self) -> tuple[int, np.ndarray]: """Required implementation of the emit method for AsyncStreamHandler""" if not self.args_set.is_set(): await self.wait_for_args() asyncio.create_task(self.generator()) array = await self.output_queue.get() return (self.output_sample_rate, array) def shutdown(self) -> None: """Stop the stream method on shutdown""" self.quit.set() css = (current_dir / "style.css").read_text() header = (current_dir / "header.html").read_text() with gr.Blocks(css=css) as demo: gr.HTML(header) with gr.Group(visible=True, elem_id="api-form") as api_key_row: with gr.Row(): _project_id = gr.Textbox( label="Project ID", placeholder="Enter your Google Cloud Project ID", ) _location = gr.Dropdown( label="Location", choices=[ "us-central1", ], value="us-central1", info="You can find additional locations [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#united-states)", ) _voice_name = gr.Dropdown( label="Voice", choices=[ "Puck", "Charon", "Kore", "Fenrir", "Aoede", ], value="Puck", ) _system_instruction = gr.Textbox( label="System Instruction", placeholder="Talk like a pirate.", ) with gr.Row(): submit = gr.Button(value="Submit") with gr.Row(visible=False) as row: webrtc = WebRTC( label="Conversation", modality="audio", mode="send-receive", # See for changes needed to deploy behind a firewall # https://fastrtc.org/deployment/ rtc_configuration=None, ) webrtc.stream( GeminiHandler(), inputs=[webrtc, _project_id, _location, _voice_name, _system_instruction], outputs=[webrtc], time_limit=90, concurrency_limit=2, ) submit.click( lambda: (gr.update(visible=False), gr.update(visible=True)), None, [api_key_row, row], ) if __name__ == "__main__": demo.launch()