Spaces:
Sleeping
Sleeping
# -- coding: utf-8 -- | |
import asyncio | |
import websockets | |
import json | |
import base64 | |
import os | |
import time | |
from typing import Optional, Callable, List, Dict, Any | |
from enum import Enum | |
class TurnDetectionMode(Enum): | |
SERVER_VAD = "server_vad" | |
MANUAL = "manual" | |
class OmniRealtimeClient: | |
""" | |
A demo client for interacting with the Omni Realtime API. | |
This class provides methods to connect to the Realtime API, send text and audio data, | |
handle responses, and manage the WebSocket connection. | |
Attributes: | |
base_url (str): | |
The base URL for the Realtime API. | |
api_key (str): | |
The API key for authentication. | |
model (str): | |
Omni model to use for chat. | |
voice (str): | |
The voice to use for audio output. | |
turn_detection_mode (TurnDetectionMode): | |
The mode for turn detection. | |
on_text_delta (Callable[[str], None]): | |
Callback for text delta events. | |
Takes in a string and returns nothing. | |
on_audio_delta (Callable[[bytes], None]): | |
Callback for audio delta events. | |
Takes in bytes and returns nothing. | |
on_input_transcript (Callable[[str], None]): | |
Callback for input transcript events. | |
Takes in a string and returns nothing. | |
on_interrupt (Callable[[], None]): | |
Callback for user interrupt events, should be used to stop audio playback. | |
on_output_transcript (Callable[[str], None]): | |
Callback for output transcript events. | |
Takes in a string and returns nothing. | |
extra_event_handlers (Dict[str, Callable[[Dict[str, Any]], None]]): | |
Additional event handlers. | |
Is a mapping of event names to functions that process the event payload. | |
""" | |
def __init__( | |
self, | |
base_url, | |
api_key: str, | |
model: str = "", | |
voice: str = "Chelsie", | |
turn_detection_mode: TurnDetectionMode = TurnDetectionMode.MANUAL, | |
on_text_delta: Optional[Callable[[str], None]] = None, | |
on_audio_delta: Optional[Callable[[bytes], None]] = None, | |
on_interrupt: Optional[Callable[[], None]] = None, | |
on_input_transcript: Optional[Callable[[str], None]] = None, | |
on_output_transcript: Optional[Callable[[str], None]] = None, | |
extra_event_handlers: Optional[Dict[str, Callable[[Dict[str, Any]], None]]] = None | |
): | |
self.base_url = base_url | |
self.api_key = api_key | |
self.model = model | |
self.voice = voice | |
self.ws = None | |
self.on_text_delta = on_text_delta | |
self.on_audio_delta = on_audio_delta | |
self.on_interrupt = on_interrupt | |
self.on_input_transcript = on_input_transcript | |
self.on_output_transcript = on_output_transcript | |
self.turn_detection_mode = turn_detection_mode | |
self.extra_event_handlers = extra_event_handlers or {} | |
# Track current response state | |
self._current_response_id = None | |
self._current_item_id = None | |
self._is_responding = False | |
# Track printing state for input and output transcripts | |
self._print_input_transcript = False | |
self._output_transcript_buffer = "" | |
async def connect(self) -> None: | |
"""Establish WebSocket connection with the Realtime API.""" | |
url = f"{self.base_url}?model={self.model}" | |
headers = { | |
"Authorization": f"Bearer {self.api_key}" | |
} | |
print(f"url: {url}, headers: {headers}") | |
self.ws = await websockets.connect(url, additional_headers=headers) | |
# Set up default session configuration | |
if self.turn_detection_mode == TurnDetectionMode.MANUAL: | |
await self.update_session({ | |
"modalities": ["text", "audio"], | |
"voice": self.voice, | |
"input_audio_format": "pcm16", | |
"output_audio_format": "pcm16", | |
"input_audio_transcription": { | |
"model": "gummy-realtime-v1" | |
}, | |
"turn_detection" : None | |
}) | |
elif self.turn_detection_mode == TurnDetectionMode.SERVER_VAD: | |
await self.update_session({ | |
"modalities": ["text", "audio"], | |
"voice": self.voice, | |
"input_audio_format": "pcm16", | |
"output_audio_format": "pcm16", | |
"input_audio_transcription": { | |
"model": "gummy-realtime-v1" | |
}, | |
"turn_detection": { | |
"type": "server_vad", | |
"threshold": 0.1, | |
"prefix_padding_ms": 500, | |
"silence_duration_ms": 900 | |
} | |
}) | |
else: | |
raise ValueError(f"Invalid turn detection mode: {self.turn_detection_mode}") | |
async def send_event(self, event) -> None: | |
event['event_id'] = "event_" + str(int(time.time() * 1000)) | |
print(f" Send event: type={event['type']}, event_id={event['event_id']}") | |
await self.ws.send(json.dumps(event)) | |
async def update_session(self, config: Dict[str, Any]) -> None: | |
"""Update session configuration.""" | |
event = { | |
"type": "session.update", | |
"session": config | |
} | |
print("update session: ", event) | |
await self.send_event(event) | |
async def stream_audio(self, audio_chunk: bytes) -> None: | |
"""Stream raw audio data to the API.""" | |
# only support 16bit 16kHz mono pcm | |
audio_b64 = base64.b64encode(audio_chunk).decode() | |
append_event = { | |
"type": "input_audio_buffer.append", | |
"audio": audio_b64 | |
} | |
await self.send_event(json.dumps(append_event)) | |
async def create_response(self) -> None: | |
"""Request a response from the API. Needed when using manual mode.""" | |
event = { | |
"type": "response.create", | |
"response": { | |
"instructions": "你是Tom,一个美国的导购,负责售卖手机、电视", | |
"modalities": ["text", "audio"] | |
} | |
} | |
print("create response: ", event) | |
await self.send_event(event) | |
async def cancel_response(self) -> None: | |
"""Cancel the current response.""" | |
event = { | |
"type": "response.cancel" | |
} | |
await self.send_event(event) | |
async def handle_interruption(self): | |
"""Handle user interruption of the current response.""" | |
if not self._is_responding: | |
return | |
print(" Handling interruption") | |
# 1. Cancel the current response | |
if self._current_response_id: | |
await self.cancel_response() | |
self._is_responding = False | |
self._current_response_id = None | |
self._current_item_id = None | |
async def handle_messages(self) -> None: | |
try: | |
async for message in self.ws: | |
event = json.loads(message) | |
event_type = event.get("type") | |
if event_type != "response.audio.delta": | |
print(" event: ", event) | |
else: | |
print(" event_type: ", event_type) | |
if event_type == "error": | |
print(" Error: ", event['error']) | |
continue | |
elif event_type == "response.created": | |
self._current_response_id = event.get("response", {}).get("id") | |
self._is_responding = True | |
elif event_type == "response.output_item.added": | |
self._current_item_id = event.get("item", {}).get("id") | |
elif event_type == "response.done": | |
self._is_responding = False | |
self._current_response_id = None | |
self._current_item_id = None | |
# Handle interruptions | |
elif event_type == "input_audio_buffer.speech_started": | |
print(" Speech detected") | |
if self._is_responding: | |
print(" Handling interruption") | |
await self.handle_interruption() | |
if self.on_interrupt: | |
print(" Handling on_interrupt, stop playback") | |
self.on_interrupt() | |
elif event_type == "input_audio_buffer.speech_stopped": | |
print(" Speech ended") | |
# Handle normal response events | |
elif event_type == "response.text.delta": | |
if self.on_text_delta: | |
self.on_text_delta(event["delta"]) | |
elif event_type == "response.audio.delta": | |
if self.on_audio_delta: | |
audio_bytes = base64.b64decode(event["delta"]) | |
self.on_audio_delta(audio_bytes) | |
elif event_type == "conversation.item.input_audio_transcription.completed": | |
transcript = event.get("transcript", "") | |
if self.on_input_transcript: | |
await asyncio.to_thread(self.on_input_transcript,transcript) | |
self._print_input_transcript = True | |
elif event_type == "response.audio_transcript.delta": | |
if self.on_output_transcript: | |
delta = event.get("delta", "") | |
if not self._print_input_transcript: | |
self._output_transcript_buffer += delta | |
else: | |
if self._output_transcript_buffer: | |
await asyncio.to_thread(self.on_output_transcript,self._output_transcript_buffer) | |
self._output_transcript_buffer = "" | |
await asyncio.to_thread(self.on_output_transcript,delta) | |
elif event_type == "response.audio_transcript.done": | |
self._print_input_transcript = False | |
elif event_type in self.extra_event_handlers: | |
self.extra_event_handlers[event_type](event) | |
except websockets.exceptions.ConnectionClosed: | |
print(" Connection closed") | |
except Exception as e: | |
print(" Error in message handling: ", str(e)) | |
async def close(self) -> None: | |
"""Close the WebSocket connection.""" | |
if self.ws: | |
await self.ws.close() | |