Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://help.aliyun.com/zh/model-studio/realtime#1234095db03g3 | |
""" | |
import argparse | |
import asyncio | |
import base64 | |
from enum import Enum | |
import json | |
import os | |
import queue | |
import threading | |
import time | |
from typing import Optional, Callable, List, Dict, Any | |
import websockets | |
import pyaudio | |
from project_settings import environment | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--qwen_api_key", | |
default=environment.get(key="QWEN_API_KEY"), | |
type=str | |
) | |
parser.add_argument( | |
"--model_name", | |
default="qwen-omni-turbo-realtime-2025-05-08", | |
type=str | |
) | |
args = parser.parse_args() | |
return args | |
class TurnDetectionMode(Enum): | |
SERVER_VAD = "server_vad" | |
MANUAL = "manual" | |
class OmniRealtimeClient: | |
""" | |
A demo client for interacting with the Omni Realtime API. | |
This class provides methods to connect to the Realtime API, send text and audio data, | |
handle responses, and manage the WebSocket connection. | |
Attributes: | |
base_url (str): | |
The base URL for the Realtime API. | |
api_key (str): | |
The API key for authentication. | |
model (str): | |
Omni model to use for chat. | |
voice (str): | |
The voice to use for audio output. | |
turn_detection_mode (TurnDetectionMode): | |
The mode for turn detection. | |
on_text_delta (Callable[[str], None]): | |
Callback for text delta events. | |
Takes in a string and returns nothing. | |
on_audio_delta (Callable[[bytes], None]): | |
Callback for audio delta events. | |
Takes in bytes and returns nothing. | |
on_input_transcript (Callable[[str], None]): | |
Callback for input transcript events. | |
Takes in a string and returns nothing. | |
on_interrupt (Callable[[], None]): | |
Callback for user interrupt events, should be used to stop audio playback. | |
on_output_transcript (Callable[[str], None]): | |
Callback for output transcript events. | |
Takes in a string and returns nothing. | |
extra_event_handlers (Dict[str, Callable[[Dict[str, Any]], None]]): | |
Additional event handlers. | |
Is a mapping of event names to functions that process the event payload. | |
""" | |
def __init__( | |
self, | |
base_url, | |
api_key: str, | |
model: str = "", | |
voice: str = "Chelsie", | |
turn_detection_mode: TurnDetectionMode = TurnDetectionMode.MANUAL, | |
on_text_delta: Optional[Callable[[str], None]] = None, | |
on_audio_delta: Optional[Callable[[bytes], None]] = None, | |
on_interrupt: Optional[Callable[[], None]] = None, | |
on_input_transcript: Optional[Callable[[str], None]] = None, | |
on_output_transcript: Optional[Callable[[str], None]] = None, | |
extra_event_handlers: Optional[Dict[str, Callable[[Dict[str, Any]], None]]] = None | |
): | |
self.base_url = base_url | |
self.api_key = api_key | |
self.model = model | |
self.voice = voice | |
self.ws = None | |
self.on_text_delta = on_text_delta | |
self.on_audio_delta = on_audio_delta | |
self.on_interrupt = on_interrupt | |
self.on_input_transcript = on_input_transcript | |
self.on_output_transcript = on_output_transcript | |
self.turn_detection_mode = turn_detection_mode | |
self.extra_event_handlers = extra_event_handlers or {} | |
# Track current response state | |
self._current_response_id = None | |
self._current_item_id = None | |
self._is_responding = False | |
# Track printing state for input and output transcripts | |
self._print_input_transcript = False | |
self._output_transcript_buffer = "" | |
async def connect(self) -> None: | |
"""Establish WebSocket connection with the Realtime API.""" | |
url = f"{self.base_url}?model={self.model}" | |
headers = { | |
"Authorization": f"Bearer {self.api_key}" | |
} | |
print(f"url: {url}, headers: {headers}") | |
self.ws = await websockets.connect(url, additional_headers=headers) | |
# Set up default session configuration | |
if self.turn_detection_mode == TurnDetectionMode.MANUAL: | |
await self.update_session({ | |
"modalities": ["text", "audio"], | |
"voice": self.voice, | |
"input_audio_format": "pcm16", | |
"output_audio_format": "pcm16", | |
"input_audio_transcription": { | |
"model": "gummy-realtime-v1" | |
}, | |
"turn_detection": None | |
}) | |
elif self.turn_detection_mode == TurnDetectionMode.SERVER_VAD: | |
await self.update_session({ | |
"modalities": ["text", "audio"], | |
"voice": self.voice, | |
"input_audio_format": "pcm16", | |
"output_audio_format": "pcm16", | |
"input_audio_transcription": { | |
"model": "gummy-realtime-v1" | |
}, | |
"turn_detection": { | |
"type": "server_vad", | |
"threshold": 0.1, | |
"prefix_padding_ms": 500, | |
"silence_duration_ms": 900 | |
} | |
}) | |
else: | |
raise ValueError(f"Invalid turn detection mode: {self.turn_detection_mode}") | |
async def send_event(self, event) -> None: | |
event['event_id'] = "event_" + str(int(time.time() * 1000)) | |
print(f" Send event: type={event['type']}, event_id={event['event_id']}") | |
await self.ws.send(json.dumps(event)) | |
async def update_session(self, config: Dict[str, Any]) -> None: | |
"""Update session configuration.""" | |
event = { | |
"type": "session.update", | |
"session": config | |
} | |
print("update session: ", event) | |
await self.send_event(event) | |
async def stream_audio(self, audio_chunk: bytes) -> None: | |
"""Stream raw audio data to the API.""" | |
# only support 16bit 16kHz mono pcm | |
audio_b64 = base64.b64encode(audio_chunk).decode() | |
append_event = { | |
"type": "input_audio_buffer.append", | |
"audio": audio_b64 | |
} | |
await self.send_event(json.dumps(append_event)) | |
async def create_response(self) -> None: | |
"""Request a response from the API. Needed when using manual mode.""" | |
event = { | |
"type": "response.create", | |
"response": { | |
"instructions": "你是Tom,一个美国的导购,负责售卖手机、电视", | |
"modalities": ["text", "audio"] | |
} | |
} | |
print("create response: ", event) | |
await self.send_event(event) | |
async def cancel_response(self) -> None: | |
"""Cancel the current response.""" | |
event = { | |
"type": "response.cancel" | |
} | |
await self.send_event(event) | |
async def handle_interruption(self): | |
"""Handle user interruption of the current response.""" | |
if not self._is_responding: | |
return | |
print(" Handling interruption") | |
# 1. Cancel the current response | |
if self._current_response_id: | |
await self.cancel_response() | |
self._is_responding = False | |
self._current_response_id = None | |
self._current_item_id = None | |
async def handle_messages(self) -> None: | |
try: | |
async for message in self.ws: | |
event = json.loads(message) | |
event_type = event.get("type") | |
if event_type != "response.audio.delta": | |
print(" event: ", event) | |
else: | |
print(" event_type: ", event_type) | |
if event_type == "error": | |
print(" Error: ", event['error']) | |
continue | |
elif event_type == "response.created": | |
self._current_response_id = event.get("response", {}).get("id") | |
self._is_responding = True | |
elif event_type == "response.output_item.added": | |
self._current_item_id = event.get("item", {}).get("id") | |
elif event_type == "response.done": | |
self._is_responding = False | |
self._current_response_id = None | |
self._current_item_id = None | |
# Handle interruptions | |
elif event_type == "input_audio_buffer.speech_started": | |
print(" Speech detected") | |
if self._is_responding: | |
print(" Handling interruption") | |
await self.handle_interruption() | |
if self.on_interrupt: | |
print(" Handling on_interrupt, stop playback") | |
self.on_interrupt() | |
elif event_type == "input_audio_buffer.speech_stopped": | |
print(" Speech ended") | |
# Handle normal response events | |
elif event_type == "response.text.delta": | |
if self.on_text_delta: | |
self.on_text_delta(event["delta"]) | |
elif event_type == "response.audio.delta": | |
if self.on_audio_delta: | |
audio_bytes = base64.b64decode(event["delta"]) | |
self.on_audio_delta(audio_bytes) | |
elif event_type == "conversation.item.input_audio_transcription.completed": | |
transcript = event.get("transcript", "") | |
if self.on_input_transcript: | |
await asyncio.to_thread(self.on_input_transcript, transcript) | |
self._print_input_transcript = True | |
elif event_type == "response.audio_transcript.delta": | |
if self.on_output_transcript: | |
delta = event.get("delta", "") | |
if not self._print_input_transcript: | |
self._output_transcript_buffer += delta | |
else: | |
if self._output_transcript_buffer: | |
await asyncio.to_thread(self.on_output_transcript, self._output_transcript_buffer) | |
self._output_transcript_buffer = "" | |
await asyncio.to_thread(self.on_output_transcript, delta) | |
elif event_type == "response.audio_transcript.done": | |
self._print_input_transcript = False | |
elif event_type in self.extra_event_handlers: | |
self.extra_event_handlers[event_type](event) | |
except websockets.exceptions.ConnectionClosed: | |
print(" Connection closed") | |
except Exception as e: | |
print(" Error in message handling: ", str(e)) | |
async def close(self) -> None: | |
"""Close the WebSocket connection.""" | |
if self.ws: | |
await self.ws.close() | |
class OmniRealtime(object): | |
def __init__(self, | |
qwen_api_key: str, | |
model_name: str, | |
sample_rate: int = 16000, | |
chunk_size: int = 3200, | |
channels: int = 1, | |
): | |
self.qwen_api_key = qwen_api_key | |
self.model_name = model_name | |
self.sample_rate = sample_rate | |
self.chunk_size = chunk_size | |
self.channels = channels | |
# 创建一个全局音频队列和播放线程 | |
self.audio_queue = queue.Queue() | |
self.audio_player = None | |
self.realtime_client = OmniRealtimeClient( | |
base_url="wss://dashscope.aliyuncs.com/api-ws/v1/realtime", | |
api_key=self.qwen_api_key, | |
model=self.model_name, | |
voice="Chelsie", | |
on_text_delta=lambda text: print(f"\nAssistant: {text}", end="", flush=True), | |
on_audio_delta=self.handle_audio_data, | |
# turn_detection_mode=TurnDetectionMode.MANUAL, | |
turn_detection_mode=TurnDetectionMode.SERVER_VAD, | |
) | |
def start_audio_player(self): | |
"""启动音频播放线程""" | |
if self.audio_player is None or not self.audio_player.is_alive(): | |
audio_player = threading.Thread(target=self.audio_player_thread, daemon=True) | |
audio_player.start() | |
def audio_player_thread(self): | |
"""后台线程用于播放音频数据""" | |
p = pyaudio.PyAudio() | |
stream = p.open(format=pyaudio.paInt16, | |
channels=self.channels, | |
rate=24000, | |
output=True, | |
frames_per_buffer=self.chunk_size) | |
try: | |
while True: | |
try: | |
# 从队列获取音频数据 | |
audio_data = self.audio_queue.get(block=True, timeout=0.5) | |
if audio_data is None: # 结束信号 | |
break | |
# 播放音频数据 | |
stream.write(audio_data) | |
self.audio_queue.task_done() | |
except queue.Empty: | |
# 如果队列为空,继续等待 | |
continue | |
finally: | |
# 清理 | |
stream.stop_stream() | |
stream.close() | |
def handle_audio_data(self, audio_data: bytes): | |
"""处理接收到的音频数据""" | |
# 打印接收到的音频数据长度(调试用) | |
print(f"\n接收到音频数据: {len(audio_data)} 字节") | |
# 将音频数据放入队列 | |
self.audio_queue.put(audio_data) | |
async def start_microphone_streaming(self): | |
p = pyaudio.PyAudio() | |
stream = p.open(format=pyaudio.paInt16, | |
channels=self.channels, | |
rate=self.sample_rate, | |
input=True, | |
frames_per_buffer=self.chunk_size) | |
try: | |
print("开始录音,请讲话...") | |
while True: | |
audio_data = stream.read(self.chunk_size) | |
encoded_data = base64.b64encode(audio_data).decode("utf-8") | |
eventd = { | |
"event_id": "event_" + str(int(time.time() * 1000)), | |
"type": "input_audio_buffer.append", | |
"audio": encoded_data | |
} | |
await self.realtime_client.send_event(eventd) | |
# 保持较短的等待时间以模拟实时交互 | |
await asyncio.sleep(0.05) | |
finally: | |
stream.stop_stream() | |
stream.close() | |
p.terminate() | |
async def run(self): | |
self.start_audio_player() | |
await self.realtime_client.connect() | |
# await self.realtime_client.create_response() | |
try: | |
# 启动消息处理和麦克风录音 | |
message_handler = asyncio.create_task(self.realtime_client.handle_messages()) | |
streaming_task = asyncio.create_task(self.start_microphone_streaming()) | |
while True: | |
await asyncio.Queue().get() | |
except Exception as e: | |
print(f"Error: {e}") | |
finally: | |
# 结束音频播放线程 | |
self.audio_queue.put(None) | |
if self.audio_player: | |
self.audio_player.join(timeout=1) | |
await self.realtime_client.close() | |
return | |
async def main(): | |
args = get_args() | |
omni_realtime = OmniRealtime( | |
qwen_api_key=args.qwen_api_key, | |
model_name=args.model_name, | |
) | |
await omni_realtime.run() | |
return | |
if __name__ == "__main__": | |
asyncio.run(main()) | |