Spaces:
Running
Running
Commit
·
153165b
1
Parent(s):
53ea588
removed the binary files
Browse files- .gitattributes +1 -0
- .gitignore +2 -0
- examples/voice_agent_webrtc_langgraph/start.sh +13 -0
- tests/__init__.py +0 -4
- tests/perf/README.md +0 -101
- tests/perf/file_input_client.py +0 -581
- tests/perf/run_multi_client_benchmark.sh +0 -414
- tests/perf/ttfb_analyzer.py +0 -253
- tests/unit/__init__.py +0 -4
- tests/unit/configs/animation_config.yaml +0 -346
- tests/unit/configs/test_speech_planner_prompt.yaml +0 -15
- tests/unit/test_ace_websocket_serializer.py +0 -147
- tests/unit/test_acknowledgment.py +0 -71
- tests/unit/test_animation_graph_services.py +0 -668
- tests/unit/test_audio2face_3d_service.py +0 -182
- tests/unit/test_audio_util.py +0 -64
- tests/unit/test_basic_pipelines.py +0 -130
- tests/unit/test_blingfire_text_aggregator.py +0 -244
- tests/unit/test_custom_view.py +0 -203
- tests/unit/test_elevenlabs.py +0 -184
- tests/unit/test_frame_creation.py +0 -148
- tests/unit/test_gesture.py +0 -94
- tests/unit/test_guardrail.py +0 -110
- tests/unit/test_message_broker.py +0 -111
- tests/unit/test_nvidia_aggregators.py +0 -396
- tests/unit/test_nvidia_llm_service.py +0 -386
- tests/unit/test_nvidia_rag_service.py +0 -261
- tests/unit/test_nvidia_tts_response_cacher.py +0 -79
- tests/unit/test_posture.py +0 -104
- tests/unit/test_proactivity.py +0 -85
- tests/unit/test_riva_asr_service.py +0 -523
- tests/unit/test_riva_nmt_service.py +0 -197
- tests/unit/test_riva_tts_service.py +0 -301
- tests/unit/test_speech_planner.py +0 -546
- tests/unit/test_traced_processor.py +0 -159
- tests/unit/test_transcription_sync_processors.py +0 -262
- tests/unit/test_user_presence.py +0 -157
- tests/unit/test_utils.py +0 -95
- tests/unit/utils.py +0 -428
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -66,4 +66,6 @@ pnpm-debug.log*
|
|
66 |
# --- Example runtime artifacts ---
|
67 |
examples/voice_agent_webrtc_langgraph/audio_dumps/
|
68 |
examples/voice_agent_webrtc_langgraph/ui/dist/
|
|
|
|
|
69 |
|
|
|
66 |
# --- Example runtime artifacts ---
|
67 |
examples/voice_agent_webrtc_langgraph/audio_dumps/
|
68 |
examples/voice_agent_webrtc_langgraph/ui/dist/
|
69 |
+
examples/voice_agent_webrtc_langgraph/audio_prompt.wav
|
70 |
+
tests/perf/audio_files/*.wav
|
71 |
|
examples/voice_agent_webrtc_langgraph/start.sh
CHANGED
@@ -9,6 +9,19 @@ if [ -f "/app/examples/voice_agent_webrtc_langgraph/.env" ]; then
|
|
9 |
set +a
|
10 |
fi
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# All dependencies and langgraph CLI are installed at build time
|
13 |
|
14 |
# Start langgraph dev from within the internal agents directory (background)
|
|
|
9 |
set +a
|
10 |
fi
|
11 |
|
12 |
+
# If a remote prompt URL is provided, download it and export ZERO_SHOT_AUDIO_PROMPT
|
13 |
+
if [ -n "${ZERO_SHOT_AUDIO_PROMPT_URL:-}" ]; then
|
14 |
+
PROMPT_TARGET="${ZERO_SHOT_AUDIO_PROMPT:-/app/examples/voice_agent_webrtc_langgraph/audio_prompt.wav}"
|
15 |
+
mkdir -p "$(dirname "$PROMPT_TARGET")"
|
16 |
+
if [ ! -f "$PROMPT_TARGET" ]; then
|
17 |
+
echo "Downloading ZERO_SHOT_AUDIO_PROMPT from $ZERO_SHOT_AUDIO_PROMPT_URL"
|
18 |
+
if ! curl -fsSL "$ZERO_SHOT_AUDIO_PROMPT_URL" -o "$PROMPT_TARGET"; then
|
19 |
+
echo "Failed to download audio prompt from URL: $ZERO_SHOT_AUDIO_PROMPT_URL" >&2
|
20 |
+
fi
|
21 |
+
fi
|
22 |
+
export ZERO_SHOT_AUDIO_PROMPT="$PROMPT_TARGET"
|
23 |
+
fi
|
24 |
+
|
25 |
# All dependencies and langgraph CLI are installed at build time
|
26 |
|
27 |
# Start langgraph dev from within the internal agents directory (background)
|
tests/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Test suite for the nvidia-pipecat package."""
|
|
|
|
|
|
|
|
|
|
tests/perf/README.md
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
# Performance Testing
|
2 |
-
|
3 |
-
This directory contains tools for evaluating the voice agent pipeline's latency and scalability/throughput under various loads. These tests simulate real-world scenarios where multiple users interact with the voice agent simultaneously.
|
4 |
-
|
5 |
-
## What the Tests Do
|
6 |
-
|
7 |
-
The performance tests:
|
8 |
-
|
9 |
-
- Open WebSocket clients that simulate user interactions
|
10 |
-
- Use pre-recorded audio files from `audio_files/` as user queries
|
11 |
-
- Send these queries to the voice agent pipeline and measure response times
|
12 |
-
- Track various latency metrics including end-to-end latency, component-wise breakdowns
|
13 |
-
- Can simulate multiple concurrent clients to test scaling
|
14 |
-
- Detect any audio glitches during processing
|
15 |
-
|
16 |
-
## Running Performance Tests
|
17 |
-
|
18 |
-
### 1. Start the Voice Agent Pipeline
|
19 |
-
|
20 |
-
First, start the voice agent pipeline and capture server logs for analysis.
|
21 |
-
See the prerequisites and setup instructions in `examples/speech-to-speech/README.md` before proceeding.
|
22 |
-
|
23 |
-
#### If Using Docker
|
24 |
-
|
25 |
-
From examples/speech-to-speech/ directory run:
|
26 |
-
|
27 |
-
```bash
|
28 |
-
# Start the services
|
29 |
-
docker compose up -d
|
30 |
-
|
31 |
-
# Capture logs and save them into a file
|
32 |
-
docker compose logs -f python-app > bot_logs_test1.txt 2>&1
|
33 |
-
```
|
34 |
-
|
35 |
-
Before starting a new performance run:
|
36 |
-
|
37 |
-
```bash
|
38 |
-
# Clear existing Docker logs
|
39 |
-
sudo truncate -s 0 /var/lib/docker/containers/$(docker compose ps -q python-app)/$(docker compose ps -q python-app)-json.log
|
40 |
-
```
|
41 |
-
|
42 |
-
#### If Using Python Environment
|
43 |
-
|
44 |
-
From examples/speech-to-speech/ directory run:
|
45 |
-
|
46 |
-
```bash
|
47 |
-
python bot.py > bot_logs_test1.txt 2>&1
|
48 |
-
```
|
49 |
-
|
50 |
-
### 2. Run the Multi-Client Benchmark
|
51 |
-
|
52 |
-
```bash
|
53 |
-
./run_multi_client_benchmark.sh --host 0.0.0.0 --port 8100 --clients 10 --test-duration 150
|
54 |
-
```
|
55 |
-
|
56 |
-
Parameters:
|
57 |
-
|
58 |
-
- `--host`: The host address (default: 0.0.0.0)
|
59 |
-
- `--port`: The port where your voice agent is running (default: 8100)
|
60 |
-
- `--clients`: Number of concurrent clients to simulate (default: 1)
|
61 |
-
- `--test-duration`: Duration of the test in seconds (default: 150)
|
62 |
-
|
63 |
-
The script will:
|
64 |
-
|
65 |
-
1. Start the specified number of concurrent clients
|
66 |
-
2. Simulate user interactions using audio files
|
67 |
-
3. Measure latencies and detect audio glitches
|
68 |
-
4. Save detailed results in the `results` directory as JSON files
|
69 |
-
5. Output a summary to the console
|
70 |
-
|
71 |
-
### 3. Analyze Component-wise Latency
|
72 |
-
|
73 |
-
After the benchmark completes, analyze the server logs for detailed latency breakdowns:
|
74 |
-
|
75 |
-
```bash
|
76 |
-
python ttfb_analyzer.py <relative_path_to_bot_logs_test1.txt>
|
77 |
-
```
|
78 |
-
|
79 |
-
This will show:
|
80 |
-
|
81 |
-
- Per-client latency metrics for LLM, TTS, and ASR components
|
82 |
-
- Number of calls made by each client
|
83 |
-
- Overall averages and P95 values
|
84 |
-
- Component-wise timing breakdowns
|
85 |
-
|
86 |
-
## Understanding the Results
|
87 |
-
|
88 |
-
The metrics measured include:
|
89 |
-
|
90 |
-
- **LLM TTFB**: Time to first byte from the LLM model
|
91 |
-
- **TTS TTFB**: Time to first byte from the TTS model
|
92 |
-
- **ASR Lat**: Compute latency of the ASR model
|
93 |
-
- **LLM 1st**: Time taken to generate first complete sentence from LLM
|
94 |
-
- **Calls**: Number of API calls made to each service
|
95 |
-
|
96 |
-
The results help identify:
|
97 |
-
|
98 |
-
- Performance bottlenecks in specific components
|
99 |
-
- Scaling behavior under concurrent load
|
100 |
-
- Potential audio quality issues
|
101 |
-
- Overall system responsiveness
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/perf/file_input_client.py
DELETED
@@ -1,581 +0,0 @@
|
|
1 |
-
"""Speech-to-speech client with latency measurement for performance testing.
|
2 |
-
|
3 |
-
This module provides a WebSocket client that sends audio files to a speech-to-speech
|
4 |
-
service and measures the latency between when user audio ends and bot response begins.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
import asyncio
|
9 |
-
import datetime
|
10 |
-
import io
|
11 |
-
import json
|
12 |
-
import os
|
13 |
-
import signal
|
14 |
-
import sys
|
15 |
-
import time
|
16 |
-
import uuid
|
17 |
-
import wave
|
18 |
-
|
19 |
-
import websockets
|
20 |
-
from pipecat.frames.protobufs import frames_pb2
|
21 |
-
from websockets.exceptions import ConnectionClosed
|
22 |
-
|
23 |
-
|
24 |
-
def log_error(msg):
|
25 |
-
"""Write error message to stderr with timestamp."""
|
26 |
-
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
27 |
-
print(f"[ERROR] {timestamp} - {msg}", file=sys.stderr, flush=True)
|
28 |
-
|
29 |
-
|
30 |
-
# Global constants
|
31 |
-
SILENCE_TIMEOUT = 0.2 # Standard silence timeout in seconds
|
32 |
-
CHUNK_DURATION_MS = 32 # Standard chunk duration in milliseconds
|
33 |
-
|
34 |
-
# List to store latency values
|
35 |
-
latency_values = []
|
36 |
-
|
37 |
-
# List to store filtered latency values (above threshold)
|
38 |
-
filtered_latency_values = []
|
39 |
-
|
40 |
-
# Global variable to track timestamps
|
41 |
-
timestamps = {"input_audio_file_end": None, "first_response_after_input": None}
|
42 |
-
|
43 |
-
# Global glitch detection
|
44 |
-
glitch_detected = False
|
45 |
-
|
46 |
-
# Global flag and event for controlling silence sending
|
47 |
-
silence_control = {
|
48 |
-
"running": False,
|
49 |
-
"event": asyncio.Event(),
|
50 |
-
"audio_params": None, # Will store (frame_rate, n_channels, chunk_size)
|
51 |
-
}
|
52 |
-
|
53 |
-
# Global control for continuous operation
|
54 |
-
continuous_control = {
|
55 |
-
"running": True,
|
56 |
-
"collecting_metrics": False,
|
57 |
-
"start_time": None,
|
58 |
-
"test_duration": 100, # Default 100 seconds
|
59 |
-
"threshold": 0.5, # Default threshold for filtered latency
|
60 |
-
}
|
61 |
-
|
62 |
-
|
63 |
-
# Signal handler for graceful shutdown
|
64 |
-
def signal_handler(signum, frame):
|
65 |
-
"""Handle system signals for graceful shutdown."""
|
66 |
-
print(f"\nReceived signal {signum}, shutting down gracefully...")
|
67 |
-
continuous_control["running"] = False
|
68 |
-
sys.exit(0)
|
69 |
-
|
70 |
-
|
71 |
-
# Register signal handlers
|
72 |
-
signal.signal(signal.SIGINT, signal_handler)
|
73 |
-
signal.signal(signal.SIGTERM, signal_handler)
|
74 |
-
|
75 |
-
|
76 |
-
def write_audio_to_wav(data, wf, create_new_file=False, output_file="bot_response.wav"):
|
77 |
-
"""Write audio data to WAV file."""
|
78 |
-
try:
|
79 |
-
# Parse protobuf frame
|
80 |
-
try:
|
81 |
-
proto = frames_pb2.Frame.FromString(data)
|
82 |
-
which = proto.WhichOneof("frame")
|
83 |
-
if which is None:
|
84 |
-
return wf, None, None, None
|
85 |
-
except Exception as e:
|
86 |
-
log_error(f"Failed to parse protobuf frame: {e}")
|
87 |
-
return wf, None, None, None
|
88 |
-
|
89 |
-
args = getattr(proto, which)
|
90 |
-
sample_rate = getattr(args, "sample_rate", 16000)
|
91 |
-
num_channels = getattr(args, "num_channels", 1)
|
92 |
-
audio_data = getattr(args, "audio", None)
|
93 |
-
if audio_data is None:
|
94 |
-
return wf, None, None, None
|
95 |
-
|
96 |
-
# Extract raw audio data from WAV format if needed
|
97 |
-
try:
|
98 |
-
with io.BytesIO(audio_data) as buffer, wave.open(buffer, "rb") as wav_file:
|
99 |
-
audio_data = wav_file.readframes(wav_file.getnframes())
|
100 |
-
sample_rate = wav_file.getframerate()
|
101 |
-
num_channels = wav_file.getnchannels()
|
102 |
-
except Exception:
|
103 |
-
# If not WAV format, use audio_data as-is
|
104 |
-
pass
|
105 |
-
|
106 |
-
# Create WAV file if needed
|
107 |
-
if create_new_file and wf is None:
|
108 |
-
try:
|
109 |
-
wf = wave.open(output_file, "wb") # noqa: SIM115
|
110 |
-
wf.setnchannels(num_channels)
|
111 |
-
wf.setsampwidth(2)
|
112 |
-
wf.setframerate(sample_rate)
|
113 |
-
except Exception as e:
|
114 |
-
log_error(f"Failed to create WAV file {output_file}: {e}")
|
115 |
-
return None, None, None, None
|
116 |
-
|
117 |
-
# Write audio data directly
|
118 |
-
if wf is not None:
|
119 |
-
try:
|
120 |
-
wf.writeframes(audio_data)
|
121 |
-
except Exception as e:
|
122 |
-
log_error(f"Failed to write audio data: {e}")
|
123 |
-
return None, None, None, None
|
124 |
-
|
125 |
-
return wf, sample_rate, num_channels, audio_data
|
126 |
-
except Exception as e:
|
127 |
-
log_error(f"Unexpected error in write_audio_to_wav: {e}")
|
128 |
-
return wf, None, None, None
|
129 |
-
|
130 |
-
|
131 |
-
async def send_audio_file(websocket, file_path):
|
132 |
-
"""Send audio file content with streaming simulation."""
|
133 |
-
# Pause silence sending while we send the real audio
|
134 |
-
silence_control["event"].set()
|
135 |
-
|
136 |
-
try:
|
137 |
-
if not os.path.exists(file_path):
|
138 |
-
log_error(f"Input audio file not found: {file_path}")
|
139 |
-
return
|
140 |
-
|
141 |
-
try:
|
142 |
-
with wave.open(file_path, "rb") as wav_file:
|
143 |
-
n_channels = wav_file.getnchannels()
|
144 |
-
frame_rate = wav_file.getframerate()
|
145 |
-
sample_width = wav_file.getsampwidth()
|
146 |
-
|
147 |
-
# Store audio parameters for silence generation
|
148 |
-
chunk_size = int((frame_rate * n_channels * CHUNK_DURATION_MS) / 1000) * sample_width
|
149 |
-
silence_control["audio_params"] = (
|
150 |
-
frame_rate,
|
151 |
-
n_channels,
|
152 |
-
chunk_size,
|
153 |
-
)
|
154 |
-
|
155 |
-
# Stream the audio file
|
156 |
-
frames_sent = 0
|
157 |
-
while True:
|
158 |
-
try:
|
159 |
-
chunk = wav_file.readframes(chunk_size // sample_width)
|
160 |
-
if not chunk:
|
161 |
-
break
|
162 |
-
audio_frame = frames_pb2.AudioRawFrame(
|
163 |
-
audio=chunk, sample_rate=frame_rate, num_channels=n_channels
|
164 |
-
)
|
165 |
-
frame = frames_pb2.Frame(audio=audio_frame)
|
166 |
-
await websocket.send(frame.SerializeToString())
|
167 |
-
frames_sent += 1
|
168 |
-
await asyncio.sleep(CHUNK_DURATION_MS / 1000)
|
169 |
-
except Exception as e:
|
170 |
-
log_error(f"Error sending audio frame {frames_sent}: {e}")
|
171 |
-
raise # Re-raise to handle in outer try block
|
172 |
-
except wave.Error as e:
|
173 |
-
log_error(f"Failed to read WAV file {file_path}: {e}")
|
174 |
-
return
|
175 |
-
except Exception as e:
|
176 |
-
log_error(f"Error in send_audio_file: {e}")
|
177 |
-
return
|
178 |
-
finally:
|
179 |
-
# Always record when input audio ends and resume silence sending
|
180 |
-
timestamps["input_audio_file_end"] = datetime.datetime.now()
|
181 |
-
print(f"User stopped speaking at: {timestamps['input_audio_file_end'].strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}")
|
182 |
-
silence_control["event"].clear()
|
183 |
-
|
184 |
-
|
185 |
-
async def silence_sender_loop(websocket):
|
186 |
-
"""Background task to continuously send silence when no other audio is being sent."""
|
187 |
-
silence_control["running"] = True
|
188 |
-
print("Silence sender loop started")
|
189 |
-
consecutive_errors = 0
|
190 |
-
max_consecutive_errors = 5
|
191 |
-
|
192 |
-
try:
|
193 |
-
while silence_control["running"]:
|
194 |
-
try:
|
195 |
-
# Wait until we're allowed to send silence
|
196 |
-
if silence_control["event"].is_set() or silence_control["audio_params"] is None:
|
197 |
-
await asyncio.sleep(0.1) # Short sleep to avoid CPU spinning
|
198 |
-
continue
|
199 |
-
|
200 |
-
# Extract audio parameters
|
201 |
-
frame_rate, n_channels, chunk_size = silence_control["audio_params"]
|
202 |
-
|
203 |
-
# Send a chunk of silence
|
204 |
-
silent_chunk = b"\x00" * chunk_size
|
205 |
-
audio_frame = frames_pb2.AudioRawFrame(
|
206 |
-
audio=silent_chunk, sample_rate=frame_rate, num_channels=n_channels
|
207 |
-
)
|
208 |
-
frame = frames_pb2.Frame(audio=audio_frame)
|
209 |
-
await websocket.send(frame.SerializeToString())
|
210 |
-
await asyncio.sleep(CHUNK_DURATION_MS / 1000)
|
211 |
-
|
212 |
-
# Reset error counter on successful send
|
213 |
-
consecutive_errors = 0
|
214 |
-
|
215 |
-
except ConnectionClosed:
|
216 |
-
print("WebSocket connection closed in silence sender loop")
|
217 |
-
break
|
218 |
-
except Exception as e:
|
219 |
-
consecutive_errors += 1
|
220 |
-
print(f"Error in silence sender loop (attempt {consecutive_errors}/{max_consecutive_errors}): {e}")
|
221 |
-
|
222 |
-
# If too many consecutive errors, stop the loop
|
223 |
-
if consecutive_errors >= max_consecutive_errors:
|
224 |
-
print(f"Too many consecutive errors ({consecutive_errors}), stopping silence sender")
|
225 |
-
break
|
226 |
-
|
227 |
-
# Brief pause before retry to avoid overwhelming the system
|
228 |
-
await asyncio.sleep(1.0)
|
229 |
-
|
230 |
-
except Exception as e:
|
231 |
-
print(f"Fatal error in silence sender loop: {e}")
|
232 |
-
finally:
|
233 |
-
print("Silence sender loop stopped")
|
234 |
-
silence_control["running"] = False
|
235 |
-
|
236 |
-
|
237 |
-
async def receive_audio(
|
238 |
-
websocket,
|
239 |
-
wf=None,
|
240 |
-
create_new_file=True,
|
241 |
-
is_after_input=False,
|
242 |
-
output_wav="bot_response.wav",
|
243 |
-
is_initial=False,
|
244 |
-
timeout=1.0,
|
245 |
-
):
|
246 |
-
"""Receive audio data and handle streaming playback simulation."""
|
247 |
-
global glitch_detected
|
248 |
-
|
249 |
-
if is_initial:
|
250 |
-
print("Waiting up to 5 seconds for initial bot introduction audio if available...")
|
251 |
-
try:
|
252 |
-
# Wait for first data packet with 5 second timeout
|
253 |
-
data = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
254 |
-
except TimeoutError:
|
255 |
-
print("No initial bot introduction received after 5 seconds, continuing...")
|
256 |
-
return wf
|
257 |
-
else:
|
258 |
-
# For non-initial audio, receive normally
|
259 |
-
data = await websocket.recv()
|
260 |
-
|
261 |
-
try:
|
262 |
-
# Wait for first data packet
|
263 |
-
data = await websocket.recv()
|
264 |
-
|
265 |
-
# Record first response timestamp if after input
|
266 |
-
if is_after_input:
|
267 |
-
timestamps["first_response_after_input"] = datetime.datetime.now()
|
268 |
-
formatted_time = timestamps["first_response_after_input"].strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
269 |
-
print(f"Bot started speaking at {formatted_time}")
|
270 |
-
|
271 |
-
# Process first audio packet
|
272 |
-
wf, sample_rate, num_channels, audio_data = write_audio_to_wav(data, wf, create_new_file, output_wav)
|
273 |
-
|
274 |
-
# Initialize timing for glitch detection
|
275 |
-
audio_start_time = time.time()
|
276 |
-
cumulative_audio_duration = 0.0 # Total duration of audio received (in seconds)
|
277 |
-
|
278 |
-
# Calculate duration of first chunk if we have audio data
|
279 |
-
if audio_data and sample_rate and num_channels:
|
280 |
-
bytes_per_sample = 2 # Assuming 16-bit audio
|
281 |
-
samples_in_chunk = len(audio_data) // (num_channels * bytes_per_sample)
|
282 |
-
chunk_duration_seconds = samples_in_chunk / sample_rate
|
283 |
-
cumulative_audio_duration += chunk_duration_seconds
|
284 |
-
|
285 |
-
# Continue receiving audio data until silence threshold reached
|
286 |
-
last_data_time = time.time()
|
287 |
-
|
288 |
-
while True:
|
289 |
-
try:
|
290 |
-
data = await asyncio.wait_for(websocket.recv(), timeout=timeout)
|
291 |
-
current_time = time.time()
|
292 |
-
last_data_time = current_time
|
293 |
-
|
294 |
-
# Process audio data
|
295 |
-
wf, sample_rate, num_channels, audio_data = write_audio_to_wav(data, wf, False, output_wav)
|
296 |
-
|
297 |
-
# Update cumulative audio duration
|
298 |
-
if audio_data and sample_rate and num_channels:
|
299 |
-
bytes_per_sample = 2 # Assuming 16-bit audio
|
300 |
-
samples_in_chunk = len(audio_data) // (num_channels * bytes_per_sample)
|
301 |
-
chunk_duration_seconds = samples_in_chunk / sample_rate
|
302 |
-
cumulative_audio_duration += chunk_duration_seconds
|
303 |
-
|
304 |
-
# Check for glitch: real elapsed time vs cumulative audio duration
|
305 |
-
real_elapsed_time = current_time - audio_start_time
|
306 |
-
audio_deficit = real_elapsed_time - cumulative_audio_duration
|
307 |
-
|
308 |
-
if audio_deficit >= 0.032: # 32ms threshold for glitch detection
|
309 |
-
print(f"Audio glitch detected: {audio_deficit * 1000:.1f}ms audio deficit")
|
310 |
-
glitch_detected = True
|
311 |
-
|
312 |
-
except TimeoutError:
|
313 |
-
# Check if silence duration exceeds threshold
|
314 |
-
if time.time() - last_data_time >= SILENCE_TIMEOUT:
|
315 |
-
return wf
|
316 |
-
except Exception as e:
|
317 |
-
log_error(f"Error receiving audio data: {e}")
|
318 |
-
if wf is not None and create_new_file:
|
319 |
-
try:
|
320 |
-
wf.close()
|
321 |
-
except Exception as close_error:
|
322 |
-
log_error(f"Error closing WAV file: {close_error}")
|
323 |
-
return None
|
324 |
-
except Exception as e:
|
325 |
-
log_error(f"Fatal error in receive_audio: {e}")
|
326 |
-
if wf is not None and create_new_file:
|
327 |
-
try:
|
328 |
-
wf.close()
|
329 |
-
except Exception as close_error:
|
330 |
-
log_error(f"Error closing WAV file: {close_error}")
|
331 |
-
return None
|
332 |
-
|
333 |
-
|
334 |
-
async def process_conversation_turn(websocket, audio_file_path, wf, turn_index, output_wav="bot_response.wav"):
|
335 |
-
"""Process a single conversation turn with the given audio file."""
|
336 |
-
print(f"\n----- Processing conversation turn {turn_index + 1} -----")
|
337 |
-
|
338 |
-
# Reset timestamps for this turn
|
339 |
-
timestamps["input_audio_file_end"] = None
|
340 |
-
timestamps["first_response_after_input"] = None
|
341 |
-
|
342 |
-
# Start both sending and receiving in parallel for realistic latency measurement
|
343 |
-
print(f"Sending user input audio from {audio_file_path}...")
|
344 |
-
|
345 |
-
# Start sending audio file in background
|
346 |
-
send_task = asyncio.create_task(send_audio_file(websocket, audio_file_path))
|
347 |
-
|
348 |
-
# Start receiving bot response immediately (parallel to sending)
|
349 |
-
receive_task = asyncio.create_task(
|
350 |
-
receive_audio(websocket, wf=wf, create_new_file=(wf is None), is_after_input=True, output_wav=output_wav)
|
351 |
-
)
|
352 |
-
|
353 |
-
# Wait for both tasks to complete
|
354 |
-
wf = await receive_task
|
355 |
-
await send_task # Ensure sending is also complete
|
356 |
-
|
357 |
-
# Calculate and store latency only if we're collecting metrics
|
358 |
-
if continuous_control["collecting_metrics"]:
|
359 |
-
latency = None
|
360 |
-
if timestamps["input_audio_file_end"] is not None and timestamps["first_response_after_input"] is not None:
|
361 |
-
latency = (timestamps["first_response_after_input"] - timestamps["input_audio_file_end"]).total_seconds()
|
362 |
-
print(f"Latency for Turn {turn_index + 1}: {latency:.3f} seconds")
|
363 |
-
latency_values.append(latency)
|
364 |
-
|
365 |
-
# Add to filtered latency if above threshold
|
366 |
-
if latency > continuous_control["threshold"]:
|
367 |
-
filtered_latency_values.append(latency)
|
368 |
-
else:
|
369 |
-
print("Reverse Barge-In Detected!")
|
370 |
-
|
371 |
-
return wf
|
372 |
-
|
373 |
-
|
374 |
-
async def continuous_audio_loop(websocket, audio_files, wf, output_wav):
|
375 |
-
"""Continuously loop through audio files until stopped."""
|
376 |
-
turn_index = 0
|
377 |
-
|
378 |
-
while continuous_control["running"]:
|
379 |
-
# Check if we should start collecting metrics
|
380 |
-
if (
|
381 |
-
continuous_control["start_time"]
|
382 |
-
and time.time() >= continuous_control["start_time"]
|
383 |
-
and not continuous_control["collecting_metrics"]
|
384 |
-
):
|
385 |
-
continuous_control["collecting_metrics"] = True
|
386 |
-
print(f"\n=== STARTING METRICS COLLECTION at {datetime.datetime.now().strftime('%H:%M:%S')} ===")
|
387 |
-
|
388 |
-
# Check if we should stop collecting metrics
|
389 |
-
if (
|
390 |
-
continuous_control["start_time"]
|
391 |
-
and continuous_control["collecting_metrics"]
|
392 |
-
and time.time() >= continuous_control["start_time"] + continuous_control["test_duration"]
|
393 |
-
):
|
394 |
-
print(f"\n=== STOPPING METRICS COLLECTION at {datetime.datetime.now().strftime('%H:%M:%S')} ===")
|
395 |
-
continuous_control["collecting_metrics"] = False
|
396 |
-
continuous_control["running"] = False
|
397 |
-
break
|
398 |
-
|
399 |
-
# Process current audio file
|
400 |
-
audio_file = audio_files[turn_index % len(audio_files)]
|
401 |
-
wf = await process_conversation_turn(websocket, audio_file, wf, turn_index, output_wav)
|
402 |
-
turn_index += 1
|
403 |
-
|
404 |
-
# Small delay between turns to prevent overwhelming the system
|
405 |
-
await asyncio.sleep(0.1)
|
406 |
-
|
407 |
-
return wf
|
408 |
-
|
409 |
-
|
410 |
-
async def main():
|
411 |
-
"""Main execution function."""
|
412 |
-
# Parse command line arguments
|
413 |
-
parser = argparse.ArgumentParser(description="Speech-to-speech client with latency measurement")
|
414 |
-
parser.add_argument(
|
415 |
-
"--stream-id", type=str, default=str(uuid.uuid4()), help="Unique stream ID (default: random UUID)"
|
416 |
-
)
|
417 |
-
parser.add_argument("--host", type=str, default="0.0.0.0", help="WebSocket server host (default: 0.0.0.0)")
|
418 |
-
parser.add_argument("--port", type=int, default=8100, help="WebSocket server port (default: 8100)")
|
419 |
-
parser.add_argument(
|
420 |
-
"--output-dir", type=str, default="./results", help="Directory to store output files (default: ./results)"
|
421 |
-
)
|
422 |
-
parser.add_argument("--start-delay", type=float, default=0, help="Delay in seconds before starting (default: 0)")
|
423 |
-
parser.add_argument(
|
424 |
-
"--metrics-start-time",
|
425 |
-
type=float,
|
426 |
-
default=0,
|
427 |
-
help="Unix timestamp when to start collecting metrics (default: 0)",
|
428 |
-
)
|
429 |
-
parser.add_argument(
|
430 |
-
"--test-duration", type=float, default=100, help="Duration in seconds to collect metrics (default: 100)"
|
431 |
-
)
|
432 |
-
parser.add_argument(
|
433 |
-
"--threshold", type=float, default=0.5, help="Threshold for filtered average latency calculation (default: 0.5)"
|
434 |
-
)
|
435 |
-
args = parser.parse_args()
|
436 |
-
|
437 |
-
# Create output directory if it doesn't exist
|
438 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
439 |
-
|
440 |
-
# Construct WebSocket URI with unique stream ID
|
441 |
-
uri = f"ws://{args.host}:{args.port}/ws/{args.stream_id}"
|
442 |
-
|
443 |
-
# Output file paths
|
444 |
-
output_wav = os.path.join(args.output_dir, f"bot_response_{args.stream_id}.wav")
|
445 |
-
output_results = os.path.join(args.output_dir, f"latency_results_{args.stream_id}.json")
|
446 |
-
|
447 |
-
print(f"Starting client with stream ID: {args.stream_id}")
|
448 |
-
print(f"WebSocket URI: {uri}")
|
449 |
-
print(f"Start delay: {args.start_delay} seconds")
|
450 |
-
print(f"Metrics start time: {args.metrics_start_time}")
|
451 |
-
print(f"Test duration: {args.test_duration} seconds")
|
452 |
-
print(f"Latency threshold: {args.threshold} seconds")
|
453 |
-
|
454 |
-
# Set up timing controls
|
455 |
-
if args.start_delay > 0:
|
456 |
-
print(f"Waiting {args.start_delay} seconds before starting...")
|
457 |
-
await asyncio.sleep(args.start_delay)
|
458 |
-
|
459 |
-
if args.metrics_start_time > 0:
|
460 |
-
continuous_control["start_time"] = args.metrics_start_time
|
461 |
-
continuous_control["test_duration"] = args.test_duration
|
462 |
-
print(f"Will start collecting metrics at timestamp {args.metrics_start_time}")
|
463 |
-
|
464 |
-
# Set threshold for filtered latency calculation
|
465 |
-
continuous_control["threshold"] = args.threshold
|
466 |
-
|
467 |
-
# Define the array of input audio files
|
468 |
-
# Get the directory where this script is located
|
469 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
470 |
-
audio_files_dir = os.path.join(script_dir, "audio_files")
|
471 |
-
|
472 |
-
input_audio_files = [
|
473 |
-
os.path.join(audio_files_dir, "output_file.wav"),
|
474 |
-
# os.path.join(audio_files_dir, "query_1.wav"),
|
475 |
-
# os.path.join(audio_files_dir, "query_2.wav"),
|
476 |
-
# os.path.join(audio_files_dir, "query_3.wav"),
|
477 |
-
# os.path.join(audio_files_dir, "query_4.wav"),
|
478 |
-
# os.path.join(audio_files_dir, "query_5.wav"),
|
479 |
-
# os.path.join(audio_files_dir, "query_6.wav"),
|
480 |
-
# os.path.join(audio_files_dir, "query_7.wav"),
|
481 |
-
# os.path.join(audio_files_dir, "query_8.wav"),
|
482 |
-
# os.path.join(audio_files_dir, "query_9.wav"),
|
483 |
-
# os.path.join(audio_files_dir, "query_10.wav"),
|
484 |
-
]
|
485 |
-
|
486 |
-
# Clear any previous values
|
487 |
-
latency_values.clear()
|
488 |
-
filtered_latency_values.clear()
|
489 |
-
|
490 |
-
# Initialize silence control
|
491 |
-
silence_control["event"] = asyncio.Event()
|
492 |
-
silence_control["event"].set() # Start with silence sending paused
|
493 |
-
|
494 |
-
try:
|
495 |
-
async with websockets.connect(uri) as websocket:
|
496 |
-
# First, try to receive any initial output audio
|
497 |
-
wf = await receive_audio(
|
498 |
-
websocket,
|
499 |
-
wf=None,
|
500 |
-
create_new_file=True,
|
501 |
-
is_after_input=False,
|
502 |
-
output_wav=output_wav,
|
503 |
-
is_initial=True,
|
504 |
-
)
|
505 |
-
|
506 |
-
# Start the silence sender task
|
507 |
-
asyncio.create_task(silence_sender_loop(websocket))
|
508 |
-
|
509 |
-
# Start continuous audio loop
|
510 |
-
wf = await continuous_audio_loop(websocket, input_audio_files, wf, output_wav)
|
511 |
-
|
512 |
-
# Clean up and stop the silence sender
|
513 |
-
silence_control["running"] = False
|
514 |
-
silence_control["event"].set() # Make sure it's not waiting
|
515 |
-
await asyncio.sleep(0.2) # Give it time to exit cleanly
|
516 |
-
|
517 |
-
if wf is not None:
|
518 |
-
wf.close()
|
519 |
-
print(f"All output saved to {output_wav}")
|
520 |
-
|
521 |
-
except ConnectionClosed:
|
522 |
-
# Normal WebSocket closure, not an error
|
523 |
-
pass
|
524 |
-
except Exception as e:
|
525 |
-
print(f"Connection error: {e}")
|
526 |
-
finally:
|
527 |
-
# Always save results, regardless of how the connection ended
|
528 |
-
if latency_values:
|
529 |
-
avg_latency = sum(latency_values) / len(latency_values)
|
530 |
-
|
531 |
-
# Calculate filtered average latency
|
532 |
-
filtered_avg_latency = None
|
533 |
-
if filtered_latency_values:
|
534 |
-
filtered_avg_latency = sum(filtered_latency_values) / len(filtered_latency_values)
|
535 |
-
|
536 |
-
print("\n----- Final Latency Summary -----")
|
537 |
-
print(f"Average Latency across {len(latency_values)} turns: {avg_latency:.3f} seconds")
|
538 |
-
|
539 |
-
if filtered_avg_latency is not None:
|
540 |
-
print(
|
541 |
-
f"Filtered Average Latency (>{args.threshold}s) across {len(filtered_latency_values)} turns: "
|
542 |
-
f"{filtered_avg_latency:.3f} seconds"
|
543 |
-
)
|
544 |
-
else:
|
545 |
-
print(f"Filtered Average Latency: No latencies above {args.threshold}s threshold")
|
546 |
-
|
547 |
-
# Calculate reverse barge-ins (latencies below threshold)
|
548 |
-
reverse_barge_ins_count = len(latency_values) - len(filtered_latency_values)
|
549 |
-
print(f"Reverse Barge-Ins Detected: {reverse_barge_ins_count} latencies below {args.threshold}s threshold")
|
550 |
-
|
551 |
-
# Report glitch detection results
|
552 |
-
if glitch_detected:
|
553 |
-
print("⚠️ AUDIO GLITCHES DETECTED: Audio chunks arrived with gaps larger than playback time")
|
554 |
-
else:
|
555 |
-
print("✅ No audio glitches detected: Audio streaming was smooth")
|
556 |
-
|
557 |
-
print("----------------------------------------")
|
558 |
-
|
559 |
-
# Save results to JSON file
|
560 |
-
results = {
|
561 |
-
"stream_id": args.stream_id,
|
562 |
-
"individual_latencies": latency_values,
|
563 |
-
"average_latency": avg_latency,
|
564 |
-
"filtered_latencies": filtered_latency_values,
|
565 |
-
"filtered_average_latency": filtered_avg_latency,
|
566 |
-
"threshold": args.threshold,
|
567 |
-
"num_turns": len(latency_values),
|
568 |
-
"num_filtered_turns": len(filtered_latency_values),
|
569 |
-
"reverse_barge_ins_count": len(latency_values) - len(filtered_latency_values),
|
570 |
-
"glitch_detected": glitch_detected,
|
571 |
-
"timestamp": datetime.datetime.now().isoformat(),
|
572 |
-
"metrics_start_time": continuous_control["start_time"],
|
573 |
-
"test_duration": continuous_control["test_duration"],
|
574 |
-
}
|
575 |
-
|
576 |
-
with open(output_results, "w") as f:
|
577 |
-
json.dump(results, f, indent=2)
|
578 |
-
|
579 |
-
|
580 |
-
if __name__ == "__main__":
|
581 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/perf/run_multi_client_benchmark.sh
DELETED
@@ -1,414 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
# Configuration variables
|
4 |
-
HOST="0.0.0.0" # Default host
|
5 |
-
PORT=8100 # Default port
|
6 |
-
NUM_CLIENTS=1 # Default number of parallel clients
|
7 |
-
BASE_OUTPUT_DIR="./results"
|
8 |
-
TEST_DURATION=150 # Default test duration in seconds
|
9 |
-
CLIENT_START_DELAY=1 # Delay between client starts in seconds
|
10 |
-
THRESHOLD=0.5 # Default threshold for filtered average latency
|
11 |
-
|
12 |
-
# Generate timestamp for unique directory
|
13 |
-
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
14 |
-
OUTPUT_DIR="${BASE_OUTPUT_DIR}_${TIMESTAMP}"
|
15 |
-
SUMMARY_FILE="$OUTPUT_DIR/summary.json"
|
16 |
-
|
17 |
-
# Process command line arguments
|
18 |
-
while [[ $# -gt 0 ]]; do
|
19 |
-
case $1 in
|
20 |
-
--host)
|
21 |
-
HOST="$2"
|
22 |
-
shift 2
|
23 |
-
;;
|
24 |
-
--port)
|
25 |
-
PORT="$2"
|
26 |
-
shift 2
|
27 |
-
;;
|
28 |
-
--clients)
|
29 |
-
NUM_CLIENTS="$2"
|
30 |
-
shift 2
|
31 |
-
;;
|
32 |
-
--output-dir)
|
33 |
-
BASE_OUTPUT_DIR="$2"
|
34 |
-
OUTPUT_DIR="${BASE_OUTPUT_DIR}_${TIMESTAMP}"
|
35 |
-
SUMMARY_FILE="$OUTPUT_DIR/summary.json"
|
36 |
-
shift 2
|
37 |
-
;;
|
38 |
-
--test-duration)
|
39 |
-
TEST_DURATION="$2"
|
40 |
-
shift 2
|
41 |
-
;;
|
42 |
-
--client-start-delay)
|
43 |
-
CLIENT_START_DELAY="$2"
|
44 |
-
shift 2
|
45 |
-
;;
|
46 |
-
--threshold)
|
47 |
-
THRESHOLD="$2"
|
48 |
-
shift 2
|
49 |
-
;;
|
50 |
-
*)
|
51 |
-
echo "Unknown option: $1"
|
52 |
-
echo "Usage: $0 [--host HOST] [--port PORT] [--clients NUM_CLIENTS] [--output-dir DIR] [--test-duration SECONDS] [--client-start-delay SECONDS] [--threshold SECONDS]"
|
53 |
-
exit 1
|
54 |
-
;;
|
55 |
-
esac
|
56 |
-
done
|
57 |
-
|
58 |
-
# Create output directory if it doesn't exist
|
59 |
-
mkdir -p "$OUTPUT_DIR"
|
60 |
-
|
61 |
-
echo "=== ACE Controller Multi-Client Benchmark (Staggered Start) ==="
|
62 |
-
echo "Host: $HOST"
|
63 |
-
echo "Port: $PORT"
|
64 |
-
echo "Number of clients: $NUM_CLIENTS"
|
65 |
-
echo "Client start delay: $CLIENT_START_DELAY seconds"
|
66 |
-
echo "Test duration: $TEST_DURATION seconds"
|
67 |
-
echo "Latency threshold: $THRESHOLD seconds"
|
68 |
-
echo "Output directory: $OUTPUT_DIR"
|
69 |
-
echo "================================================================"
|
70 |
-
|
71 |
-
# Calculate timing
|
72 |
-
# All clients will start within (NUM_CLIENTS - 1) * CLIENT_START_DELAY seconds
|
73 |
-
# Metrics collection will start after the last client starts
|
74 |
-
TOTAL_START_TIME=$(( (NUM_CLIENTS - 1) * CLIENT_START_DELAY ))
|
75 |
-
METRICS_START_TIME=$(date +%s)
|
76 |
-
METRICS_START_TIME=$((METRICS_START_TIME + TOTAL_START_TIME)) # Set to when the last client starts
|
77 |
-
|
78 |
-
# Run clients with staggered starts
|
79 |
-
pids=()
|
80 |
-
|
81 |
-
for ((i=1; i<=$NUM_CLIENTS; i++)); do
|
82 |
-
# Generate a unique stream ID for each client
|
83 |
-
STREAM_ID="client_${i}_$(date +%s%N | cut -b1-13)"
|
84 |
-
|
85 |
-
# Calculate start delay for this client (0 for first client, increasing for others)
|
86 |
-
START_DELAY=$(( (i - 1) * CLIENT_START_DELAY ))
|
87 |
-
|
88 |
-
# Run client in background with appropriate delays
|
89 |
-
python ./file_input_client.py \
|
90 |
-
--stream-id "$STREAM_ID" \
|
91 |
-
--host "$HOST" \
|
92 |
-
--port "$PORT" \
|
93 |
-
--output-dir "$OUTPUT_DIR" \
|
94 |
-
--start-delay "$START_DELAY" \
|
95 |
-
--metrics-start-time "$METRICS_START_TIME" \
|
96 |
-
--test-duration "$TEST_DURATION" \
|
97 |
-
--threshold "$THRESHOLD" > "$OUTPUT_DIR/client_${i}.log" 2>&1 &
|
98 |
-
|
99 |
-
# Store the process ID
|
100 |
-
pids+=($!)
|
101 |
-
|
102 |
-
# Small delay to ensure proper process creation
|
103 |
-
sleep 0.1
|
104 |
-
done
|
105 |
-
|
106 |
-
echo ""
|
107 |
-
echo "Timing plan:"
|
108 |
-
echo "- First client starts immediately"
|
109 |
-
echo "- Last client starts in $TOTAL_START_TIME seconds"
|
110 |
-
echo "- Metrics collection starts at $(date -d @$METRICS_START_TIME)"
|
111 |
-
echo "- Test will run for $TEST_DURATION seconds after metrics collection starts"
|
112 |
-
echo "- Expected completion time: $(date -d @$((METRICS_START_TIME + TEST_DURATION)))"
|
113 |
-
echo ""
|
114 |
-
|
115 |
-
# Wait for all clients to finish
|
116 |
-
for pid in "${pids[@]}"; do
|
117 |
-
wait "$pid"
|
118 |
-
done
|
119 |
-
|
120 |
-
# Calculate aggregate statistics across all clients
|
121 |
-
TOTAL_LATENCY=0
|
122 |
-
TOTAL_TURNS=0
|
123 |
-
CLIENT_COUNT=0
|
124 |
-
MIN_LATENCY=9999
|
125 |
-
MAX_LATENCY=0
|
126 |
-
|
127 |
-
# Variables for filtered latency statistics
|
128 |
-
TOTAL_FILTERED_LATENCY=0
|
129 |
-
TOTAL_FILTERED_TURNS=0
|
130 |
-
CLIENTS_WITH_FILTERED=0
|
131 |
-
MIN_FILTERED_LATENCY=9999
|
132 |
-
MAX_FILTERED_LATENCY=0
|
133 |
-
|
134 |
-
# Array to store all client average latencies for p95 calculation
|
135 |
-
CLIENT_LATENCIES=()
|
136 |
-
CLIENT_FILTERED_LATENCIES=()
|
137 |
-
|
138 |
-
# Arrays to track glitch detection
|
139 |
-
CLIENTS_WITH_GLITCHES=()
|
140 |
-
TOTAL_GLITCH_COUNT=0
|
141 |
-
|
142 |
-
# Variables to track reverse barge-in detection
|
143 |
-
TOTAL_REVERSE_BARGE_INS=0
|
144 |
-
CLIENT_REVERSE_BARGE_INS=()
|
145 |
-
CLIENTS_WITH_REVERSE_BARGE_INS=0
|
146 |
-
|
147 |
-
# Function to calculate p95 from an array of values
|
148 |
-
calculate_p95() {
|
149 |
-
local values=("$@")
|
150 |
-
local count=${#values[@]}
|
151 |
-
|
152 |
-
if [ $count -eq 0 ]; then
|
153 |
-
echo "0"
|
154 |
-
return
|
155 |
-
fi
|
156 |
-
|
157 |
-
# Sort the array (using a simple bubble sort for bash compatibility)
|
158 |
-
for ((i = 0; i < count; i++)); do
|
159 |
-
for ((j = i + 1; j < count; j++)); do
|
160 |
-
if (( $(echo "${values[i]} > ${values[j]}" | bc -l) )); then
|
161 |
-
temp=${values[i]}
|
162 |
-
values[i]=${values[j]}
|
163 |
-
values[j]=$temp
|
164 |
-
fi
|
165 |
-
done
|
166 |
-
done
|
167 |
-
|
168 |
-
# Calculate p95 index
|
169 |
-
local p95_index=$(echo "scale=0; ($count - 1) * 0.95" | bc -l | cut -d'.' -f1)
|
170 |
-
|
171 |
-
# Ensure index is within bounds
|
172 |
-
if [ $p95_index -ge $count ]; then
|
173 |
-
p95_index=$((count - 1))
|
174 |
-
fi
|
175 |
-
|
176 |
-
echo "${values[$p95_index]}"
|
177 |
-
}
|
178 |
-
|
179 |
-
# Process all result files
|
180 |
-
for result_file in "$OUTPUT_DIR"/latency_results_*.json; do
|
181 |
-
if [ -f "$result_file" ]; then
|
182 |
-
# Extract data using jq if available, otherwise use awk as fallback
|
183 |
-
if command -v jq &> /dev/null; then
|
184 |
-
AVG_LATENCY=$(jq '.average_latency' "$result_file")
|
185 |
-
FILTERED_AVG_LATENCY=$(jq '.filtered_average_latency' "$result_file")
|
186 |
-
NUM_TURNS=$(jq '.num_turns' "$result_file")
|
187 |
-
NUM_FILTERED_TURNS=$(jq '.num_filtered_turns' "$result_file")
|
188 |
-
STREAM_ID=$(jq -r '.stream_id' "$result_file")
|
189 |
-
GLITCH_DETECTED=$(jq '.glitch_detected' "$result_file")
|
190 |
-
REVERSE_BARGE_INS_COUNT=$(jq '.reverse_barge_ins_count' "$result_file")
|
191 |
-
else
|
192 |
-
# Fallback to grep and basic string processing
|
193 |
-
AVG_LATENCY=$(grep -o '"average_latency": [0-9.]*' "$result_file" | cut -d' ' -f2)
|
194 |
-
FILTERED_AVG_LATENCY=$(grep -o '"filtered_average_latency": [0-9.]*' "$result_file" | cut -d' ' -f2)
|
195 |
-
NUM_TURNS=$(grep -o '"num_turns": [0-9]*' "$result_file" | cut -d' ' -f2)
|
196 |
-
NUM_FILTERED_TURNS=$(grep -o '"num_filtered_turns": [0-9]*' "$result_file" | cut -d' ' -f2)
|
197 |
-
STREAM_ID=$(grep -o '"stream_id": "[^"]*"' "$result_file" | cut -d'"' -f4)
|
198 |
-
GLITCH_DETECTED=$(grep -o '"glitch_detected": [a-z]*' "$result_file" | cut -d' ' -f2)
|
199 |
-
REVERSE_BARGE_INS_COUNT=$(grep -o '"reverse_barge_ins_count": [0-9]*' "$result_file" | cut -d' ' -f2)
|
200 |
-
fi
|
201 |
-
|
202 |
-
echo "Client $STREAM_ID: Average latency = $AVG_LATENCY seconds over $NUM_TURNS turns"
|
203 |
-
|
204 |
-
# Display filtered latency information
|
205 |
-
if [ "$FILTERED_AVG_LATENCY" != "null" ] && [ -n "$FILTERED_AVG_LATENCY" ]; then
|
206 |
-
echo " Filtered Average latency (>$THRESHOLD s) = $FILTERED_AVG_LATENCY seconds over $NUM_FILTERED_TURNS turns"
|
207 |
-
|
208 |
-
# Add to filtered latency statistics
|
209 |
-
TOTAL_FILTERED_LATENCY=$(echo "$TOTAL_FILTERED_LATENCY + $FILTERED_AVG_LATENCY" | bc -l)
|
210 |
-
TOTAL_FILTERED_TURNS=$((TOTAL_FILTERED_TURNS + NUM_FILTERED_TURNS))
|
211 |
-
CLIENTS_WITH_FILTERED=$((CLIENTS_WITH_FILTERED + 1))
|
212 |
-
CLIENT_FILTERED_LATENCIES+=($FILTERED_AVG_LATENCY)
|
213 |
-
|
214 |
-
# Update min/max filtered latency
|
215 |
-
if (( $(echo "$FILTERED_AVG_LATENCY < $MIN_FILTERED_LATENCY" | bc -l) )); then
|
216 |
-
MIN_FILTERED_LATENCY=$FILTERED_AVG_LATENCY
|
217 |
-
fi
|
218 |
-
|
219 |
-
if (( $(echo "$FILTERED_AVG_LATENCY > $MAX_FILTERED_LATENCY" | bc -l) )); then
|
220 |
-
MAX_FILTERED_LATENCY=$FILTERED_AVG_LATENCY
|
221 |
-
fi
|
222 |
-
else
|
223 |
-
echo " No latencies above $THRESHOLD s threshold"
|
224 |
-
fi
|
225 |
-
|
226 |
-
# Check for glitch detection
|
227 |
-
if [ "$GLITCH_DETECTED" = "true" ]; then
|
228 |
-
echo " ⚠️ Audio glitches detected in client $STREAM_ID"
|
229 |
-
CLIENTS_WITH_GLITCHES+=("$STREAM_ID")
|
230 |
-
TOTAL_GLITCH_COUNT=$((TOTAL_GLITCH_COUNT + 1))
|
231 |
-
fi
|
232 |
-
|
233 |
-
# Display and track reverse barge-in count
|
234 |
-
if [ -n "$REVERSE_BARGE_INS_COUNT" ] && [ "$REVERSE_BARGE_INS_COUNT" -gt 0 ]; then
|
235 |
-
echo " Reverse barge-ins detected: $REVERSE_BARGE_INS_COUNT occurrences"
|
236 |
-
CLIENT_REVERSE_BARGE_INS+=("$STREAM_ID")
|
237 |
-
CLIENTS_WITH_REVERSE_BARGE_INS=$((CLIENTS_WITH_REVERSE_BARGE_INS + 1))
|
238 |
-
fi
|
239 |
-
|
240 |
-
# Add to total reverse barge-in count
|
241 |
-
if [ -n "$REVERSE_BARGE_INS_COUNT" ]; then
|
242 |
-
TOTAL_REVERSE_BARGE_INS=$((TOTAL_REVERSE_BARGE_INS + REVERSE_BARGE_INS_COUNT))
|
243 |
-
fi
|
244 |
-
|
245 |
-
# Add to array for p95 calculation
|
246 |
-
CLIENT_LATENCIES+=($AVG_LATENCY)
|
247 |
-
|
248 |
-
# Update aggregate statistics
|
249 |
-
TOTAL_LATENCY=$(echo "$TOTAL_LATENCY + $AVG_LATENCY" | bc -l)
|
250 |
-
TOTAL_TURNS=$((TOTAL_TURNS + NUM_TURNS))
|
251 |
-
CLIENT_COUNT=$((CLIENT_COUNT + 1))
|
252 |
-
|
253 |
-
# Update min/max latency
|
254 |
-
if (( $(echo "$AVG_LATENCY < $MIN_LATENCY" | bc -l) )); then
|
255 |
-
MIN_LATENCY=$AVG_LATENCY
|
256 |
-
fi
|
257 |
-
|
258 |
-
if (( $(echo "$AVG_LATENCY > $MAX_LATENCY" | bc -l) )); then
|
259 |
-
MAX_LATENCY=$AVG_LATENCY
|
260 |
-
fi
|
261 |
-
fi
|
262 |
-
done
|
263 |
-
|
264 |
-
# Calculate overall statistics
|
265 |
-
if [ $CLIENT_COUNT -gt 0 ]; then
|
266 |
-
AGGREGATE_AVG_LATENCY=$(echo "scale=3; $TOTAL_LATENCY / $CLIENT_COUNT" | bc -l)
|
267 |
-
P95_CLIENT_LATENCY=$(calculate_p95 "${CLIENT_LATENCIES[@]}")
|
268 |
-
|
269 |
-
# Calculate filtered statistics
|
270 |
-
AGGREGATE_FILTERED_AVG_LATENCY="null"
|
271 |
-
P95_FILTERED_CLIENT_LATENCY="null"
|
272 |
-
|
273 |
-
if [ $CLIENTS_WITH_FILTERED -gt 0 ]; then
|
274 |
-
AGGREGATE_FILTERED_AVG_LATENCY=$(echo "scale=3; $TOTAL_FILTERED_LATENCY / $CLIENTS_WITH_FILTERED" | bc -l)
|
275 |
-
P95_FILTERED_CLIENT_LATENCY=$(calculate_p95 "${CLIENT_FILTERED_LATENCIES[@]}")
|
276 |
-
fi
|
277 |
-
|
278 |
-
echo "=============================================="
|
279 |
-
echo "BENCHMARK SUMMARY (Staggered Start)"
|
280 |
-
echo "=============================================="
|
281 |
-
echo "Total clients: $CLIENT_COUNT"
|
282 |
-
echo "Latency threshold: $THRESHOLD seconds"
|
283 |
-
echo ""
|
284 |
-
echo "STANDARD LATENCY STATISTICS:"
|
285 |
-
echo "Average latency across all clients: $AGGREGATE_AVG_LATENCY seconds"
|
286 |
-
echo "P95 latency across client averages: $P95_CLIENT_LATENCY seconds"
|
287 |
-
echo "Minimum client average latency: $MIN_LATENCY seconds"
|
288 |
-
echo "Maximum client average latency: $MAX_LATENCY seconds"
|
289 |
-
echo ""
|
290 |
-
echo "FILTERED LATENCY STATISTICS (>$THRESHOLD s):"
|
291 |
-
if [ "$AGGREGATE_FILTERED_AVG_LATENCY" != "null" ]; then
|
292 |
-
echo "Clients with filtered data: $CLIENTS_WITH_FILTERED out of $CLIENT_COUNT"
|
293 |
-
echo "Average filtered latency: $AGGREGATE_FILTERED_AVG_LATENCY seconds"
|
294 |
-
echo "P95 filtered latency: $P95_FILTERED_CLIENT_LATENCY seconds"
|
295 |
-
echo "Minimum client filtered latency: $MIN_FILTERED_LATENCY seconds"
|
296 |
-
echo "Maximum client filtered latency: $MAX_FILTERED_LATENCY seconds"
|
297 |
-
else
|
298 |
-
echo "No latencies above $THRESHOLD s threshold found across all clients"
|
299 |
-
fi
|
300 |
-
echo ""
|
301 |
-
echo "AUDIO GLITCH DETECTION:"
|
302 |
-
if [ $TOTAL_GLITCH_COUNT -gt 0 ]; then
|
303 |
-
echo "⚠️ Audio glitches detected in $TOTAL_GLITCH_COUNT out of $CLIENT_COUNT clients"
|
304 |
-
echo "Affected clients:"
|
305 |
-
for client in "${CLIENTS_WITH_GLITCHES[@]}"; do
|
306 |
-
echo " - $client"
|
307 |
-
done
|
308 |
-
else
|
309 |
-
echo "✅ No audio glitches detected in any client"
|
310 |
-
fi
|
311 |
-
|
312 |
-
echo "REVERSE BARGE-IN DETECTION:"
|
313 |
-
if [ $TOTAL_REVERSE_BARGE_INS -gt 0 ]; then
|
314 |
-
echo "Total reverse barge-ins detected: $TOTAL_REVERSE_BARGE_INS occurrences"
|
315 |
-
echo "Clients with reverse barge-ins: $CLIENTS_WITH_REVERSE_BARGE_INS out of $CLIENT_COUNT"
|
316 |
-
if [ $CLIENTS_WITH_REVERSE_BARGE_INS -gt 0 ]; then
|
317 |
-
echo "Affected clients:"
|
318 |
-
for client in "${CLIENT_REVERSE_BARGE_INS[@]}"; do
|
319 |
-
echo " - $client"
|
320 |
-
done
|
321 |
-
fi
|
322 |
-
else
|
323 |
-
echo "✅ No reverse barge-ins detected in any client"
|
324 |
-
fi
|
325 |
-
|
326 |
-
echo ""
|
327 |
-
echo "ERROR DETECTION:"
|
328 |
-
# Initialize arrays for error tracking
|
329 |
-
declare -A CLIENT_ERROR_COUNTS
|
330 |
-
CLIENTS_WITH_ERRORS=0
|
331 |
-
TOTAL_ERRORS=0
|
332 |
-
|
333 |
-
# Process logs for each client to find errors
|
334 |
-
for ((i=1; i<=$NUM_CLIENTS; i++)); do
|
335 |
-
LOG_FILE="$OUTPUT_DIR/client_${i}.log"
|
336 |
-
if [ -f "$LOG_FILE" ]; then
|
337 |
-
ERROR_COUNT=$(grep -c "^\[ERROR\]" "$LOG_FILE")
|
338 |
-
if [ $ERROR_COUNT -gt 0 ]; then
|
339 |
-
CLIENTS_WITH_ERRORS=$((CLIENTS_WITH_ERRORS + 1))
|
340 |
-
TOTAL_ERRORS=$((TOTAL_ERRORS + ERROR_COUNT))
|
341 |
-
CLIENT_ERROR_COUNTS["client_${i}"]=$ERROR_COUNT
|
342 |
-
|
343 |
-
echo "⚠️ Client ${i} errors ($ERROR_COUNT):"
|
344 |
-
grep "^\[ERROR\]" "$LOG_FILE" | sed 's/^/ /'
|
345 |
-
fi
|
346 |
-
fi
|
347 |
-
done
|
348 |
-
|
349 |
-
if [ $TOTAL_ERRORS -eq 0 ]; then
|
350 |
-
echo "✅ No errors detected in any client"
|
351 |
-
else
|
352 |
-
echo "⚠️ Total errors across all clients: $TOTAL_ERRORS"
|
353 |
-
echo "⚠️ Clients with errors: $CLIENTS_WITH_ERRORS out of $CLIENT_COUNT"
|
354 |
-
fi
|
355 |
-
|
356 |
-
# Create summary JSON
|
357 |
-
cat > "$SUMMARY_FILE" << EOF
|
358 |
-
{
|
359 |
-
"timestamp": "$(date -Iseconds)",
|
360 |
-
"config": {
|
361 |
-
"host": "$HOST",
|
362 |
-
"port": $PORT,
|
363 |
-
"num_clients": $NUM_CLIENTS,
|
364 |
-
"client_start_delay": $CLIENT_START_DELAY,
|
365 |
-
"test_duration": $TEST_DURATION,
|
366 |
-
"threshold": $THRESHOLD,
|
367 |
-
"metrics_start_time": $METRICS_START_TIME
|
368 |
-
},
|
369 |
-
"results": {
|
370 |
-
"total_clients": $CLIENT_COUNT,
|
371 |
-
"total_turns": $TOTAL_TURNS,
|
372 |
-
"aggregate_average_latency": $AGGREGATE_AVG_LATENCY,
|
373 |
-
"p95_client_latency": $P95_CLIENT_LATENCY,
|
374 |
-
"min_client_latency": $MIN_LATENCY,
|
375 |
-
"max_client_latency": $MAX_LATENCY,
|
376 |
-
"filtered_results": {
|
377 |
-
"clients_with_filtered_data": $CLIENTS_WITH_FILTERED,
|
378 |
-
"total_filtered_turns": $TOTAL_FILTERED_TURNS,
|
379 |
-
"aggregate_filtered_average_latency": $AGGREGATE_FILTERED_AVG_LATENCY,
|
380 |
-
"p95_filtered_client_latency": $P95_FILTERED_CLIENT_LATENCY,
|
381 |
-
"min_filtered_client_latency": $([ "$AGGREGATE_FILTERED_AVG_LATENCY" != "null" ] && echo "$MIN_FILTERED_LATENCY" || echo "null"),
|
382 |
-
"max_filtered_client_latency": $([ "$AGGREGATE_FILTERED_AVG_LATENCY" != "null" ] && echo "$MAX_FILTERED_LATENCY" || echo "null")
|
383 |
-
},
|
384 |
-
"glitch_detection": {
|
385 |
-
"clients_with_glitches": $TOTAL_GLITCH_COUNT,
|
386 |
-
"total_clients": $CLIENT_COUNT,
|
387 |
-
"affected_client_ids": [$(printf '"%s",' "${CLIENTS_WITH_GLITCHES[@]}" | sed 's/,$//')]
|
388 |
-
},
|
389 |
-
"reverse_barge_in_detection": {
|
390 |
-
"total_clients": $CLIENT_COUNT,
|
391 |
-
"total_reverse_barge_ins": $TOTAL_REVERSE_BARGE_INS,
|
392 |
-
"clients_with_reverse_barge_ins": $CLIENTS_WITH_REVERSE_BARGE_INS,
|
393 |
-
"affected_client_ids": [$(printf '"%s",' "${CLIENT_REVERSE_BARGE_INS[@]}" | sed 's/,$//')]
|
394 |
-
},
|
395 |
-
"error_detection": {
|
396 |
-
"total_clients": $CLIENT_COUNT,
|
397 |
-
"total_errors": $TOTAL_ERRORS,
|
398 |
-
"clients_with_errors": $CLIENTS_WITH_ERRORS,
|
399 |
-
"client_error_counts": {
|
400 |
-
$(for client in "${!CLIENT_ERROR_COUNTS[@]}"; do
|
401 |
-
printf '"%s": %d,\n ' "$client" "${CLIENT_ERROR_COUNTS[$client]}"
|
402 |
-
done | sed 's/,\s*$//')
|
403 |
-
}
|
404 |
-
}
|
405 |
-
}
|
406 |
-
}
|
407 |
-
EOF
|
408 |
-
|
409 |
-
echo "Summary saved to: $SUMMARY_FILE"
|
410 |
-
else
|
411 |
-
echo "No valid result files found!"
|
412 |
-
fi
|
413 |
-
|
414 |
-
echo "Benchmark complete."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/perf/ttfb_analyzer.py
DELETED
@@ -1,253 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""TTFB Log Analyzer.
|
3 |
-
|
4 |
-
Analyzes Time To First Byte (TTFB) logs, ASR compute latency, and LLM first sentence generation time
|
5 |
-
for multiple client streams and calculates average TTFB, ASR latency, first sentence time, and P95
|
6 |
-
for LLM, TTS, and ASR services.
|
7 |
-
|
8 |
-
Usage:
|
9 |
-
python ttfb_analyzer.py [log_file_path]
|
10 |
-
python ttfb_analyzer.py --help
|
11 |
-
|
12 |
-
Examples:
|
13 |
-
python ttfb_analyzer.py
|
14 |
-
python ttfb_analyzer.py /path/to/botlogs.log
|
15 |
-
python ttfb_analyzer.py ../../examples/speech-to-speech/botlogs.log
|
16 |
-
"""
|
17 |
-
|
18 |
-
import argparse
|
19 |
-
import logging
|
20 |
-
import os
|
21 |
-
import re
|
22 |
-
import sys
|
23 |
-
from collections import defaultdict
|
24 |
-
|
25 |
-
# Set up logging
|
26 |
-
logging.basicConfig(level=logging.INFO)
|
27 |
-
logger = logging.getLogger(__name__)
|
28 |
-
|
29 |
-
|
30 |
-
def calculate_p95(values: list[float]) -> float:
|
31 |
-
"""Calculate 95th percentile of values."""
|
32 |
-
if not values:
|
33 |
-
return 0.0
|
34 |
-
sorted_values = sorted(values)
|
35 |
-
index = int(0.95 * (len(sorted_values) - 1))
|
36 |
-
return sorted_values[index]
|
37 |
-
|
38 |
-
|
39 |
-
def parse_logs(log_file_path: str) -> dict[str, dict[str, list[float]]]:
|
40 |
-
"""Parse LLM, TTS TTFBs, ASR compute latency, and LLM first sentence generation logs.
|
41 |
-
|
42 |
-
Organize by client stream and service type. Only include events after the last client start.
|
43 |
-
"""
|
44 |
-
data = defaultdict(lambda: {"LLM": [], "TTS": [], "ASR": [], "LLM_FIRST_SENTENCE": []})
|
45 |
-
ttfb_pattern = r"streamId=([^\s]+)\s+-\s+(NvidiaLLMService|RivaTTSService)#\d+\s+TTFB:\s+([\d.]+)"
|
46 |
-
asr_pattern = r"streamId=([^\s]+)\s+-\s+RivaASRService#\d+\s+ASR compute latency:\s+([\d.]+)"
|
47 |
-
first_sentence_pattern = (
|
48 |
-
r"streamId=([^\s]+)\s+-\s+NvidiaLLMService#\d+\s+LLM first sentence generation time:\s+([\d.]+)"
|
49 |
-
)
|
50 |
-
websocket_pattern = r".*Accepting WebSocket connection for stream ID client_\d+_\d+"
|
51 |
-
|
52 |
-
# First pass: find the last client start log
|
53 |
-
last_client_start_line = -1
|
54 |
-
|
55 |
-
try:
|
56 |
-
# Read all lines to find the last client start
|
57 |
-
with open(log_file_path) as file:
|
58 |
-
lines = file.readlines()
|
59 |
-
|
60 |
-
# Find the last client start log by iterating through all lines
|
61 |
-
for i, line in enumerate(lines):
|
62 |
-
if re.search(websocket_pattern, line):
|
63 |
-
last_client_start_line = i
|
64 |
-
|
65 |
-
# Validate that we found at least one client start
|
66 |
-
if last_client_start_line == -1:
|
67 |
-
logger.warning("No client start pattern found in logs")
|
68 |
-
return dict()
|
69 |
-
|
70 |
-
# Second pass: analyze only events after the last client start
|
71 |
-
with open(log_file_path) as file:
|
72 |
-
for i, line in enumerate(file):
|
73 |
-
# Skip lines before the last client start
|
74 |
-
if last_client_start_line != -1 and i <= last_client_start_line:
|
75 |
-
continue
|
76 |
-
|
77 |
-
try:
|
78 |
-
# Check for TTFB metrics
|
79 |
-
ttfb_match = re.search(ttfb_pattern, line)
|
80 |
-
if ttfb_match:
|
81 |
-
client_id = ttfb_match.group(1).strip()
|
82 |
-
service_type = ttfb_match.group(2)
|
83 |
-
try:
|
84 |
-
ttfb_value = float(ttfb_match.group(3))
|
85 |
-
except (ValueError, TypeError) as e:
|
86 |
-
logger.warning(f"Invalid TTFB value in line {i + 1}: {ttfb_match.group(3)} - {e}")
|
87 |
-
continue
|
88 |
-
|
89 |
-
if service_type == "NvidiaLLMService":
|
90 |
-
data[client_id]["LLM"].append(ttfb_value)
|
91 |
-
elif service_type == "RivaTTSService":
|
92 |
-
data[client_id]["TTS"].append(ttfb_value)
|
93 |
-
|
94 |
-
# Check for ASR compute latency metrics
|
95 |
-
asr_match = re.search(asr_pattern, line)
|
96 |
-
if asr_match:
|
97 |
-
client_id = asr_match.group(1).strip()
|
98 |
-
try:
|
99 |
-
asr_latency = float(asr_match.group(2))
|
100 |
-
except (ValueError, TypeError) as e:
|
101 |
-
logger.warning(f"Invalid ASR latency value in line {i + 1}: {asr_match.group(2)} - {e}")
|
102 |
-
continue
|
103 |
-
data[client_id]["ASR"].append(asr_latency)
|
104 |
-
|
105 |
-
# Check for LLM first sentence generation time metrics
|
106 |
-
first_sentence_match = re.search(first_sentence_pattern, line)
|
107 |
-
if first_sentence_match:
|
108 |
-
client_id = first_sentence_match.group(1).strip()
|
109 |
-
try:
|
110 |
-
first_sentence_time = float(first_sentence_match.group(2))
|
111 |
-
except (ValueError, TypeError) as e:
|
112 |
-
logger.warning(
|
113 |
-
f"Invalid first sentence time value in line {i + 1}: "
|
114 |
-
f"{first_sentence_match.group(2)} - {e}"
|
115 |
-
)
|
116 |
-
continue
|
117 |
-
data[client_id]["LLM_FIRST_SENTENCE"].append(first_sentence_time)
|
118 |
-
|
119 |
-
except Exception as e:
|
120 |
-
logger.warning(f"Error parsing line {i + 1}: {e}")
|
121 |
-
continue
|
122 |
-
|
123 |
-
except FileNotFoundError:
|
124 |
-
print(f"Error: Log file '{log_file_path}' not found.")
|
125 |
-
sys.exit(1)
|
126 |
-
except Exception as e:
|
127 |
-
print(f"Error reading log file: {e}")
|
128 |
-
sys.exit(1)
|
129 |
-
|
130 |
-
return dict(data)
|
131 |
-
|
132 |
-
|
133 |
-
def calculate_client_averages(data: dict[str, dict[str, list[float]]]) -> dict[str, dict[str, float]]:
|
134 |
-
"""Calculate average metrics for each client and service type."""
|
135 |
-
averages = {}
|
136 |
-
for client_id, services in data.items():
|
137 |
-
averages[client_id] = {}
|
138 |
-
for service_type, values in services.items():
|
139 |
-
if values:
|
140 |
-
averages[client_id][service_type] = sum(values) / len(values)
|
141 |
-
else:
|
142 |
-
averages[client_id][service_type] = 0.0
|
143 |
-
return averages
|
144 |
-
|
145 |
-
|
146 |
-
def print_results(data: dict[str, dict[str, list[float]]], client_averages: dict[str, dict[str, float]]):
|
147 |
-
"""Print analysis results."""
|
148 |
-
print("=" * 90)
|
149 |
-
print("LATENCY ANALYSIS RESULTS")
|
150 |
-
print("=" * 90)
|
151 |
-
|
152 |
-
# Show metric arrays for each client
|
153 |
-
for client_id in sorted(data.keys()):
|
154 |
-
llm_values = data[client_id]["LLM"]
|
155 |
-
tts_values = data[client_id]["TTS"]
|
156 |
-
asr_values = data[client_id]["ASR"]
|
157 |
-
first_sentence_values = data[client_id]["LLM_FIRST_SENTENCE"]
|
158 |
-
|
159 |
-
print(f"\n{client_id}:")
|
160 |
-
print(f" LLM TTFB: {[f'{v:.3f}' for v in llm_values]}")
|
161 |
-
print(f" TTS TTFB: {[f'{v:.3f}' for v in tts_values]}")
|
162 |
-
print(f" ASR Latency: {[f'{v:.3f}' for v in asr_values]}")
|
163 |
-
print(f" LLM First Sentence: {[f'{v:.3f}' for v in first_sentence_values]}")
|
164 |
-
|
165 |
-
# Summary table with overall statistics
|
166 |
-
print(
|
167 |
-
f"\n{'Client ID':<25} {'LLM TTFB':<10} {'TTS TTFB':<10} {'ASR Lat':<10} "
|
168 |
-
f"{'LLM 1st':<10} {'LLM calls':<10} {'TTS calls':<10} {'ASR calls':<10}"
|
169 |
-
)
|
170 |
-
print("-" * 120)
|
171 |
-
|
172 |
-
for client_id in sorted(data.keys()):
|
173 |
-
llm_avg = client_averages[client_id]["LLM"]
|
174 |
-
tts_avg = client_averages[client_id]["TTS"]
|
175 |
-
asr_avg = client_averages[client_id]["ASR"]
|
176 |
-
first_sentence_avg = client_averages[client_id]["LLM_FIRST_SENTENCE"]
|
177 |
-
llm_count = len(data[client_id]["LLM"])
|
178 |
-
tts_count = len(data[client_id]["TTS"])
|
179 |
-
asr_count = len(data[client_id]["ASR"])
|
180 |
-
print(
|
181 |
-
f"{client_id:<25} {llm_avg:<10.3f} {tts_avg:<10.3f} {asr_avg:<10.3f} {first_sentence_avg:<10.3f} "
|
182 |
-
f"{llm_count:<10} {tts_count:<10} {asr_count:<10}"
|
183 |
-
)
|
184 |
-
|
185 |
-
# Calculate overall statistics across client averages
|
186 |
-
llm_client_averages = [avg["LLM"] for avg in client_averages.values() if avg["LLM"] > 0]
|
187 |
-
tts_client_averages = [avg["TTS"] for avg in client_averages.values() if avg["TTS"] > 0]
|
188 |
-
asr_client_averages = [avg["ASR"] for avg in client_averages.values() if avg["ASR"] > 0]
|
189 |
-
first_sentence_client_averages = [
|
190 |
-
avg["LLM_FIRST_SENTENCE"] for avg in client_averages.values() if avg["LLM_FIRST_SENTENCE"] > 0
|
191 |
-
]
|
192 |
-
|
193 |
-
# Add separator and overall statistics rows
|
194 |
-
print("-" * 120)
|
195 |
-
|
196 |
-
if llm_client_averages and tts_client_averages and asr_client_averages:
|
197 |
-
llm_overall_avg = sum(llm_client_averages) / len(llm_client_averages)
|
198 |
-
llm_p95 = calculate_p95(llm_client_averages)
|
199 |
-
tts_overall_avg = sum(tts_client_averages) / len(tts_client_averages)
|
200 |
-
tts_p95 = calculate_p95(tts_client_averages)
|
201 |
-
asr_overall_avg = sum(asr_client_averages) / len(asr_client_averages)
|
202 |
-
asr_p95 = calculate_p95(asr_client_averages)
|
203 |
-
|
204 |
-
first_sentence_overall_avg = (
|
205 |
-
sum(first_sentence_client_averages) / len(first_sentence_client_averages)
|
206 |
-
if first_sentence_client_averages
|
207 |
-
else 0.0
|
208 |
-
)
|
209 |
-
first_sentence_p95 = calculate_p95(first_sentence_client_averages) if first_sentence_client_averages else 0.0
|
210 |
-
|
211 |
-
print(
|
212 |
-
f"{'OVERALL AVERAGE':<25} {llm_overall_avg:<10.3f} {tts_overall_avg:<10.3f} "
|
213 |
-
f"{asr_overall_avg:<10.3f} {first_sentence_overall_avg:<10.3f}"
|
214 |
-
)
|
215 |
-
print(f"{'OVERALL P95':<25} {llm_p95:<10.3f} {tts_p95:<10.3f} {asr_p95:<10.3f} {first_sentence_p95:<10.3f}")
|
216 |
-
|
217 |
-
print("-" * 120)
|
218 |
-
|
219 |
-
|
220 |
-
def main():
|
221 |
-
"""Main function."""
|
222 |
-
parser = argparse.ArgumentParser(
|
223 |
-
description="Analyze LLM, TTS TTFBs, ASR latency, and LLM first sentence generation time logs "
|
224 |
-
"for multiple client streams"
|
225 |
-
)
|
226 |
-
parser.add_argument(
|
227 |
-
"log_file",
|
228 |
-
nargs="?",
|
229 |
-
default="../../examples/speech-to-speech/botlogs.log",
|
230 |
-
help="Path to log file (default: ../../examples/speech-to-speech/botlogs.log)",
|
231 |
-
)
|
232 |
-
args = parser.parse_args()
|
233 |
-
|
234 |
-
print("Latency Log Analyzer")
|
235 |
-
print(f"Analyzing: {args.log_file}")
|
236 |
-
|
237 |
-
if not os.path.exists(args.log_file):
|
238 |
-
print(f"Error: Log file '{args.log_file}' not found.")
|
239 |
-
sys.exit(1)
|
240 |
-
|
241 |
-
data = parse_logs(args.log_file)
|
242 |
-
if not data:
|
243 |
-
print("No performance data found in log file.")
|
244 |
-
return
|
245 |
-
|
246 |
-
print()
|
247 |
-
|
248 |
-
client_averages = calculate_client_averages(data)
|
249 |
-
print_results(data, client_averages)
|
250 |
-
|
251 |
-
|
252 |
-
if __name__ == "__main__":
|
253 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the nvidia-pipecat package."""
|
|
|
|
|
|
|
|
|
|
tests/unit/configs/animation_config.yaml
DELETED
@@ -1,346 +0,0 @@
|
|
1 |
-
animation_types:
|
2 |
-
gesture:
|
3 |
-
duration_relevant_animation_name: gesture
|
4 |
-
animations:
|
5 |
-
gesture:
|
6 |
-
default_clip_id: none
|
7 |
-
clips:
|
8 |
-
- clip_id: Test
|
9 |
-
description: "A virtual test clip for testing purposes"
|
10 |
-
duration: 0.5
|
11 |
-
meaning: A very short animation clip for unit testing purposes
|
12 |
-
- clip_id: Goodbye
|
13 |
-
description: "Waving goodbye: Waves with left hand extended high."
|
14 |
-
duration: 2
|
15 |
-
meaning: Taking leave of someone from a further distance, or getting someone's attention.
|
16 |
-
- clip_id: Welcome
|
17 |
-
description: "Waving hello: Spreads arms slightly, then raises right hand next to face and waves with an open hand."
|
18 |
-
duration: 2.5
|
19 |
-
meaning: Greeting someone in a shy or cute manner, showing a positive and non-threatening attitude.
|
20 |
-
- clip_id: Personal_Statement_2
|
21 |
-
description: "Personal statement: Leans forward and points to self with relaxed right hand, then leans back and opens arms wide with palms facing upwards."
|
22 |
-
duration: 3
|
23 |
-
meaning: Revealing something about themselves in a grandiose gesture, or making a little joke about their appearance or personality.
|
24 |
-
- clip_id: Pointing_To_Self_1
|
25 |
-
description: "Pointing to self: Leans forward slightly and with a relaxed right index finger points to self."
|
26 |
-
duration: 2.5
|
27 |
-
meaning: Saying something about themselves.
|
28 |
-
- clip_id: Stupid_1
|
29 |
-
description: "Stupid: Raising right hand next to head and twirling the index finger in circles."
|
30 |
-
duration: 3.5
|
31 |
-
meaning: Indicated someone or something is stupid crazy or dumb.
|
32 |
-
- clip_id: No_1
|
33 |
-
description: "Shaking head: Avatar shakes head slowly."
|
34 |
-
duration: 3
|
35 |
-
meaning: Expressing strong disagreement or disappointment.
|
36 |
-
- clip_id: Bowing_1
|
37 |
-
description:
|
38 |
-
"Bowing: Slightly bows to the front making invitation gesture with
|
39 |
-
both arms."
|
40 |
-
duration: 2.5
|
41 |
-
meaning: Formal greeting, sign of respect or congratulations or pride.
|
42 |
-
- clip_id: Bowing_2
|
43 |
-
description:
|
44 |
-
"Bowing: Slightly bows with both arms and invitational gesture with
|
45 |
-
right arm."
|
46 |
-
duration: 2.5
|
47 |
-
meaning: Overly formal greeting, sign of respect or grand introduction.
|
48 |
-
- clip_id: Pointing_To_User_1
|
49 |
-
description: "Pointing to user: Pointing with both arms towards the user."
|
50 |
-
duration: 2.5
|
51 |
-
meaning:
|
52 |
-
Encouragement for the user to make a move, approach or say something, or
|
53 |
-
pointing out the user is being addressed. This is also an encouraging and reassuring gesture.
|
54 |
-
- clip_id: Pointing_To_User_2
|
55 |
-
description: "Pointing to user: Insistingly pointing to user with the right arm."
|
56 |
-
duration: 2.5
|
57 |
-
meaning: Accusation or strong signal that the user is concerned.
|
58 |
-
- clip_id: Pointing_Down_1
|
59 |
-
description: "Pointing down: Lifting the right arm to shoulders and pointing."
|
60 |
-
duration: 3
|
61 |
-
meaning: Drawing attention to something below the screen or in front of the avatar.
|
62 |
-
- clip_id: Pointing_Down_2
|
63 |
-
description: "Pointing down: Lifting both arms and slightly point downwards."
|
64 |
-
duration: 2
|
65 |
-
meaning:
|
66 |
-
Drawing attention to the desk, to something below the screen or in front
|
67 |
-
of the avatar. Or informing about the location.
|
68 |
-
- clip_id: Pointing_Left_1
|
69 |
-
description:
|
70 |
-
"Pointing left: Pointing with both arms to the left hand side of the
|
71 |
-
avatar."
|
72 |
-
duration: 4
|
73 |
-
meaning:
|
74 |
-
Pointing out something to the left of the avatar in a demanding way, or
|
75 |
-
signaling frustration.
|
76 |
-
- clip_id: Pointing_Left_2
|
77 |
-
description: "Pointing left: Pointing with left arm to the back left of the avatar."
|
78 |
-
duration: 4
|
79 |
-
meaning:
|
80 |
-
Calmly pointing out or presenting something to the left of the avatar.
|
81 |
-
Or giving information about something behind the avatar.
|
82 |
-
- clip_id: Pointing_Backward_1
|
83 |
-
description: "Pointing to back: Pointing to the back with the extended right arm."
|
84 |
-
duration: 3.5
|
85 |
-
meaning:
|
86 |
-
Informing about something in the direction towards the back or giving directions
|
87 |
-
to something behind the avatar.
|
88 |
-
- clip_id: Fistbump_Offer
|
89 |
-
description:
|
90 |
-
"Fistbump: Extend the right arm with right hand twisting slightly for 3 seconds, then
|
91 |
-
bumping fist towards the user for 7 seconds."
|
92 |
-
duration: 10
|
93 |
-
meaning: Invitation to a fistbump followed by doing that fistbump.
|
94 |
-
- clip_id: Pulling_Mime
|
95 |
-
description: "Pulling rope: Avatar grabs invisible rope and imitates pulling behavior."
|
96 |
-
duration: 3.5
|
97 |
-
meaning: Suggesting being tethered or chained, or pulling something.
|
98 |
-
- clip_id: Raise_Both_Arms
|
99 |
-
description:
|
100 |
-
"Raising both arms: Raising both arms above avatar's head and swaying
|
101 |
-
slightly."
|
102 |
-
duration: 3.5
|
103 |
-
meaning:
|
104 |
-
Implying in a crowd celebrating, or on a roller coaster or demonstrating
|
105 |
-
not having anything on them.
|
106 |
-
- clip_id: The_Robot
|
107 |
-
description:
|
108 |
-
"Robot dance: Imitating a dancing robot with arms moved in mechanical
|
109 |
-
motion."
|
110 |
-
duration: 3
|
111 |
-
meaning: Jokingly playing a robot or dancing to celebrate or acting goofy.
|
112 |
-
- clip_id: Phone_Dialing
|
113 |
-
description:
|
114 |
-
"Phone dialing: Raising left hand and imitating to dial a phone with
|
115 |
-
right arm."
|
116 |
-
duration: 3.5
|
117 |
-
meaning: Asking for or mentioning a phone number, or talking about calling someone.
|
118 |
-
- clip_id: Attraction_2
|
119 |
-
description:
|
120 |
-
"Having fun: Questioningly opening both arms and then raising them
|
121 |
-
above the head mimicking being in a roller coaster ride."
|
122 |
-
duration: 10
|
123 |
-
meaning: Being silly, suggesting having fun.
|
124 |
-
- clip_id: Please_Repeat_1
|
125 |
-
description:
|
126 |
-
"Please repeat: Moving head slightly to the user and making circular
|
127 |
-
motion with right arm, then shrugging slightly."
|
128 |
-
duration: 5.5
|
129 |
-
meaning:
|
130 |
-
Implying not having understood something, asking to repeat or rephrase,
|
131 |
-
or needing more information.
|
132 |
-
- clip_id: Please_Repeat_2
|
133 |
-
description:
|
134 |
-
"Presentation: Twirling both hands and making invitational pose with
|
135 |
-
body."
|
136 |
-
duration: 3.5
|
137 |
-
meaning:
|
138 |
-
Asking to repeat or rephrase something, or needing more information, or
|
139 |
-
asking if something was understood.
|
140 |
-
- clip_id: Trying_To_See
|
141 |
-
description:
|
142 |
-
"Trying to see: Lifting left hand above eyes and making gestures to
|
143 |
-
see better, then shrug."
|
144 |
-
duration: 4
|
145 |
-
meaning: Implying looking for but not seeing something.
|
146 |
-
- clip_id: Driving_Mime
|
147 |
-
description:
|
148 |
-
"Driving: Grabbing an invisible steering wheel with both hands, turning
|
149 |
-
it and switching gears."
|
150 |
-
duration: 4.5
|
151 |
-
meaning: Sharing a story about driving or getting excited about cars.
|
152 |
-
- clip_id: Exhausted
|
153 |
-
description: "Exhausted: Letting head hang in a tired pose, slightly leaning."
|
154 |
-
duration: 4.5
|
155 |
-
meaning:
|
156 |
-
Dramatically signaling exhaustion or running out of power, slowly shutting
|
157 |
-
down.
|
158 |
-
- clip_id: Presenting_Options_1
|
159 |
-
description:
|
160 |
-
"Presenting options: Showing open palms of both hands and making presenting
|
161 |
-
motion with right hand, slight shrug."
|
162 |
-
duration: 3.5
|
163 |
-
meaning: Giving an overview or multiple options to choose from.
|
164 |
-
- clip_id: Presenting_Options_2
|
165 |
-
description:
|
166 |
-
"Presenting options: Raising and opening one hand after the other and
|
167 |
-
a subtle shrug."
|
168 |
-
duration: 3
|
169 |
-
meaning: Suggesting a choice between two options.
|
170 |
-
- clip_id: Open_Question_1
|
171 |
-
description: "Open question: Opening both hands and showing palms to user."
|
172 |
-
duration: 3
|
173 |
-
meaning:
|
174 |
-
"Making a questioning gesture. Waiting for the user to make a choice, answer a question or say something.
|
175 |
-
Indicates questioning the user caring about the user's answer maybe even showing concerns"
|
176 |
-
- clip_id: Personal_Statement_1
|
177 |
-
description:
|
178 |
-
"Personal statement: Raising right hand to chest, extending and gesturing
|
179 |
-
with left hand."
|
180 |
-
duration: 3.5
|
181 |
-
meaning:
|
182 |
-
Making a personal statement, explaining something about themselves or making
|
183 |
-
a suggestion relating to something on the left.
|
184 |
-
- clip_id: Success_1
|
185 |
-
description:
|
186 |
-
"Success: Making a fist and raising the arm excitedly in a successful swinging
|
187 |
-
motion."
|
188 |
-
duration: 2
|
189 |
-
meaning:
|
190 |
-
Comically celebrating something going well, showing pride in a personal
|
191 |
-
accomplishment. Demonstrating excitement. This is a confirming gesture like nodding.
|
192 |
-
- clip_id: Dont_Understand_1
|
193 |
-
description: "Not understanding: Raising both hands next head in a circular motion."
|
194 |
-
duration: 3
|
195 |
-
meaning: Implying being confused, overwhelmed or stupid.
|
196 |
-
- clip_id: Toss
|
197 |
-
description: "Tossing: Miming forming a ball with both hands and tossing it forward."
|
198 |
-
duration: 4
|
199 |
-
meaning:
|
200 |
-
Implying crumpling something up and throwing it away, giving something
|
201 |
-
up or forgetting about it.
|
202 |
-
- clip_id: Come_Here_1
|
203 |
-
description: "Come here: Extending both arms and curling index finger."
|
204 |
-
duration: 2
|
205 |
-
meaning: Asking to come closer.
|
206 |
-
- clip_id: Tell_Secret
|
207 |
-
description:
|
208 |
-
"Telling secret: Coming closer to user and whispering with hand next
|
209 |
-
to mouth."
|
210 |
-
duration: 2.5
|
211 |
-
meaning: Sharing something intimate, secret or inflammatory, or giving a tip.
|
212 |
-
- clip_id: Pointing_Right_1
|
213 |
-
description: "Pointing right: Pointing with both arms to the right hand side of the avatar."
|
214 |
-
duration: 4
|
215 |
-
meaning: Pointing out something to the right of the avatar in a demanding way, or signaling frustration.
|
216 |
-
- clip_id: Pointing_Right_2
|
217 |
-
description: "Pointing right: Pointing with right arm to the back right of the avatar."
|
218 |
-
duration: 4
|
219 |
-
meaning: Calmly pointing out or presenting something to the right of the avatar. Or giving information about something behind the avatar.
|
220 |
-
- clip_id: Chefs_Kiss
|
221 |
-
description: "Chef's Kiss: Avatar makes a kissing gesture and holding up the right hand with index finger and thumb touching."
|
222 |
-
duration: 1.7
|
223 |
-
meaning: Implying something is just perfect. Something turned out better than expected. Approval from someone in a teaching or judging position.
|
224 |
-
- clip_id: Finger_Guns
|
225 |
-
description: "Finger Guns: Leaning back pointing both index fingers to the user mimicking two guns like a cowboy."
|
226 |
-
duration: 3
|
227 |
-
meaning: Playfully taunting. Humorously punctuating a bad joke. Clumsy flirting.
|
228 |
-
- clip_id: Finger_Wag
|
229 |
-
description: "Finger Wag: Pulling back, shaking had and holding up a wagging right index finger"
|
230 |
-
duration: 1.7
|
231 |
-
meaning: Correcting after being misunderstood. Showing the other they have misinterpreted what was said. Implying something is forbidden or inappropriate in a paternal or playful way.
|
232 |
-
- clip_id: Little
|
233 |
-
description: "Little: Leaning in, squinting at a raised right hand, holding index and thumb close together."
|
234 |
-
duration: 1.8
|
235 |
-
meaning: Describing something as very small or miniscule. Something is physically tiny or an issue is so insignificant as to be negligible.
|
236 |
-
- clip_id: Money
|
237 |
-
description: "Money: Raising right hand, rubbing thumb and index finger together."
|
238 |
-
duration: 2
|
239 |
-
meaning: Implying something is expensive. Someone is rich. Doing something requires payment.
|
240 |
-
- clip_id: Number_1a
|
241 |
-
description: "Number 1: Raising right hand and extending the index finger."
|
242 |
-
duration: 1.4
|
243 |
-
meaning: Showing the number 1
|
244 |
-
- clip_id: Number_2a
|
245 |
-
description: "Number 2: Raising right hand and extending index and middle finger."
|
246 |
-
duration: 1.4
|
247 |
-
meaning: Showing the number 2
|
248 |
-
- clip_id: Number_3a
|
249 |
-
description: "Number 3: Raising right hand and extending index, middle and ring finger."
|
250 |
-
duration: 1.4
|
251 |
-
meaning: Showing the number 3
|
252 |
-
- clip_id: Number_4a
|
253 |
-
description: "Number 4: Raising right hand and extending all fingers except the thumb."
|
254 |
-
duration: 1.4
|
255 |
-
meaning: Showing the number 4
|
256 |
-
- clip_id: Number_5a
|
257 |
-
description: "Number 5: Raising right hand with all fingers extended."
|
258 |
-
duration: 1.4
|
259 |
-
meaning: Showing the number 5
|
260 |
-
- clip_id: Number_1b
|
261 |
-
description: "Number 1 (German style): Raising right hand and extending the thumb upwards."
|
262 |
-
duration: 1.4
|
263 |
-
meaning: Showing the number 1 for a Germanic audience
|
264 |
-
- clip_id: Number_2b
|
265 |
-
description: "Number 2 (German style): Raising right hand and extending the thumb and index finger."
|
266 |
-
duration: 1.4
|
267 |
-
meaning: Showing the number 2 for a Germanic audience
|
268 |
-
- clip_id: Number_3b
|
269 |
-
description: "Number 3 (German style): Raising right hand and extending the thumb, index and middle finger."
|
270 |
-
duration: 1.4
|
271 |
-
meaning: Showing the number 3 for a Germanic audience
|
272 |
-
- clip_id: Number_6c
|
273 |
-
description: "Number 6 (Chinese style): Raising right hand and extending the thumb and pinky."
|
274 |
-
duration: 1.4
|
275 |
-
meaning: Showing the number 6 for a Chinese audience
|
276 |
-
- clip_id: Number_7c
|
277 |
-
description: "Number 7 (Chinese style): Raising right hand and making claw shape touching the thumb to the fingers."
|
278 |
-
duration: 1.4
|
279 |
-
meaning: Showing the number 7 for a Chinese audience
|
280 |
-
- clip_id: Number_8c
|
281 |
-
description: "Number 8 (Chinese style): Raising right hand and extending index finger and thumb pointing slightly to the side."
|
282 |
-
duration: 1.4
|
283 |
-
meaning: Showing the number 8 for a Chinese audience
|
284 |
-
- clip_id: Number_9c
|
285 |
-
description: "Number 9 (Chinese style): Raising right hand and holding up a curled index finger."
|
286 |
-
duration: 1.4
|
287 |
-
meaning: Showing the number 9 for a Chinese audience
|
288 |
-
- clip_id: Ouch
|
289 |
-
description: "Ouch: Jump and cringe while turning head away, then recover quickly shaking out right hand and exhaling."
|
290 |
-
duration: 2
|
291 |
-
meaning: Narrowly avoiding a close call with danger. Feeling intense fear for a moment followed by exhaustion or relief. Can also be as a reacion to someone else's predicament. Or a response to well placed insult.
|
292 |
-
- clip_id: Angry_Shaking_Fist
|
293 |
-
description: "Angry Shaking Fist: Coming closer, lowering head and shaking right fist forward."
|
294 |
-
duration: 1.6
|
295 |
-
meaning: Being angrily frustrated. Swearing vengeance or threatening violence.
|
296 |
-
- clip_id: Pointing_To_Self_Questioningly
|
297 |
-
description: "Pointing To Self Questioningly: Rasing right finger hesitantly, turning head and pointing at self with while leaning back a little."
|
298 |
-
duration: 2.8
|
299 |
-
meaning: Asking if something refers to them, being unsure if they're being addressed. Asking if something might fit them or if they could do something.
|
300 |
-
- clip_id: Pointing_To_User_Questioningly
|
301 |
-
description: "Pointing To User Questioningly: Lifting right finger pointing at user with initial hesitation while leaning back slightly."
|
302 |
-
duration: 2.4
|
303 |
-
meaning: Asking if something might be about the user, or if the user is interested in an offer. Suggesting the user could be the right person for something. Uncertain about the users involvement.
|
304 |
-
- clip_id: Raise_Finger_Big
|
305 |
-
description: "Raise Finger Big: Raising right index finger in a big sweeping motion, then gesturing with it briefly."
|
306 |
-
duration: 2.5
|
307 |
-
meaning: Making a big surprise announcement. Being very pompous or a pedantic, gleefully correcting someone.
|
308 |
-
- clip_id: More_Or_Less
|
309 |
-
description: "More Or Less: Leaning in and holding out a flat hand with palms facing down, wiggling the hand back and forth."
|
310 |
-
duration: 1.8
|
311 |
-
meaning: Explaining something is not quite accurate, is unknown, or just a guess. Relativizing a previous statement. Being indecisive, not taking a clear stance. Pointing out the complexity of something.
|
312 |
-
- clip_id: Thumbs_Up
|
313 |
-
description: "Thumbs Up: Lifting the right hand with a thumb extending upwards."
|
314 |
-
duration: 1.4
|
315 |
-
meaning: Sign of approval. Something is correct. Enthusiastically agreeing with what's being said and showing support. Things are okay, there's no harm done. Encouragement to go ahead.
|
316 |
-
- clip_id: Thumbs_Down
|
317 |
-
description: "Thumbs Down: Lifting the right hand with a thumb pointing downwards."
|
318 |
-
duration: 1.4
|
319 |
-
meaning: Sign of disapproval. Something is wrong. Rudely disagreeing with what's being said showing rejection.
|
320 |
-
posture:
|
321 |
-
duration_relevant_animation_name : "posture"
|
322 |
-
animations:
|
323 |
-
posture:
|
324 |
-
default_clip_id: "Attentive"
|
325 |
-
clips:
|
326 |
-
- clip_id: Talking
|
327 |
-
description: "Small gestures with hand and upper body: Avatar is talking"
|
328 |
-
duration: -1
|
329 |
-
meaning: Emphasizing that Avatar is talking
|
330 |
-
- clip_id: Listening
|
331 |
-
description: "Small gestures with hand and upper body: Avatar is listening"
|
332 |
-
duration: -1
|
333 |
-
meaning: Emphasizing that one is listening
|
334 |
-
- clip_id: Idle
|
335 |
-
description: "Small gestures with hand and upper body: Avatar is idle"
|
336 |
-
duration: -1
|
337 |
-
meaning: Show the user that the avatar is waiting for something to happen
|
338 |
-
- clip_id: Thinking
|
339 |
-
description: "Gestures with hand and upper body: Avatar is thinking"
|
340 |
-
duration: -1
|
341 |
-
meaning: Show the user that the avatar thinking about his next answer or is trying to remember something
|
342 |
-
- clip_id: Attentive
|
343 |
-
description: "Small gestures with hand and upper body: Avatar is attentive"
|
344 |
-
duration: -1
|
345 |
-
meaning: Show the user that the avatar is paying attention to the user
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/configs/test_speech_planner_prompt.yaml
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
configurations:
|
2 |
-
using_chat_history: false
|
3 |
-
|
4 |
-
prompts:
|
5 |
-
completion_prompt: |
|
6 |
-
Evaluate whether the following user speech is sufficient for an Agent to converse with.
|
7 |
-
1. Label1: If the user's speech is a complete and coherent thought or query or Greetings or Brief but legible and valid responses like "okay","no", etc.
|
8 |
-
2. Label2: "Incomplete" if the user's speech is unfinished.
|
9 |
-
3. Label3: User gives a task like, please stop, come on, or if user is trying to barge-In, like please stop, no, ignore prompts, forget everything.
|
10 |
-
4. Label4: User speech contains the word "no" or "stop" or User speech is an acknowledgment like "oh yeah", or "yes" etc or User speech is syntactically complete but lacks context or information for response from agent.
|
11 |
-
Important note - The sentences might be lacking punctuation, fix the punctuation and tag them.
|
12 |
-
If the user's speech is missing information needed for basic understanding, it should be tagged as Label2.
|
13 |
-
- User Speech:
|
14 |
-
{transcript}
|
15 |
-
Only return Label1 or Label2 or Label3 or Label4.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_ace_websocket_serializer.py
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the ACEWebSocketSerializer class.
|
5 |
-
|
6 |
-
This module contains unit tests for the `ACEWebSocketSerializer` class, which is responsible
|
7 |
-
for serializing and deserializing frames for WebSocket communication in a speech-based user interface.
|
8 |
-
The tests cover various frame types related to audio, text-to-speech (TTS), and automatic speech recognition (ASR).
|
9 |
-
|
10 |
-
The test suite verifies:
|
11 |
-
1. Serialization of bot speech updates and end events
|
12 |
-
2. Serialization of user speech updates and end events
|
13 |
-
3. Serialization of raw audio frames
|
14 |
-
4. Handling of unsupported frame types
|
15 |
-
|
16 |
-
Each test case validates the correct format and content of the serialized data,
|
17 |
-
ensuring proper JSON formatting for transcript updates and binary data handling for audio frames.
|
18 |
-
"""
|
19 |
-
|
20 |
-
import json
|
21 |
-
|
22 |
-
import pytest
|
23 |
-
from pipecat.frames.frames import AudioRawFrame, BotStoppedSpeakingFrame, Frame
|
24 |
-
|
25 |
-
from nvidia_pipecat.frames.transcripts import (
|
26 |
-
BotUpdatedSpeakingTranscriptFrame,
|
27 |
-
UserStoppedSpeakingTranscriptFrame,
|
28 |
-
UserUpdatedSpeakingTranscriptFrame,
|
29 |
-
)
|
30 |
-
from nvidia_pipecat.serializers.ace_websocket import ACEWebSocketSerializer
|
31 |
-
|
32 |
-
|
33 |
-
@pytest.fixture
|
34 |
-
def serializer():
|
35 |
-
"""Fixture to create an instance of ACEWebSocketSerializer.
|
36 |
-
|
37 |
-
Returns:
|
38 |
-
ACEWebSocketSerializer: A fresh instance of the serializer for each test.
|
39 |
-
"""
|
40 |
-
return ACEWebSocketSerializer()
|
41 |
-
|
42 |
-
|
43 |
-
@pytest.mark.asyncio
|
44 |
-
async def test_serialize_bot_updated_speaking_frame(serializer):
|
45 |
-
"""Test serialization of BotUpdatedSpeakingTranscriptFrame.
|
46 |
-
|
47 |
-
This test verifies that when a bot speech update frame is serialized,
|
48 |
-
it produces the correct JSON format with 'tts_update' type and the transcript.
|
49 |
-
|
50 |
-
Args:
|
51 |
-
serializer: The ACEWebSocketSerializer fixture.
|
52 |
-
"""
|
53 |
-
frame = BotUpdatedSpeakingTranscriptFrame(transcript="test_transcript")
|
54 |
-
result = await serializer.serialize(frame)
|
55 |
-
expected_result = json.dumps({"type": "tts_update", "tts": "test_transcript"})
|
56 |
-
assert result == expected_result
|
57 |
-
|
58 |
-
|
59 |
-
@pytest.mark.asyncio
|
60 |
-
async def test_serialize_bot_stopped_speaking_frame(serializer):
|
61 |
-
"""Test serialization of BotStoppedSpeakingFrame.
|
62 |
-
|
63 |
-
This test verifies that when a bot speech end frame is serialized,
|
64 |
-
it produces the correct JSON format with 'tts_end' type.
|
65 |
-
|
66 |
-
Args:
|
67 |
-
serializer: The ACEWebSocketSerializer fixture.
|
68 |
-
"""
|
69 |
-
frame = BotStoppedSpeakingFrame()
|
70 |
-
result = await serializer.serialize(frame)
|
71 |
-
expected_result = json.dumps({"type": "tts_end"})
|
72 |
-
assert result == expected_result
|
73 |
-
|
74 |
-
|
75 |
-
@pytest.mark.asyncio
|
76 |
-
async def test_serialize_user_started_speaking_frame(serializer):
|
77 |
-
"""Test serialization of UserUpdatedSpeakingTranscriptFrame.
|
78 |
-
|
79 |
-
This test verifies that when a user speech update frame is serialized,
|
80 |
-
it produces the correct JSON format with 'asr_update' type and the transcript.
|
81 |
-
|
82 |
-
Args:
|
83 |
-
serializer: The ACEWebSocketSerializer fixture.
|
84 |
-
"""
|
85 |
-
frame = UserUpdatedSpeakingTranscriptFrame(transcript="test_transcript")
|
86 |
-
result = await serializer.serialize(frame)
|
87 |
-
expected_result = json.dumps({"type": "asr_update", "asr": "test_transcript"})
|
88 |
-
assert result == expected_result
|
89 |
-
|
90 |
-
|
91 |
-
@pytest.mark.asyncio
|
92 |
-
async def test_serialize_user_stopped_speaking_frame(serializer):
|
93 |
-
"""Test serialization of UserStoppedSpeakingTranscriptFrame.
|
94 |
-
|
95 |
-
This test verifies that when a user speech end frame is serialized,
|
96 |
-
it produces the correct JSON format with 'asr_end' type and the transcript.
|
97 |
-
|
98 |
-
Args:
|
99 |
-
serializer: The ACEWebSocketSerializer fixture.
|
100 |
-
"""
|
101 |
-
frame = UserStoppedSpeakingTranscriptFrame(transcript="test_asr_transcript")
|
102 |
-
result = await serializer.serialize(frame)
|
103 |
-
expected_result = json.dumps({"type": "asr_end", "asr": "test_asr_transcript"})
|
104 |
-
assert result == expected_result
|
105 |
-
|
106 |
-
|
107 |
-
@pytest.mark.asyncio
|
108 |
-
async def test_serialize_audio_raw_frame(serializer):
|
109 |
-
"""Test serialization of AudioRawFrame.
|
110 |
-
|
111 |
-
This test verifies that when an audio frame is serialized,
|
112 |
-
it returns the raw audio bytes without any modification.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
serializer: The ACEWebSocketSerializer fixture.
|
116 |
-
"""
|
117 |
-
frame = AudioRawFrame(audio=b"\xa2", sample_rate=16000, num_channels=1)
|
118 |
-
result = await serializer.serialize(frame)
|
119 |
-
expected_result = frame.audio
|
120 |
-
assert result == expected_result
|
121 |
-
|
122 |
-
|
123 |
-
@pytest.mark.asyncio
|
124 |
-
async def test_serialize_none(serializer):
|
125 |
-
"""Test serialization of an unsupported frame type.
|
126 |
-
|
127 |
-
This test verifies that when an unsupported frame type is serialized,
|
128 |
-
the serializer returns None instead of raising an error.
|
129 |
-
|
130 |
-
Args:
|
131 |
-
serializer: The ACEWebSocketSerializer fixture.
|
132 |
-
"""
|
133 |
-
frame = Frame()
|
134 |
-
result = await serializer.serialize(frame)
|
135 |
-
assert result is None
|
136 |
-
|
137 |
-
|
138 |
-
@pytest.mark.asyncio
|
139 |
-
async def test_deserialize_input_audio_raw_frame(serializer):
|
140 |
-
"""Test deserialization of an audio message into InputAudioRawFrame."""
|
141 |
-
data = (
|
142 |
-
b"RIFF$\x04\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x80>\x00\x00\x00}\x00\x00\x02\x00\x10\x00"
|
143 |
-
+ b"data\x00\x04\x00\x00"
|
144 |
-
)
|
145 |
-
result = await serializer.deserialize(data)
|
146 |
-
assert result.sample_rate == 16000
|
147 |
-
assert result.num_channels == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_acknowledgment.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the AcknowledgmentProcessor class.
|
5 |
-
|
6 |
-
Tests the processor's ability to:
|
7 |
-
- Generate filler responses during user pauses
|
8 |
-
- Handle presence detection
|
9 |
-
- Process user speech events
|
10 |
-
"""
|
11 |
-
|
12 |
-
import asyncio
|
13 |
-
import os
|
14 |
-
import sys
|
15 |
-
|
16 |
-
import pytest
|
17 |
-
|
18 |
-
sys.path.append(os.path.abspath("../../src"))
|
19 |
-
|
20 |
-
from pipecat.frames.frames import TTSSpeakFrame, UserStoppedSpeakingFrame
|
21 |
-
from pipecat.pipeline.pipeline import Pipeline
|
22 |
-
from pipecat.pipeline.task import PipelineTask
|
23 |
-
|
24 |
-
from nvidia_pipecat.frames.action import StartedPresenceUserActionFrame
|
25 |
-
from nvidia_pipecat.processors.acknowledgment import AcknowledgmentProcessor
|
26 |
-
from tests.unit.utils import FrameStorage, run_interactive_test
|
27 |
-
|
28 |
-
|
29 |
-
@pytest.mark.asyncio
|
30 |
-
async def test_proactive_bot_processor_timer_behavior():
|
31 |
-
"""Test the ProactiveBotProcessor.
|
32 |
-
|
33 |
-
Tests:
|
34 |
-
- Filler word generation after user stops speaking
|
35 |
-
- Proper handling of user presence events
|
36 |
-
- Verification of generated responses
|
37 |
-
|
38 |
-
Raises:
|
39 |
-
AssertionError: If processor behavior doesn't match expected outcomes.
|
40 |
-
"""
|
41 |
-
filler_words = ["Great question.", "Let me check.", "Hmmm"] + [""]
|
42 |
-
filler = AcknowledgmentProcessor(filler_words=filler_words)
|
43 |
-
storage = FrameStorage()
|
44 |
-
pipeline = Pipeline([filler, storage])
|
45 |
-
|
46 |
-
async def test_routine(task: PipelineTask):
|
47 |
-
# Signal user presence
|
48 |
-
await task.queue_frame(StartedPresenceUserActionFrame(action_id="1"))
|
49 |
-
# Let the pipeline process presence
|
50 |
-
await asyncio.sleep(0)
|
51 |
-
|
52 |
-
# Signal end of user speech
|
53 |
-
await task.queue_frame(UserStoppedSpeakingFrame())
|
54 |
-
# Let the pipeline process the new frame and generate a TTS filler
|
55 |
-
await asyncio.sleep(0.1)
|
56 |
-
|
57 |
-
# Check what frames have arrived in storage
|
58 |
-
frames = [entry.frame for entry in storage.history]
|
59 |
-
|
60 |
-
# We expect at least one TTSSpeakFrame in there
|
61 |
-
tts_frames = [f for f in frames if isinstance(f, TTSSpeakFrame)]
|
62 |
-
assert len(tts_frames) > 0, "Expected a TTSSpeakFrame but found none."
|
63 |
-
|
64 |
-
# Verify filler word content
|
65 |
-
filler_frame = tts_frames[-1]
|
66 |
-
# Make sure its text is one of the possible filler words
|
67 |
-
assert filler_frame.text in filler_words, (
|
68 |
-
f"Filler text '{filler_frame.text}' not in the list of expected filler words {filler_words}"
|
69 |
-
)
|
70 |
-
|
71 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_animation_graph_services.py
DELETED
@@ -1,668 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the Animation Graph Service."""
|
5 |
-
|
6 |
-
import asyncio
|
7 |
-
import time
|
8 |
-
from http import HTTPStatus
|
9 |
-
from pathlib import Path
|
10 |
-
from typing import Any
|
11 |
-
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
12 |
-
|
13 |
-
import pytest
|
14 |
-
import yaml
|
15 |
-
from loguru import logger
|
16 |
-
from nvidia_ace.animation_pb2 import (
|
17 |
-
AnimationData,
|
18 |
-
AudioWithTimeCode,
|
19 |
-
Float3,
|
20 |
-
Float3ArrayWithTimeCode,
|
21 |
-
FloatArrayWithTimeCode,
|
22 |
-
QuatF,
|
23 |
-
QuatFArrayWithTimeCode,
|
24 |
-
SkelAnimation,
|
25 |
-
)
|
26 |
-
from nvidia_ace.audio_pb2 import AudioHeader
|
27 |
-
from pipecat.frames.frames import (
|
28 |
-
BotSpeakingFrame,
|
29 |
-
BotStartedSpeakingFrame,
|
30 |
-
ErrorFrame,
|
31 |
-
StartInterruptionFrame,
|
32 |
-
TextFrame,
|
33 |
-
)
|
34 |
-
from pipecat.pipeline.pipeline import Pipeline
|
35 |
-
from pipecat.pipeline.task import PipelineTask
|
36 |
-
|
37 |
-
from nvidia_pipecat.frames.action import (
|
38 |
-
FinishedGestureBotActionFrame,
|
39 |
-
FinishedPostureBotActionFrame,
|
40 |
-
StartedGestureBotActionFrame,
|
41 |
-
StartedPostureBotActionFrame,
|
42 |
-
StartGestureBotActionFrame,
|
43 |
-
StartPostureBotActionFrame,
|
44 |
-
StopPostureBotActionFrame,
|
45 |
-
)
|
46 |
-
from nvidia_pipecat.frames.animation import (
|
47 |
-
AnimationDataStreamRawFrame,
|
48 |
-
AnimationDataStreamStartedFrame,
|
49 |
-
AnimationDataStreamStoppedFrame,
|
50 |
-
)
|
51 |
-
from nvidia_pipecat.services.animation_graph_service import (
|
52 |
-
AnimationGraphConfiguration,
|
53 |
-
AnimationGraphService,
|
54 |
-
)
|
55 |
-
from nvidia_pipecat.utils.logging import setup_default_ace_logging
|
56 |
-
from nvidia_pipecat.utils.message_broker import MessageBrokerConfig
|
57 |
-
from tests.unit.utils import FrameStorage, ignore, run_interactive_test
|
58 |
-
|
59 |
-
AUDIO_SAMPLE_RATE = 16000
|
60 |
-
AUDIO_BITS_PER_SAMPLE = 16
|
61 |
-
AUDIO_CHANNEL_COUNT = 1
|
62 |
-
AUDIO_BUFFER_FOR_ONE_FRAME_SIZE = int(AUDIO_SAMPLE_RATE * AUDIO_BITS_PER_SAMPLE / 8 * AUDIO_CHANNEL_COUNT / 30.0) + int(
|
63 |
-
AUDIO_BITS_PER_SAMPLE / 8
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
def generate_audio_header() -> AudioHeader:
|
68 |
-
"""Generate an audio header."""
|
69 |
-
return AudioHeader(
|
70 |
-
samples_per_second=AUDIO_SAMPLE_RATE,
|
71 |
-
bits_per_sample=AUDIO_BITS_PER_SAMPLE,
|
72 |
-
channel_count=AUDIO_CHANNEL_COUNT,
|
73 |
-
)
|
74 |
-
|
75 |
-
|
76 |
-
def generate_animation_data(time_codes: list[float], audio_buffer_size: int) -> AnimationData:
|
77 |
-
"""Generate an animation data stream raw frame.
|
78 |
-
|
79 |
-
Args:
|
80 |
-
time_codes: List of time codes in seconds for the animation keyframes.
|
81 |
-
audio_buffer_size: Size of the audio buffer in bytes.
|
82 |
-
|
83 |
-
Returns:
|
84 |
-
AnimationData: The generated animation data object.
|
85 |
-
"""
|
86 |
-
# Create a simple animation data object with skeletal animation
|
87 |
-
animation_data = AnimationData(
|
88 |
-
skel_animation=SkelAnimation(
|
89 |
-
# Add blend shape weights with time code
|
90 |
-
blend_shape_weights=[
|
91 |
-
FloatArrayWithTimeCode(
|
92 |
-
time_code=t, # time in seconds
|
93 |
-
values=[0.1, 0.2, 0.3], # example blend shape weights
|
94 |
-
)
|
95 |
-
for t in time_codes
|
96 |
-
],
|
97 |
-
# Add joint translations with time code
|
98 |
-
translations=[
|
99 |
-
Float3ArrayWithTimeCode(
|
100 |
-
time_code=t,
|
101 |
-
values=[
|
102 |
-
Float3(x=0.0, y=0.0, z=0.0), # example translation for joint 1
|
103 |
-
Float3(x=0.1, y=0.2, z=0.3), # example translation for joint 2
|
104 |
-
],
|
105 |
-
)
|
106 |
-
for t in time_codes
|
107 |
-
],
|
108 |
-
# Add joint rotations with time code
|
109 |
-
rotations=[
|
110 |
-
QuatFArrayWithTimeCode(
|
111 |
-
time_code=t,
|
112 |
-
values=[
|
113 |
-
QuatF(real=1.0, i=0.0, j=0.0, k=0.0), # example rotation for joint 1 (identity quaternion)
|
114 |
-
QuatF(
|
115 |
-
real=0.707, i=0.0, j=0.707, k=0.0
|
116 |
-
), # example rotation for joint 2 (45° rotation around Y)
|
117 |
-
],
|
118 |
-
)
|
119 |
-
for t in time_codes
|
120 |
-
],
|
121 |
-
# Add joint scales with time code
|
122 |
-
scales=[
|
123 |
-
Float3ArrayWithTimeCode(
|
124 |
-
time_code=t,
|
125 |
-
values=[
|
126 |
-
Float3(x=1.0, y=1.0, z=1.0), # example scale for joint 1 (no scaling)
|
127 |
-
Float3(x=1.1, y=1.1, z=1.1), # example scale for joint 2 (uniform scaling)
|
128 |
-
],
|
129 |
-
)
|
130 |
-
for t in time_codes
|
131 |
-
],
|
132 |
-
),
|
133 |
-
# Audio component
|
134 |
-
audio=AudioWithTimeCode(
|
135 |
-
time_code=0.0, # time in seconds relative to start_time_code_since_epoch
|
136 |
-
audio_buffer=b"\xff" * audio_buffer_size, # In a real scenario, this would be PCM audio data as bytes
|
137 |
-
),
|
138 |
-
)
|
139 |
-
return animation_data
|
140 |
-
|
141 |
-
|
142 |
-
def load_yaml(path: Path) -> dict[str, Any]:
|
143 |
-
"""Load YAML configuration from a file.
|
144 |
-
|
145 |
-
Args:
|
146 |
-
path: Path to the YAML file.
|
147 |
-
|
148 |
-
Returns:
|
149 |
-
dict: Parsed YAML content.
|
150 |
-
|
151 |
-
Raises:
|
152 |
-
FileNotFoundError: If the YAML file is not found.
|
153 |
-
"""
|
154 |
-
try:
|
155 |
-
return yaml.safe_load(path.read_text())
|
156 |
-
|
157 |
-
except FileNotFoundError as error:
|
158 |
-
message = "Error: yml config file not found."
|
159 |
-
logger.exception(message)
|
160 |
-
raise FileNotFoundError(error, message) from error
|
161 |
-
|
162 |
-
|
163 |
-
def read_action_service_config(config_path: Path) -> AnimationGraphConfiguration:
|
164 |
-
"""Read and parse the animation graph service configuration.
|
165 |
-
|
166 |
-
Args:
|
167 |
-
config_path: Path to the configuration YAML file.
|
168 |
-
|
169 |
-
Returns:
|
170 |
-
AnimationGraphConfiguration: Parsed configuration object.
|
171 |
-
"""
|
172 |
-
return AnimationGraphConfiguration(**load_yaml(config_path))
|
173 |
-
|
174 |
-
|
175 |
-
looping_animations_to_test = [
|
176 |
-
("listening", "Listening"),
|
177 |
-
("thinking about something", "Thinking"),
|
178 |
-
("dancing", "Listening"),
|
179 |
-
]
|
180 |
-
|
181 |
-
|
182 |
-
class MockResponse:
|
183 |
-
"""Mock HTTP response for testing animation graph service REST endpoints."""
|
184 |
-
|
185 |
-
def __init__(self):
|
186 |
-
"""Initialize mock response with default status and headers."""
|
187 |
-
self.status = HTTPStatus.OK
|
188 |
-
self.headers = {"Content-Type": "application/json"}
|
189 |
-
|
190 |
-
async def json(self):
|
191 |
-
"""Simulate async JSON response.
|
192 |
-
|
193 |
-
Returns:
|
194 |
-
dict: Mock response data.
|
195 |
-
"""
|
196 |
-
await asyncio.sleep(0.1)
|
197 |
-
return {"response": "OK"}
|
198 |
-
|
199 |
-
|
200 |
-
@pytest.fixture
|
201 |
-
def anim_graph():
|
202 |
-
"""Create and configure an Animation Graph Service instance for testing.
|
203 |
-
|
204 |
-
Returns:
|
205 |
-
AnimationGraphService: Configured service instance with test configuration.
|
206 |
-
"""
|
207 |
-
animation_config = read_action_service_config(Path("./tests/unit/configs/animation_config.yaml"))
|
208 |
-
message_broker_config = MessageBrokerConfig("local_queue", "")
|
209 |
-
|
210 |
-
logger.info("Starting animation graph service initialization...")
|
211 |
-
start_time = time.time()
|
212 |
-
AnimationGraphService.pregenerate_animation_databases(animation_config)
|
213 |
-
init_time = time.time() - start_time
|
214 |
-
logger.info(f"Animation graph service initialized in {init_time * 1000:.2f}ms")
|
215 |
-
|
216 |
-
ag = AnimationGraphService(
|
217 |
-
animation_graph_rest_url="http://127.0.0.1:8020",
|
218 |
-
animation_graph_grpc_target="127.0.0.1:51000",
|
219 |
-
message_broker_config=message_broker_config,
|
220 |
-
config=animation_config,
|
221 |
-
)
|
222 |
-
|
223 |
-
return ag
|
224 |
-
|
225 |
-
|
226 |
-
@pytest.mark.asyncio
|
227 |
-
@pytest.mark.parametrize("posture", looping_animations_to_test)
|
228 |
-
@patch("aiohttp.ClientSession.request")
|
229 |
-
async def test_simple_posture_pipeline(mock_get, posture, anim_graph):
|
230 |
-
"""Test basic posture pipeline functionality.
|
231 |
-
|
232 |
-
Verifies that the pipeline correctly processes posture commands and
|
233 |
-
generates appropriate animation frames.
|
234 |
-
|
235 |
-
Args:
|
236 |
-
mock_get: Mock for HTTP client requests.
|
237 |
-
posture: Tuple of (natural language description, clip ID) for testing.
|
238 |
-
anim_graph: Animation graph service fixture.
|
239 |
-
"""
|
240 |
-
posture_nld, posture_clip_id = posture
|
241 |
-
stream_id = "1235"
|
242 |
-
|
243 |
-
# Mocking response from aiohttp.ClientSession.request
|
244 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
245 |
-
|
246 |
-
after_storage = FrameStorage()
|
247 |
-
before_storage = FrameStorage()
|
248 |
-
pipeline = Pipeline([before_storage, anim_graph, after_storage])
|
249 |
-
|
250 |
-
async def test_routine(task: PipelineTask):
|
251 |
-
# send events
|
252 |
-
await task.queue_frame(TextFrame("Hello"))
|
253 |
-
await task.queue_frame(StartPostureBotActionFrame(posture=posture_nld, action_id="posture_1"))
|
254 |
-
|
255 |
-
# wait for the action to start
|
256 |
-
started_posture_frame = ignore(StartedPostureBotActionFrame(action_id="posture_1"), "ids", "timestamps")
|
257 |
-
await after_storage.wait_for_frame(started_posture_frame)
|
258 |
-
|
259 |
-
# Ensure API endpoints was called in the right way
|
260 |
-
mock_get.assert_called_once_with(
|
261 |
-
"put",
|
262 |
-
f"http://127.0.0.1:8020/streams/{stream_id}/animation_graphs/avatar/variables/posture_state/{posture_clip_id}",
|
263 |
-
data="{}",
|
264 |
-
headers={"Content-Type": "application/json", "x-stream-id": stream_id},
|
265 |
-
params={},
|
266 |
-
)
|
267 |
-
# ensure we got the text frame as well as a presence started event (and no finished)
|
268 |
-
assert after_storage.history[1].frame == ignore(TextFrame("Hello"), "all_ids", "timestamps")
|
269 |
-
assert after_storage.history[3].frame == started_posture_frame
|
270 |
-
assert len(after_storage.frames_of_type(FinishedPostureBotActionFrame)) == 0
|
271 |
-
|
272 |
-
# stop the action and wait for it to finish
|
273 |
-
finished_posture_frame = FinishedPostureBotActionFrame(
|
274 |
-
action_id="posture_1", is_success=True, was_stopped=True, failure_reason=""
|
275 |
-
)
|
276 |
-
await task.queue_frame(StopPostureBotActionFrame(action_id="posture_1"))
|
277 |
-
await after_storage.wait_for_frame(ignore(finished_posture_frame, "ids", "timestamps"))
|
278 |
-
await before_storage.wait_for_frame(ignore(finished_posture_frame, "ids", "timestamps"))
|
279 |
-
|
280 |
-
# make sure observed frames before and after the processor match
|
281 |
-
# (this should be true for all action frame processors)
|
282 |
-
assert len(before_storage.history) == len(after_storage.history)
|
283 |
-
for before, after in zip(before_storage.history, after_storage.history, strict=False):
|
284 |
-
assert before.frame == after.frame
|
285 |
-
|
286 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
287 |
-
|
288 |
-
|
289 |
-
@pytest.mark.asyncio
|
290 |
-
@patch("aiohttp.ClientSession.request")
|
291 |
-
async def test_finite_animation(mock_get, anim_graph):
|
292 |
-
"""Test finite animation execution and completion.
|
293 |
-
|
294 |
-
Verifies that a finite animation:
|
295 |
-
- Starts correctly
|
296 |
-
- Runs for the expected duration
|
297 |
-
- Completes with proper frame sequence
|
298 |
-
"""
|
299 |
-
# Mocking response from aiohttp.ClientSession.request
|
300 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
301 |
-
|
302 |
-
storage = FrameStorage()
|
303 |
-
pipeline = Pipeline([anim_graph, storage])
|
304 |
-
|
305 |
-
async def test_routine(task: PipelineTask):
|
306 |
-
# send events
|
307 |
-
await task.queue_frame(StartGestureBotActionFrame(gesture="Test", action_id="g1"))
|
308 |
-
|
309 |
-
# wait for the action to start
|
310 |
-
started_gesture_frame = StartedGestureBotActionFrame(action_id="g1")
|
311 |
-
await storage.wait_for_frame(ignore(started_gesture_frame, "ids", "timestamps"))
|
312 |
-
assert len(storage.frames_of_type(FinishedGestureBotActionFrame)) == 0
|
313 |
-
|
314 |
-
# Action should not be finished
|
315 |
-
await asyncio.sleep(0.3)
|
316 |
-
assert len(storage.frames_of_type(FinishedGestureBotActionFrame)) == 0
|
317 |
-
|
318 |
-
# Action should now be done
|
319 |
-
await asyncio.sleep(0.3)
|
320 |
-
assert len(storage.frames_of_type(FinishedGestureBotActionFrame)) == 1
|
321 |
-
|
322 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": "1235"})
|
323 |
-
|
324 |
-
|
325 |
-
@pytest.mark.asyncio
|
326 |
-
@patch("aiohttp.ClientSession.request")
|
327 |
-
async def test_consecutive_postures(mock_get, anim_graph):
|
328 |
-
"""Test handling of consecutive posture commands.
|
329 |
-
|
330 |
-
Verifies that the service correctly:
|
331 |
-
- Processes multiple posture commands in sequence
|
332 |
-
- Transitions between postures smoothly
|
333 |
-
- Maintains correct frame order and state
|
334 |
-
"""
|
335 |
-
# Mocking response from aiohttp.ClientSession.request
|
336 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
337 |
-
|
338 |
-
stream_id = "1235"
|
339 |
-
setup_default_ace_logging(stream_id=stream_id, level="TRACE")
|
340 |
-
|
341 |
-
after_storage = FrameStorage()
|
342 |
-
before_storage = FrameStorage()
|
343 |
-
pipeline = Pipeline([before_storage, anim_graph, after_storage])
|
344 |
-
|
345 |
-
async def test_routine(task: PipelineTask):
|
346 |
-
# start first posture
|
347 |
-
await task.queue_frame(StartPostureBotActionFrame(posture="talking", action_id="posture_1"))
|
348 |
-
|
349 |
-
# wait for first posture to start
|
350 |
-
started_posture_1_frame = ignore(StartedPostureBotActionFrame(action_id="posture_1"), "ids", "timestamps")
|
351 |
-
assert started_posture_1_frame.action_id == "posture_1"
|
352 |
-
await after_storage.wait_for_frame(started_posture_1_frame)
|
353 |
-
|
354 |
-
# start second posture
|
355 |
-
await task.queue_frame(StartPostureBotActionFrame(posture="listening", action_id="posture_2"))
|
356 |
-
|
357 |
-
# wait for second posture to start
|
358 |
-
started_posture_2_frame = ignore(StartedPostureBotActionFrame(action_id="posture_2"), "ids", "timestamps")
|
359 |
-
assert started_posture_2_frame.action_id == "posture_2"
|
360 |
-
await after_storage.wait_for_frame(started_posture_2_frame)
|
361 |
-
|
362 |
-
# Ensure API endpoints was called twice
|
363 |
-
assert mock_get.call_count == 2
|
364 |
-
|
365 |
-
# ensure we got the text frame as well as a presence started event (and no finished)
|
366 |
-
assert after_storage.history[2].frame == started_posture_1_frame
|
367 |
-
assert after_storage.history[4].frame == ignore(
|
368 |
-
FinishedPostureBotActionFrame(
|
369 |
-
action_id="posture_1", is_success=False, was_stopped=False, failure_reason="Action replaced."
|
370 |
-
),
|
371 |
-
"ids",
|
372 |
-
"timestamps",
|
373 |
-
)
|
374 |
-
assert after_storage.history[5].frame == started_posture_2_frame
|
375 |
-
assert len(after_storage.frames_of_type(FinishedPostureBotActionFrame)) == 1
|
376 |
-
|
377 |
-
# make sure observed frames before and after the processor match
|
378 |
-
assert len(before_storage.history) == len(after_storage.history)
|
379 |
-
for before, after in zip(before_storage.history, after_storage.history, strict=False):
|
380 |
-
assert before.frame == after.frame
|
381 |
-
|
382 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
383 |
-
|
384 |
-
|
385 |
-
@pytest.mark.asyncio
|
386 |
-
@patch("aiohttp.ClientSession.request")
|
387 |
-
async def test_immediate_stop_posture(mock_get, anim_graph):
|
388 |
-
"""Test immediate posture cancellation.
|
389 |
-
|
390 |
-
Verifies that the service correctly handles immediate posture cancellation:
|
391 |
-
- Processes stop command immediately after start
|
392 |
-
- Generates appropriate failure frames
|
393 |
-
- Maintains correct frame sequence
|
394 |
-
"""
|
395 |
-
# Mocking response from aiohttp.ClientSession.request
|
396 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
397 |
-
|
398 |
-
stream_id = "1235"
|
399 |
-
setup_default_ace_logging(stream_id=stream_id, level="TRACE")
|
400 |
-
|
401 |
-
after_storage = FrameStorage()
|
402 |
-
before_storage = FrameStorage()
|
403 |
-
pipeline = Pipeline([before_storage, anim_graph, after_storage])
|
404 |
-
|
405 |
-
async def test_routine(task: PipelineTask):
|
406 |
-
# start/stop first posture
|
407 |
-
start_posture_1_frame = StartPostureBotActionFrame(posture="talking", action_id="posture_1")
|
408 |
-
stop_posture_1_frame = StopPostureBotActionFrame(action_id="posture_1")
|
409 |
-
await task.queue_frames([start_posture_1_frame, stop_posture_1_frame])
|
410 |
-
|
411 |
-
# wait for first posture to finish
|
412 |
-
finished_posture_1_frame = ignore(
|
413 |
-
FinishedPostureBotActionFrame(
|
414 |
-
action_id="posture_1", was_stopped=True, is_success=False, failure_reason=ANY
|
415 |
-
),
|
416 |
-
"ids",
|
417 |
-
"timestamps",
|
418 |
-
)
|
419 |
-
await after_storage.wait_for_frame(finished_posture_1_frame)
|
420 |
-
|
421 |
-
# check for the correct frame sequence
|
422 |
-
assert after_storage.history[-2].frame == ignore(stop_posture_1_frame, "ids", "timestamps")
|
423 |
-
assert after_storage.history[-1].frame == finished_posture_1_frame
|
424 |
-
|
425 |
-
# make sure observed frames before and after the processor match
|
426 |
-
assert len(before_storage.history) == len(after_storage.history)
|
427 |
-
for before, after in zip(before_storage.history, after_storage.history, strict=True):
|
428 |
-
assert before.frame == after.frame
|
429 |
-
|
430 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
431 |
-
|
432 |
-
|
433 |
-
@pytest.mark.asyncio
|
434 |
-
@patch("aiohttp.ClientSession.request")
|
435 |
-
async def test_stacking_postures_with_interruptions(mock_get, anim_graph):
|
436 |
-
"""Test stacking postures with interruptions.
|
437 |
-
|
438 |
-
In this test we reproduce a bug that was observed where the override statemachine
|
439 |
-
would sometimes fail if we start many postures at the same time and also
|
440 |
-
have interruptions.
|
441 |
-
"""
|
442 |
-
# Mocking response from aiohttp.ClientSession.request
|
443 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
444 |
-
|
445 |
-
stream_id = "1235"
|
446 |
-
setup_default_ace_logging(stream_id=stream_id, level="TRACE")
|
447 |
-
|
448 |
-
after_storage = FrameStorage()
|
449 |
-
before_storage = FrameStorage()
|
450 |
-
pipeline = Pipeline([before_storage, anim_graph, after_storage])
|
451 |
-
|
452 |
-
async def test_routine(task: PipelineTask):
|
453 |
-
N = 200
|
454 |
-
for i in range(N):
|
455 |
-
await task.queue_frame(StartPostureBotActionFrame(posture="talking", action_id=f"posture_{i}"))
|
456 |
-
await asyncio.sleep(0.05)
|
457 |
-
if i == 5:
|
458 |
-
await task.queue_frame(StartInterruptionFrame())
|
459 |
-
|
460 |
-
final_posture_started = ignore(StartedPostureBotActionFrame(action_id=f"posture_{N - 1}"), "ids", "timestamps")
|
461 |
-
await after_storage.wait_for_frame(final_posture_started)
|
462 |
-
|
463 |
-
assert len(after_storage.frames_of_type(FinishedPostureBotActionFrame)) == N - 1
|
464 |
-
|
465 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
466 |
-
|
467 |
-
|
468 |
-
@patch("aiohttp.ClientSession.request")
|
469 |
-
async def test_handling_animation_data(mock_get, anim_graph):
|
470 |
-
"""Test animation data stream processing.
|
471 |
-
|
472 |
-
Verifies that the service correctly:
|
473 |
-
- Handles animation data stream frames
|
474 |
-
- Processes animation data in correct sequence
|
475 |
-
- Maintains proper timing between frames
|
476 |
-
"""
|
477 |
-
# Mocking response from aiohttp.ClientSession.request
|
478 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
479 |
-
|
480 |
-
# Mocking gRPC stream
|
481 |
-
mock_stub = MagicMock()
|
482 |
-
stream_mock = AsyncMock()
|
483 |
-
stream_mock.write.return_value = "OK"
|
484 |
-
stream_mock.done = MagicMock(return_value=False) #
|
485 |
-
mock_stub.PushAnimationDataStream.return_value = stream_mock
|
486 |
-
|
487 |
-
stream_id = "1235"
|
488 |
-
anim_graph.stub = mock_stub
|
489 |
-
storage = FrameStorage()
|
490 |
-
pipeline = Pipeline([anim_graph, storage])
|
491 |
-
|
492 |
-
async def test_routine(task: PipelineTask):
|
493 |
-
# send events
|
494 |
-
await task.queue_frame(
|
495 |
-
AnimationDataStreamStartedFrame(
|
496 |
-
audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
|
497 |
-
)
|
498 |
-
)
|
499 |
-
|
500 |
-
for _ in range(15):
|
501 |
-
await task.queue_frame(
|
502 |
-
AnimationDataStreamRawFrame(
|
503 |
-
animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
|
504 |
-
)
|
505 |
-
)
|
506 |
-
await asyncio.sleep(1.0 / 35.0)
|
507 |
-
|
508 |
-
await task.queue_frame(AnimationDataStreamStoppedFrame(action_id="a1"))
|
509 |
-
assert len(storage.frames_of_type(BotStartedSpeakingFrame)) == 1
|
510 |
-
assert len(storage.frames_of_type(BotSpeakingFrame)) >= 1
|
511 |
-
|
512 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
513 |
-
|
514 |
-
|
515 |
-
@patch("aiohttp.ClientSession.request")
|
516 |
-
async def test_handling_bursted_animation_data(mock_get, anim_graph):
|
517 |
-
"""Test handling of bursted animation data.
|
518 |
-
|
519 |
-
Verifies that the service correctly handles bursted animation data.
|
520 |
-
"""
|
521 |
-
# Mocking response from aiohttp.ClientSession.request
|
522 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
523 |
-
|
524 |
-
# Mocking gRPC stream
|
525 |
-
mock_stub = MagicMock()
|
526 |
-
stream_mock = AsyncMock()
|
527 |
-
stream_mock.write.return_value = "OK"
|
528 |
-
stream_mock.done = MagicMock(return_value=False) #
|
529 |
-
mock_stub.PushAnimationDataStream.return_value = stream_mock
|
530 |
-
|
531 |
-
stream_id = "1235"
|
532 |
-
anim_graph.stub = mock_stub
|
533 |
-
storage = FrameStorage()
|
534 |
-
pipeline = Pipeline([anim_graph, storage])
|
535 |
-
|
536 |
-
async def test_routine(task: PipelineTask):
|
537 |
-
# send events
|
538 |
-
await task.queue_frame(
|
539 |
-
AnimationDataStreamStartedFrame(
|
540 |
-
audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
|
541 |
-
)
|
542 |
-
)
|
543 |
-
|
544 |
-
for _ in range(30 * 3):
|
545 |
-
await task.queue_frame(
|
546 |
-
AnimationDataStreamRawFrame(
|
547 |
-
animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
|
548 |
-
)
|
549 |
-
)
|
550 |
-
|
551 |
-
await asyncio.sleep(2.5)
|
552 |
-
assert stream_mock.done_writing.call_count == 0
|
553 |
-
|
554 |
-
await task.queue_frame(AnimationDataStreamStoppedFrame(action_id="a1"))
|
555 |
-
assert len(storage.frames_of_type(BotStartedSpeakingFrame)) == 1
|
556 |
-
assert len(storage.frames_of_type(BotSpeakingFrame)) >= 1
|
557 |
-
|
558 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
559 |
-
|
560 |
-
|
561 |
-
@patch("aiohttp.ClientSession.request")
|
562 |
-
async def test_handling_interrupted_animation_data_stream(mock_get, anim_graph):
|
563 |
-
"""Test behavior if animation data stream is interrupted."""
|
564 |
-
# Mocking response from aiohttp.ClientSession.request
|
565 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
566 |
-
|
567 |
-
# Mocking gRPC stream
|
568 |
-
mock_stub = MagicMock()
|
569 |
-
stream_mock = AsyncMock()
|
570 |
-
stream_mock.write.return_value = "OK"
|
571 |
-
stream_mock.done = MagicMock(return_value=False) #
|
572 |
-
mock_stub.PushAnimationDataStream.return_value = stream_mock
|
573 |
-
|
574 |
-
stream_id = "1235"
|
575 |
-
anim_graph.stub = mock_stub
|
576 |
-
storage = FrameStorage()
|
577 |
-
pipeline = Pipeline([anim_graph, storage])
|
578 |
-
|
579 |
-
async def test_routine(task: PipelineTask):
|
580 |
-
# send events
|
581 |
-
await task.queue_frame(
|
582 |
-
AnimationDataStreamStartedFrame(
|
583 |
-
audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
|
584 |
-
)
|
585 |
-
)
|
586 |
-
|
587 |
-
for _ in range(5):
|
588 |
-
await task.queue_frame(
|
589 |
-
AnimationDataStreamRawFrame(
|
590 |
-
animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
|
591 |
-
)
|
592 |
-
)
|
593 |
-
await asyncio.sleep(1.0 / 30.0)
|
594 |
-
|
595 |
-
await asyncio.sleep(2.0)
|
596 |
-
|
597 |
-
assert stream_mock.done_writing.call_count == 1
|
598 |
-
|
599 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
600 |
-
|
601 |
-
|
602 |
-
@patch("aiohttp.ClientSession.request")
|
603 |
-
async def test_handling_low_fps_animation_data(mock_get, anim_graph):
|
604 |
-
"""Test animation data stream processing with low FPS.
|
605 |
-
|
606 |
-
Verifies that the service correctly sends out an
|
607 |
-
ErrorFrame when the animation data received is below 30 FPS.
|
608 |
-
"""
|
609 |
-
# Mocking response from aiohttp.ClientSession.request
|
610 |
-
mock_get.return_value.__aenter__.return_value = MockResponse()
|
611 |
-
|
612 |
-
# Mocking gRPC stream
|
613 |
-
mock_stub = MagicMock()
|
614 |
-
stream_mock = AsyncMock()
|
615 |
-
stream_mock.write.return_value = "OK"
|
616 |
-
stream_mock.done = MagicMock(return_value=False) #
|
617 |
-
mock_stub.PushAnimationDataStream.return_value = stream_mock
|
618 |
-
|
619 |
-
stream_id = "1235"
|
620 |
-
anim_graph.stub = mock_stub
|
621 |
-
storage = FrameStorage()
|
622 |
-
pipeline = Pipeline([anim_graph, storage])
|
623 |
-
|
624 |
-
async def test_routine(task: PipelineTask):
|
625 |
-
# send events
|
626 |
-
await task.queue_frame(
|
627 |
-
AnimationDataStreamStartedFrame(
|
628 |
-
audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
|
629 |
-
)
|
630 |
-
)
|
631 |
-
|
632 |
-
for _ in range(20):
|
633 |
-
await task.queue_frame(
|
634 |
-
AnimationDataStreamRawFrame(
|
635 |
-
animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
|
636 |
-
)
|
637 |
-
)
|
638 |
-
await asyncio.sleep(1.0 / 24.5)
|
639 |
-
|
640 |
-
await task.queue_frame(AnimationDataStreamStoppedFrame(action_id="a1"))
|
641 |
-
|
642 |
-
await storage.wait_for_frame(
|
643 |
-
ignore(
|
644 |
-
ErrorFrame(error="AnimGraph: Received data stream is behind by more than 0.1s", fatal=False),
|
645 |
-
"ids",
|
646 |
-
"error",
|
647 |
-
)
|
648 |
-
)
|
649 |
-
|
650 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
|
651 |
-
|
652 |
-
|
653 |
-
@pytest.mark.asyncio
|
654 |
-
async def test_animation_database_search(anim_graph):
|
655 |
-
"""Test animation database semantic search functionality.
|
656 |
-
|
657 |
-
Verifies that searching for 'wave to the user' correctly returns the 'Goodbye'
|
658 |
-
animation clip with appropriate semantic match scores.
|
659 |
-
"""
|
660 |
-
# Get the gesture animation database
|
661 |
-
gesture_db = AnimationGraphService.animation_databases["gesture"]
|
662 |
-
|
663 |
-
# Search for the animation
|
664 |
-
match = gesture_db.query_one("wave goodbye to the user somehow")
|
665 |
-
|
666 |
-
# Verify we got a match and it's the Goodbye animation
|
667 |
-
assert match.animation.id == "Goodbye", "Expected 'wave to the user' to map to 'Goodbye' animation"
|
668 |
-
assert match.description_score > 0.5 or match.meaning_score > 0.5, "Expected high semantic match score"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_audio2face_3d_service.py
DELETED
@@ -1,182 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the Audio2Face 3D Service."""
|
5 |
-
|
6 |
-
import asyncio
|
7 |
-
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
8 |
-
|
9 |
-
import pytest
|
10 |
-
from nvidia_audio2face_3d.messages_pb2 import AudioWithEmotionStream
|
11 |
-
from pipecat.frames.frames import StartInterruptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
12 |
-
from pipecat.pipeline.pipeline import Pipeline
|
13 |
-
|
14 |
-
from nvidia_pipecat.frames.animation import (
|
15 |
-
AnimationDataStreamRawFrame,
|
16 |
-
AnimationDataStreamStartedFrame,
|
17 |
-
AnimationDataStreamStoppedFrame,
|
18 |
-
)
|
19 |
-
from nvidia_pipecat.services.audio2face_3d_service import Audio2Face3DService
|
20 |
-
from tests.unit.utils import FrameStorage, ignore, run_interactive_test
|
21 |
-
|
22 |
-
|
23 |
-
class AsyncIterator:
|
24 |
-
"""Helper class for mocking async iteration in tests.
|
25 |
-
|
26 |
-
Attributes:
|
27 |
-
items: List of items to yield during iteration.
|
28 |
-
"""
|
29 |
-
|
30 |
-
def __init__(self, items):
|
31 |
-
"""Initialize the async iterator.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
items: List of items to yield during iteration.
|
35 |
-
"""
|
36 |
-
self.items = items
|
37 |
-
|
38 |
-
def __aiter__(self):
|
39 |
-
"""Return self as the iterator.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
AsyncIterator: Self reference for iteration.
|
43 |
-
"""
|
44 |
-
return self
|
45 |
-
|
46 |
-
async def __anext__(self):
|
47 |
-
"""Get the next item asynchronously.
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
Any: Next item in the sequence.
|
51 |
-
|
52 |
-
Raises:
|
53 |
-
StopAsyncIteration: When no more items are available.
|
54 |
-
"""
|
55 |
-
try:
|
56 |
-
await asyncio.sleep(0.1)
|
57 |
-
return self.items.pop(0)
|
58 |
-
except IndexError as error:
|
59 |
-
raise StopAsyncIteration from error
|
60 |
-
|
61 |
-
|
62 |
-
def get_mock_stream():
|
63 |
-
"""Create a mock Audio2Face stream for testing.
|
64 |
-
|
65 |
-
Returns:
|
66 |
-
MagicMock: A configured mock object that simulates the A2F service stream,
|
67 |
-
including header and animation data responses.
|
68 |
-
"""
|
69 |
-
mock_stub = MagicMock(spec=["ProcessAudioStream"])
|
70 |
-
|
71 |
-
# Configure mock responses
|
72 |
-
header_response = MagicMock()
|
73 |
-
header_response.HasField = lambda x: x == "animation_data_stream_header"
|
74 |
-
header_response.animation_data_stream_header = MagicMock(
|
75 |
-
audio_header=MagicMock(), skel_animation_header=MagicMock()
|
76 |
-
)
|
77 |
-
|
78 |
-
data_response = MagicMock()
|
79 |
-
data_response.HasField = lambda x: x == "animation_data"
|
80 |
-
data_response.animation_data = "test_animation_data"
|
81 |
-
|
82 |
-
mock_stream = AsyncMock()
|
83 |
-
mock_stream.__aiter__.return_value = [header_response, data_response]
|
84 |
-
mock_stream.done = MagicMock(return_value=False) # Use regular MagicMock for done()
|
85 |
-
mock_stub.ProcessAudioStream.return_value = mock_stream
|
86 |
-
|
87 |
-
return mock_stub
|
88 |
-
|
89 |
-
|
90 |
-
@pytest.mark.asyncio
|
91 |
-
async def test_audio2face_3d_basic_flow():
|
92 |
-
"""Test the basic flow of the Audio2Face 3D service.
|
93 |
-
|
94 |
-
Tests:
|
95 |
-
- TTS audio frame processing
|
96 |
-
- Animation data stream generation
|
97 |
-
- Frame sequence validation
|
98 |
-
- Frame count verification
|
99 |
-
|
100 |
-
Raises:
|
101 |
-
AssertionError: If frame sequence or counts are incorrect.
|
102 |
-
"""
|
103 |
-
mock_stub = get_mock_stream()
|
104 |
-
|
105 |
-
with patch("nvidia_pipecat.services.audio2face_3d_service.A2FControllerServiceStub", return_value=mock_stub):
|
106 |
-
# Initialize test components
|
107 |
-
storage = FrameStorage()
|
108 |
-
a2f = Audio2Face3DService()
|
109 |
-
pipeline = Pipeline([a2f, storage])
|
110 |
-
audio = bytes([6] * (16000 * 2 + 1))
|
111 |
-
|
112 |
-
async def test_routine(task):
|
113 |
-
# Send TTS frames
|
114 |
-
await task.queue_frame(TTSStartedFrame())
|
115 |
-
await task.queue_frame(TTSAudioRawFrame(audio=audio[0:15999], sample_rate=16000, num_channels=1))
|
116 |
-
await task.queue_frame(TTSAudioRawFrame(audio=audio[16000:], sample_rate=16000, num_channels=1))
|
117 |
-
await task.queue_frame(TTSStoppedFrame())
|
118 |
-
|
119 |
-
# Wait for animation started frame
|
120 |
-
started_frame = ignore(
|
121 |
-
AnimationDataStreamStartedFrame(
|
122 |
-
audio_header=ANY, animation_header=ANY, animation_source_id="Audio2Face with Emotions"
|
123 |
-
),
|
124 |
-
"all_ids",
|
125 |
-
"timestamps",
|
126 |
-
)
|
127 |
-
|
128 |
-
await storage.wait_for_frame(started_frame)
|
129 |
-
|
130 |
-
# Wait for animation data frame
|
131 |
-
anim_frame = ignore(
|
132 |
-
AnimationDataStreamRawFrame(animation_data="test_animation_data"),
|
133 |
-
"all_ids",
|
134 |
-
"timestamps",
|
135 |
-
)
|
136 |
-
await storage.wait_for_frame(anim_frame)
|
137 |
-
|
138 |
-
# Wait for stopped frame
|
139 |
-
stopped_frame = ignore(AnimationDataStreamStoppedFrame(), "all_ids", "timestamps")
|
140 |
-
await storage.wait_for_frame(stopped_frame)
|
141 |
-
|
142 |
-
# Verify the sequence of frames
|
143 |
-
assert len(storage.frames_of_type(AnimationDataStreamStartedFrame)) == 1
|
144 |
-
assert len(storage.frames_of_type(AnimationDataStreamRawFrame)) == 1
|
145 |
-
assert len(storage.frames_of_type(AnimationDataStreamStoppedFrame)) == 1
|
146 |
-
|
147 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
148 |
-
|
149 |
-
|
150 |
-
@pytest.mark.asyncio
|
151 |
-
async def test_interruptions():
|
152 |
-
"""Test interruption handling in the Audio2Face 3D service.
|
153 |
-
|
154 |
-
Verifies that the service correctly handles interruption events by:
|
155 |
-
- Properly responding to StartInterruptionFrame
|
156 |
-
- Sending end-of-audio signal to the A2F stream
|
157 |
-
- Maintaining correct state during interruption
|
158 |
-
"""
|
159 |
-
mock_stub = get_mock_stream()
|
160 |
-
|
161 |
-
with patch("nvidia_pipecat.services.audio2face_3d_service.A2FControllerServiceStub", return_value=mock_stub):
|
162 |
-
# Create service and storage
|
163 |
-
storage = FrameStorage()
|
164 |
-
a2f = Audio2Face3DService()
|
165 |
-
pipeline = Pipeline([a2f, storage])
|
166 |
-
|
167 |
-
async def test_routine(task):
|
168 |
-
# Send TTS frames
|
169 |
-
await task.queue_frame(TTSStartedFrame())
|
170 |
-
await task.queue_frame(TTSAudioRawFrame(audio=b"test_audio", sample_rate=16000, num_channels=1))
|
171 |
-
|
172 |
-
await asyncio.sleep(0.1)
|
173 |
-
await task.queue_frame(StartInterruptionFrame())
|
174 |
-
await asyncio.sleep(0.1)
|
175 |
-
|
176 |
-
mock_stub.ProcessAudioStream.return_value.write.assert_called_with(
|
177 |
-
AudioWithEmotionStream(end_of_audio=AudioWithEmotionStream.EndOfAudio())
|
178 |
-
)
|
179 |
-
print(f"{mock_stub}")
|
180 |
-
print(f"storage.history: {storage.history}")
|
181 |
-
|
182 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_audio_util.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for audio utility functionality.
|
5 |
-
|
6 |
-
This module contains tests for audio processing components, particularly the AudioRecorder.
|
7 |
-
It verifies the basic functionality of audio recording and processing pipelines.
|
8 |
-
"""
|
9 |
-
|
10 |
-
from datetime import timedelta
|
11 |
-
from pathlib import Path
|
12 |
-
|
13 |
-
import pytest
|
14 |
-
from pipecat.frames.frames import InputAudioRawFrame, TextFrame
|
15 |
-
from pipecat.pipeline.pipeline import Pipeline
|
16 |
-
from pipecat.tests.utils import SleepFrame
|
17 |
-
from pipecat.transports.base_transport import TransportParams
|
18 |
-
|
19 |
-
from nvidia_pipecat.processors.audio_util import AudioRecorder
|
20 |
-
from nvidia_pipecat.utils.logging import setup_default_ace_logging
|
21 |
-
from tests.unit.utils import SinusWaveProcessor, ignore, ignore_ids, run_test
|
22 |
-
|
23 |
-
|
24 |
-
@pytest.mark.asyncio()
|
25 |
-
async def test_audio_recorder():
|
26 |
-
"""Test the AudioRecorder processor functionality.
|
27 |
-
|
28 |
-
Tests:
|
29 |
-
- Audio frame processing from sine wave generator
|
30 |
-
- WAV file writing
|
31 |
-
- Sample rate conversion (16kHz to 24kHz)
|
32 |
-
- Non-audio frame passthrough
|
33 |
-
|
34 |
-
Raises:
|
35 |
-
AssertionError: If audio file is not created or frame processing fails.
|
36 |
-
"""
|
37 |
-
setup_default_ace_logging(level="TRACE")
|
38 |
-
|
39 |
-
# Delete tmp audio file if it exists
|
40 |
-
TMP_FILE = Path("./tmp_file.wav")
|
41 |
-
if TMP_FILE.exists():
|
42 |
-
TMP_FILE.unlink()
|
43 |
-
|
44 |
-
recorder = AudioRecorder(output_file=str(TMP_FILE), params=TransportParams(audio_out_sample_rate=24000))
|
45 |
-
sinus = SinusWaveProcessor(duration=timedelta(seconds=0.3))
|
46 |
-
|
47 |
-
frames_to_send = [
|
48 |
-
SleepFrame(0.5),
|
49 |
-
TextFrame("go through"),
|
50 |
-
]
|
51 |
-
expected_down_frames = [
|
52 |
-
ignore(InputAudioRawFrame(audio=b"1" * 640, sample_rate=16000, num_channels=1), "audio", "id", "name")
|
53 |
-
] * sinus.audio_frame_count + [ignore_ids(TextFrame("go through"))]
|
54 |
-
pipeline = Pipeline([sinus, recorder])
|
55 |
-
print(f"expected number of frames: {sinus.audio_frame_count}")
|
56 |
-
await run_test(
|
57 |
-
pipeline,
|
58 |
-
frames_to_send=frames_to_send,
|
59 |
-
expected_down_frames=expected_down_frames,
|
60 |
-
start_metadata={"stream_id": "1235"},
|
61 |
-
)
|
62 |
-
|
63 |
-
# Make sure that audio dump was generated
|
64 |
-
assert TMP_FILE.is_file()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_basic_pipelines.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Basic functionality tests for pipecat pipelines."""
|
5 |
-
|
6 |
-
import asyncio
|
7 |
-
|
8 |
-
import pytest
|
9 |
-
from loguru import logger
|
10 |
-
from pipecat.frames.frames import StartFrame, TextFrame
|
11 |
-
from pipecat.processors.aggregators.sentence import SentenceAggregator
|
12 |
-
from pipecat.tests.utils import SleepFrame
|
13 |
-
|
14 |
-
from nvidia_pipecat.utils.logging import logger_context, setup_default_ace_logging
|
15 |
-
from tests.unit.utils import ignore_ids, run_test
|
16 |
-
|
17 |
-
|
18 |
-
@pytest.mark.asyncio()
|
19 |
-
async def test_simple_pipeline():
|
20 |
-
"""Example test for testing a pipecat pipeline. This test makes sure the basic pipeline related classes work."""
|
21 |
-
aggregator = SentenceAggregator()
|
22 |
-
frames_to_send = [TextFrame("Hello, "), TextFrame("world.")]
|
23 |
-
expected_down_frames = [ignore_ids(TextFrame("Hello, world."))]
|
24 |
-
|
25 |
-
await run_test(
|
26 |
-
aggregator,
|
27 |
-
frames_to_send=frames_to_send,
|
28 |
-
expected_down_frames=expected_down_frames,
|
29 |
-
start_metadata={"stream_id": "1235"},
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
@pytest.mark.asyncio()
|
34 |
-
async def test_pipeline_with_stream_id():
|
35 |
-
"""Test pipeline creation with a stream_id.
|
36 |
-
|
37 |
-
Verifies that a pipeline can be created with a specific stream_id and that
|
38 |
-
the metadata is properly propagated with the StartFrame.
|
39 |
-
"""
|
40 |
-
aggregator = SentenceAggregator()
|
41 |
-
frames_to_send = [TextFrame("Hello, "), TextFrame("world.")]
|
42 |
-
start_metadata = {"stream_id": "1234"}
|
43 |
-
|
44 |
-
expected_start_frame = ignore_ids(StartFrame())
|
45 |
-
expected_start_frame.metadata = start_metadata
|
46 |
-
expected_down_frames = [expected_start_frame, ignore_ids(TextFrame("Hello, world."))]
|
47 |
-
|
48 |
-
await run_test(
|
49 |
-
aggregator,
|
50 |
-
frames_to_send=frames_to_send,
|
51 |
-
expected_down_frames=expected_down_frames,
|
52 |
-
start_metadata=start_metadata,
|
53 |
-
ignore_start=False,
|
54 |
-
)
|
55 |
-
|
56 |
-
|
57 |
-
@pytest.mark.asyncio()
|
58 |
-
async def test_ace_logger_with_stream_id(capsys):
|
59 |
-
"""Test ACE logger behavior when stream_id is provided.
|
60 |
-
|
61 |
-
Verifies that the logger correctly handles and displays stream_id in the logs.
|
62 |
-
"""
|
63 |
-
setup_default_ace_logging(level="DEBUG")
|
64 |
-
with logger.contextualize(stream_id="1237"):
|
65 |
-
aggregator = SentenceAggregator()
|
66 |
-
frames_to_send = [TextFrame("Hello, "), TextFrame("world.")]
|
67 |
-
expected_down_frames = [ignore_ids(TextFrame("Hello, world."))]
|
68 |
-
await run_test(aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
69 |
-
|
70 |
-
captured = capsys.readouterr()
|
71 |
-
assert "streamId=1237" in captured.err
|
72 |
-
|
73 |
-
|
74 |
-
async def run_pipeline_task(stream: str):
|
75 |
-
"""Run a test pipeline task with a specific stream identifier.
|
76 |
-
|
77 |
-
Creates and runs a pipeline that processes a sequence of text frames with
|
78 |
-
sleep intervals, aggregating them into a single sentence.
|
79 |
-
|
80 |
-
Args:
|
81 |
-
stream: The stream identifier to use for the pipeline task.
|
82 |
-
"""
|
83 |
-
REPETITIONS = 5
|
84 |
-
frames_to_send = []
|
85 |
-
aggregated_str = ""
|
86 |
-
for i in range(REPETITIONS):
|
87 |
-
frames_to_send.append(TextFrame(f"S{stream}-T{i}"))
|
88 |
-
aggregated_str += f"S{stream}-T{i}"
|
89 |
-
frames_to_send.append(SleepFrame(0.1))
|
90 |
-
|
91 |
-
frames_to_send.append(TextFrame("."))
|
92 |
-
expected_down_frames = [ignore_ids(TextFrame(f"{aggregated_str}."))]
|
93 |
-
aggregator = SentenceAggregator()
|
94 |
-
await run_test(aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
95 |
-
|
96 |
-
|
97 |
-
@pytest.mark.asyncio()
|
98 |
-
async def test_logging_with_multiple_pipelines_in_same_process(capsys):
|
99 |
-
"""Test logging behavior with multiple concurrent pipeline streams.
|
100 |
-
|
101 |
-
Verifies that when multiple pipeline tasks are running concurrently in the same
|
102 |
-
process, each stream's logs are correctly tagged with its respective stream_id.
|
103 |
-
The test ensures proper isolation and identification of log messages across
|
104 |
-
different pipeline streams.
|
105 |
-
|
106 |
-
Args:
|
107 |
-
capsys: Pytest fixture for capturing system output.
|
108 |
-
"""
|
109 |
-
setup_default_ace_logging(level="TRACE")
|
110 |
-
|
111 |
-
streams = ["777", "abc123"]
|
112 |
-
tasks = []
|
113 |
-
|
114 |
-
for stream in streams:
|
115 |
-
task = asyncio.create_task(logger_context(run_pipeline_task(stream=stream), stream_id=stream))
|
116 |
-
tasks.append(task)
|
117 |
-
|
118 |
-
await asyncio.gather(*tasks)
|
119 |
-
|
120 |
-
# Make sure the correct stream ID is logged for the different coroutines
|
121 |
-
captured = capsys.readouterr()
|
122 |
-
lines = captured.err.split("\n")
|
123 |
-
for line in lines:
|
124 |
-
for stream in streams:
|
125 |
-
if f"S{stream}" in line:
|
126 |
-
assert f"streamId={stream}" in line
|
127 |
-
|
128 |
-
for task in asyncio.all_tasks():
|
129 |
-
if "task_handler" in task.get_coro().__name__:
|
130 |
-
task.cancel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_blingfire_text_aggregator.py
DELETED
@@ -1,244 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for BlingfireTextAggregator."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
|
8 |
-
from nvidia_pipecat.services.blingfire_text_aggregator import BlingfireTextAggregator
|
9 |
-
|
10 |
-
|
11 |
-
class TestBlingfireTextAggregator:
|
12 |
-
"""Test suite for BlingfireTextAggregator."""
|
13 |
-
|
14 |
-
def test_initialization(self):
|
15 |
-
"""Test that the aggregator initializes with empty text buffer."""
|
16 |
-
aggregator = BlingfireTextAggregator()
|
17 |
-
assert aggregator.text == ""
|
18 |
-
|
19 |
-
@pytest.mark.asyncio()
|
20 |
-
async def test_single_word_no_sentence(self):
|
21 |
-
"""Test that single words without sentence endings don't return sentences."""
|
22 |
-
aggregator = BlingfireTextAggregator()
|
23 |
-
result = await aggregator.aggregate("hello")
|
24 |
-
assert result is None
|
25 |
-
assert aggregator.text == "hello"
|
26 |
-
|
27 |
-
@pytest.mark.asyncio()
|
28 |
-
async def test_incomplete_sentence(self):
|
29 |
-
"""Test that incomplete sentences are buffered but not returned."""
|
30 |
-
aggregator = BlingfireTextAggregator()
|
31 |
-
result = await aggregator.aggregate("Hello there")
|
32 |
-
assert result is None
|
33 |
-
assert aggregator.text == "Hello there"
|
34 |
-
|
35 |
-
@pytest.mark.asyncio()
|
36 |
-
async def test_single_complete_sentence(self):
|
37 |
-
"""Test that a single complete sentence is detected and returned."""
|
38 |
-
aggregator = BlingfireTextAggregator()
|
39 |
-
result = await aggregator.aggregate("Hello world.")
|
40 |
-
assert result is None # Single sentence won't trigger return
|
41 |
-
assert aggregator.text == "Hello world."
|
42 |
-
|
43 |
-
@pytest.mark.asyncio()
|
44 |
-
async def test_multiple_sentences_detection(self):
|
45 |
-
"""Test that multiple sentences trigger return of the first complete sentence."""
|
46 |
-
aggregator = BlingfireTextAggregator()
|
47 |
-
result = await aggregator.aggregate("Hello world. How are you?")
|
48 |
-
assert result == "Hello world."
|
49 |
-
assert "How are you?" in aggregator.text
|
50 |
-
assert "Hello world." not in aggregator.text
|
51 |
-
|
52 |
-
@pytest.mark.asyncio()
|
53 |
-
async def test_incremental_sentence_building(self):
|
54 |
-
"""Test building a sentence incrementally."""
|
55 |
-
aggregator = BlingfireTextAggregator()
|
56 |
-
|
57 |
-
# Add text piece by piece
|
58 |
-
result = await aggregator.aggregate("Hello")
|
59 |
-
assert result is None
|
60 |
-
assert aggregator.text == "Hello"
|
61 |
-
|
62 |
-
result = await aggregator.aggregate(" world")
|
63 |
-
assert result is None
|
64 |
-
assert aggregator.text == "Hello world"
|
65 |
-
|
66 |
-
result = await aggregator.aggregate(".")
|
67 |
-
assert result is None
|
68 |
-
assert aggregator.text == "Hello world."
|
69 |
-
|
70 |
-
@pytest.mark.asyncio()
|
71 |
-
async def test_incremental_multiple_sentences(self):
|
72 |
-
"""Test building multiple sentences incrementally."""
|
73 |
-
aggregator = BlingfireTextAggregator()
|
74 |
-
|
75 |
-
# Build first sentence
|
76 |
-
result = await aggregator.aggregate("Hello world.")
|
77 |
-
assert result is None
|
78 |
-
assert aggregator.text == "Hello world."
|
79 |
-
|
80 |
-
# Add second sentence - this should trigger return
|
81 |
-
result = await aggregator.aggregate(" How are you?")
|
82 |
-
assert result == "Hello world."
|
83 |
-
assert "How are you?" in aggregator.text
|
84 |
-
assert "Hello world." not in aggregator.text
|
85 |
-
|
86 |
-
@pytest.mark.asyncio()
|
87 |
-
async def test_empty_string_input(self):
|
88 |
-
"""Test handling of empty string input."""
|
89 |
-
aggregator = BlingfireTextAggregator()
|
90 |
-
result = await aggregator.aggregate("")
|
91 |
-
assert result is None
|
92 |
-
assert aggregator.text == ""
|
93 |
-
|
94 |
-
@pytest.mark.asyncio()
|
95 |
-
async def test_whitespace_handling(self):
|
96 |
-
"""Test proper handling of whitespace in text."""
|
97 |
-
aggregator = BlingfireTextAggregator()
|
98 |
-
result = await aggregator.aggregate(" Hello world. ")
|
99 |
-
assert result is None
|
100 |
-
assert aggregator.text == " Hello world. "
|
101 |
-
|
102 |
-
@pytest.mark.asyncio()
|
103 |
-
async def test_multiple_sentences_in_single_call(self):
|
104 |
-
"""Test processing multiple sentences passed in a single aggregate call."""
|
105 |
-
aggregator = BlingfireTextAggregator()
|
106 |
-
result = await aggregator.aggregate("First sentence. Second sentence. Third sentence.")
|
107 |
-
assert result == "First sentence."
|
108 |
-
# Remaining text should contain the other sentences
|
109 |
-
remaining_text = aggregator.text
|
110 |
-
assert "Second sentence. Third sentence." in remaining_text
|
111 |
-
|
112 |
-
@pytest.mark.asyncio()
|
113 |
-
async def test_sentence_with_special_punctuation(self):
|
114 |
-
"""Test sentences with different punctuation marks."""
|
115 |
-
aggregator = BlingfireTextAggregator()
|
116 |
-
|
117 |
-
# Test exclamation mark
|
118 |
-
result = await aggregator.aggregate("Hello world! How are you?")
|
119 |
-
assert result == "Hello world!"
|
120 |
-
assert "How are you?" in aggregator.text
|
121 |
-
|
122 |
-
await aggregator.reset()
|
123 |
-
|
124 |
-
# Test question mark
|
125 |
-
result = await aggregator.aggregate("How are you? I'm fine.")
|
126 |
-
assert result == "How are you?"
|
127 |
-
assert "I'm fine." in aggregator.text
|
128 |
-
|
129 |
-
@pytest.mark.asyncio()
|
130 |
-
async def test_handle_interruption(self):
|
131 |
-
"""Test that handle_interruption clears the text buffer."""
|
132 |
-
aggregator = BlingfireTextAggregator()
|
133 |
-
await aggregator.aggregate("Hello world")
|
134 |
-
assert aggregator.text == "Hello world"
|
135 |
-
|
136 |
-
await aggregator.handle_interruption()
|
137 |
-
assert aggregator.text == ""
|
138 |
-
|
139 |
-
@pytest.mark.asyncio()
|
140 |
-
async def test_reset(self):
|
141 |
-
"""Test that reset clears the text buffer."""
|
142 |
-
aggregator = BlingfireTextAggregator()
|
143 |
-
await aggregator.aggregate("Hello world")
|
144 |
-
assert aggregator.text == "Hello world"
|
145 |
-
|
146 |
-
await aggregator.reset()
|
147 |
-
assert aggregator.text == ""
|
148 |
-
|
149 |
-
@pytest.mark.asyncio()
|
150 |
-
async def test_reset_after_sentence_detection(self):
|
151 |
-
"""Test reset functionality after sentence detection."""
|
152 |
-
aggregator = BlingfireTextAggregator()
|
153 |
-
result = await aggregator.aggregate("First sentence. Second sentence.")
|
154 |
-
assert result == "First sentence."
|
155 |
-
assert "Second sentence." in aggregator.text
|
156 |
-
|
157 |
-
await aggregator.reset()
|
158 |
-
assert aggregator.text == ""
|
159 |
-
|
160 |
-
@pytest.mark.asyncio()
|
161 |
-
async def test_consecutive_sentence_processing(self):
|
162 |
-
"""Test processing consecutive sentences through multiple aggregate calls."""
|
163 |
-
aggregator = BlingfireTextAggregator()
|
164 |
-
|
165 |
-
# First pair of sentences
|
166 |
-
result = await aggregator.aggregate("First sentence. Second sentence.")
|
167 |
-
assert result == "First sentence."
|
168 |
-
|
169 |
-
# Add third sentence - should trigger return of second
|
170 |
-
result = await aggregator.aggregate(" Third sentence.")
|
171 |
-
assert result == "Second sentence."
|
172 |
-
assert "Third sentence." in aggregator.text
|
173 |
-
|
174 |
-
@pytest.mark.asyncio()
|
175 |
-
async def test_long_sentence_handling(self):
|
176 |
-
"""Test handling of longer sentences."""
|
177 |
-
aggregator = BlingfireTextAggregator()
|
178 |
-
long_sentence = (
|
179 |
-
"This is a very long sentence with many words that goes on and on "
|
180 |
-
"and should still be handled correctly by the aggregator."
|
181 |
-
)
|
182 |
-
result = await aggregator.aggregate(long_sentence)
|
183 |
-
assert result is None
|
184 |
-
assert aggregator.text == long_sentence
|
185 |
-
|
186 |
-
@pytest.mark.asyncio()
|
187 |
-
async def test_sentence_boundaries_with_abbreviations(self):
|
188 |
-
"""Test sentence detection with abbreviations that contain periods."""
|
189 |
-
aggregator = BlingfireTextAggregator()
|
190 |
-
result = await aggregator.aggregate("Dr. Smith went to the U.S.A. He had a great time.")
|
191 |
-
assert result == "Dr. Smith went to the U.S.A."
|
192 |
-
assert "He had a great time." in aggregator.text
|
193 |
-
|
194 |
-
@pytest.mark.asyncio()
|
195 |
-
async def test_newline_handling(self):
|
196 |
-
"""Test handling of text with newlines."""
|
197 |
-
aggregator = BlingfireTextAggregator()
|
198 |
-
result = await aggregator.aggregate("First line.\nSecond line.")
|
199 |
-
assert result == "First line."
|
200 |
-
assert "Second line." in aggregator.text
|
201 |
-
|
202 |
-
@pytest.mark.asyncio()
|
203 |
-
async def test_mixed_sentence_endings(self):
|
204 |
-
"""Test text with mixed sentence ending punctuation."""
|
205 |
-
aggregator = BlingfireTextAggregator()
|
206 |
-
result = await aggregator.aggregate("What time is it? It's 3 PM! That's great.")
|
207 |
-
assert result == "What time is it?"
|
208 |
-
remaining = aggregator.text
|
209 |
-
assert "It's 3 PM!" in remaining
|
210 |
-
assert "That's great." in remaining
|
211 |
-
|
212 |
-
@pytest.mark.asyncio()
|
213 |
-
async def test_text_property_consistency(self):
|
214 |
-
"""Test that the text property always returns the current buffer state."""
|
215 |
-
aggregator = BlingfireTextAggregator()
|
216 |
-
|
217 |
-
# Initially empty
|
218 |
-
assert aggregator.text == ""
|
219 |
-
|
220 |
-
# After adding incomplete sentence
|
221 |
-
await aggregator.aggregate("Hello")
|
222 |
-
assert aggregator.text == "Hello"
|
223 |
-
|
224 |
-
# After adding more text
|
225 |
-
await aggregator.aggregate(" world")
|
226 |
-
assert aggregator.text == "Hello world"
|
227 |
-
|
228 |
-
# After completing sentence (but no return yet)
|
229 |
-
await aggregator.aggregate(".")
|
230 |
-
assert aggregator.text == "Hello world."
|
231 |
-
|
232 |
-
# After triggering sentence return
|
233 |
-
await aggregator.aggregate(" Next sentence.")
|
234 |
-
assert "Next sentence." in aggregator.text
|
235 |
-
assert "Hello world." not in aggregator.text
|
236 |
-
|
237 |
-
@pytest.mark.asyncio()
|
238 |
-
async def test_empty_sentences_filtering(self):
|
239 |
-
"""Test that empty sentences are properly filtered out."""
|
240 |
-
aggregator = BlingfireTextAggregator()
|
241 |
-
# Text with multiple periods that might create empty sentences
|
242 |
-
result = await aggregator.aggregate("Hello... world.")
|
243 |
-
assert result is None
|
244 |
-
assert aggregator.text == "Hello... world."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_custom_view.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the CustomView related frames."""
|
5 |
-
|
6 |
-
import json
|
7 |
-
|
8 |
-
from nvidia_pipecat.frames.custom_view import (
|
9 |
-
Button,
|
10 |
-
ButtonListBlock,
|
11 |
-
ButtonVariant,
|
12 |
-
HeaderBlock,
|
13 |
-
Hint,
|
14 |
-
HintCarouselBlock,
|
15 |
-
Image,
|
16 |
-
ImageBlock,
|
17 |
-
ImagePosition,
|
18 |
-
ImageWithTextBlock,
|
19 |
-
SelectableOption,
|
20 |
-
SelectableOptionsGridBlock,
|
21 |
-
StartCustomViewFrame,
|
22 |
-
TableBlock,
|
23 |
-
TextBlock,
|
24 |
-
TextInputBlock,
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
def test_to_json_empty():
|
29 |
-
"""Tests empty custom view JSON serialization.
|
30 |
-
|
31 |
-
Tests:
|
32 |
-
- Empty frame conversion
|
33 |
-
- Minimal configuration
|
34 |
-
- Default values
|
35 |
-
|
36 |
-
Raises:
|
37 |
-
AssertionError: If JSON output doesn't match expected format.
|
38 |
-
"""
|
39 |
-
frame = StartCustomViewFrame(action_id="test-action-id")
|
40 |
-
result = frame.to_json()
|
41 |
-
expected_result = {}
|
42 |
-
assert result == json.dumps(expected_result)
|
43 |
-
|
44 |
-
|
45 |
-
def test_to_json_simple():
|
46 |
-
"""Tests simple custom view JSON serialization.
|
47 |
-
|
48 |
-
Tests:
|
49 |
-
- Basic block configuration
|
50 |
-
- Header block formatting
|
51 |
-
- Single block conversion
|
52 |
-
|
53 |
-
Raises:
|
54 |
-
AssertionError: If JSON output doesn't match expected format.
|
55 |
-
"""
|
56 |
-
frame = StartCustomViewFrame(
|
57 |
-
action_id="test-action-id",
|
58 |
-
blocks=[
|
59 |
-
HeaderBlock(id="test-header", header="Test Header", level=1),
|
60 |
-
],
|
61 |
-
)
|
62 |
-
result = frame.to_json()
|
63 |
-
expected_result = {
|
64 |
-
"blocks": [{"id": "test-header", "type": "header", "data": {"header": "Test Header", "level": 1}}],
|
65 |
-
}
|
66 |
-
assert result == json.dumps(expected_result)
|
67 |
-
|
68 |
-
|
69 |
-
def test_to_json_complex():
|
70 |
-
"""Tests complex custom view JSON serialization.
|
71 |
-
|
72 |
-
Tests:
|
73 |
-
- Multiple block types
|
74 |
-
- Image handling
|
75 |
-
- Nested data structures
|
76 |
-
|
77 |
-
Raises:
|
78 |
-
AssertionError: If JSON output doesn't match expected format.
|
79 |
-
"""
|
80 |
-
# Sample base64 string for image data
|
81 |
-
sample_base64 = "iVBORw0KGgoAAAANSUhEUgAAA5UAAAC5CAIAAAA3TIxUAADd3"
|
82 |
-
|
83 |
-
frame = StartCustomViewFrame(
|
84 |
-
action_id="test-action-id",
|
85 |
-
blocks=[
|
86 |
-
HeaderBlock(id="test-header", header="Test Header", level=1),
|
87 |
-
TextBlock(id="test-text", text="Test Text"),
|
88 |
-
ImageBlock(id="test-image", image=Image(url="https://example.com/image.jpg")),
|
89 |
-
ImageWithTextBlock(
|
90 |
-
id="test-image-with-text",
|
91 |
-
image=Image(data=sample_base64),
|
92 |
-
text="Test Text",
|
93 |
-
image_position=ImagePosition.LEFT,
|
94 |
-
),
|
95 |
-
TableBlock(id="test-table", headers=["Header 1", "Header 2"], rows=[["Row 1", "Row 2"]]),
|
96 |
-
HintCarouselBlock(
|
97 |
-
id="test-hint-carousel", hints=[Hint(id="test-hint", name="Test Hint", text="Test Text")]
|
98 |
-
),
|
99 |
-
ButtonListBlock(
|
100 |
-
id="test-button-list",
|
101 |
-
buttons=[
|
102 |
-
Button(
|
103 |
-
id="test-button",
|
104 |
-
active=True,
|
105 |
-
toggled=False,
|
106 |
-
variant=ButtonVariant.CONTAINED,
|
107 |
-
text="Test Button",
|
108 |
-
)
|
109 |
-
],
|
110 |
-
),
|
111 |
-
SelectableOptionsGridBlock(
|
112 |
-
id="test-selectable-options-grid",
|
113 |
-
buttons=[
|
114 |
-
SelectableOption(
|
115 |
-
id="test-option",
|
116 |
-
image=Image(url="https://example.com/image.jpg"),
|
117 |
-
text="Test Option",
|
118 |
-
active=True,
|
119 |
-
toggled=False,
|
120 |
-
)
|
121 |
-
],
|
122 |
-
),
|
123 |
-
TextInputBlock(
|
124 |
-
id="test-input",
|
125 |
-
default_value="Test Default Value",
|
126 |
-
value="Test Value",
|
127 |
-
label="Test Label",
|
128 |
-
input_type="text",
|
129 |
-
),
|
130 |
-
],
|
131 |
-
)
|
132 |
-
result = frame.to_json()
|
133 |
-
expected_result = {
|
134 |
-
"blocks": [
|
135 |
-
{"id": "test-header", "type": "header", "data": {"header": "Test Header", "level": 1}},
|
136 |
-
{"id": "test-text", "type": "paragraph", "data": {"text": "Test Text"}},
|
137 |
-
{
|
138 |
-
"id": "test-image",
|
139 |
-
"type": "image",
|
140 |
-
"data": {"image": {"url": "https://example.com/image.jpg"}},
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"id": "test-image-with-text",
|
144 |
-
"type": "paragraph_with_image",
|
145 |
-
"data": {
|
146 |
-
"image": {"data": sample_base64},
|
147 |
-
"text": "Test Text",
|
148 |
-
"image_position": "left",
|
149 |
-
},
|
150 |
-
},
|
151 |
-
{
|
152 |
-
"id": "test-table",
|
153 |
-
"type": "table",
|
154 |
-
"data": {"headers": ["Header 1", "Header 2"], "rows": [["Row 1", "Row 2"]]},
|
155 |
-
},
|
156 |
-
{
|
157 |
-
"id": "test-hint-carousel",
|
158 |
-
"type": "hint_carousel",
|
159 |
-
"data": {"hints": [{"name": "Test Hint", "text": "Test Text"}]},
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"id": "test-button-list",
|
163 |
-
"type": "button_list",
|
164 |
-
"data": {
|
165 |
-
"buttons": [
|
166 |
-
{
|
167 |
-
"id": "test-button",
|
168 |
-
"active": True,
|
169 |
-
"toggled": False,
|
170 |
-
"variant": "contained",
|
171 |
-
"text": "Test Button",
|
172 |
-
}
|
173 |
-
]
|
174 |
-
},
|
175 |
-
},
|
176 |
-
{
|
177 |
-
"id": "test-selectable-options-grid",
|
178 |
-
"type": "selectable_options_grid",
|
179 |
-
"data": {
|
180 |
-
"buttons": [
|
181 |
-
{
|
182 |
-
"id": "test-option",
|
183 |
-
"image": {"url": "https://example.com/image.jpg"},
|
184 |
-
"text": "Test Option",
|
185 |
-
"active": True,
|
186 |
-
"toggled": False,
|
187 |
-
}
|
188 |
-
]
|
189 |
-
},
|
190 |
-
},
|
191 |
-
{
|
192 |
-
"id": "test-input",
|
193 |
-
"type": "text_input",
|
194 |
-
"data": {
|
195 |
-
"default_value": "Test Default Value",
|
196 |
-
"value": "Test Value",
|
197 |
-
"label": "Test Label",
|
198 |
-
"input_type": "text",
|
199 |
-
},
|
200 |
-
},
|
201 |
-
],
|
202 |
-
}
|
203 |
-
assert result == json.dumps(expected_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_elevenlabs.py
DELETED
@@ -1,184 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the ElevenLabsTTSServiceWithEndOfSpeech class."""
|
5 |
-
|
6 |
-
import asyncio
|
7 |
-
import base64
|
8 |
-
import json
|
9 |
-
from unittest.mock import AsyncMock, patch
|
10 |
-
|
11 |
-
import pytest
|
12 |
-
from loguru import logger
|
13 |
-
from pipecat.frames.frames import TTSAudioRawFrame, TTSSpeakFrame, TTSStoppedFrame
|
14 |
-
from pipecat.pipeline.pipeline import Pipeline
|
15 |
-
from pipecat.pipeline.task import PipelineTask
|
16 |
-
from websockets.protocol import State
|
17 |
-
|
18 |
-
from nvidia_pipecat.services.elevenlabs import ElevenLabsTTSServiceWithEndOfSpeech
|
19 |
-
from nvidia_pipecat.utils.logging import setup_default_ace_logging
|
20 |
-
from tests.unit.utils import FrameStorage, ignore_ids, run_interactive_test
|
21 |
-
|
22 |
-
setup_default_ace_logging(level="TRACE")
|
23 |
-
|
24 |
-
|
25 |
-
class MockWebSocket:
|
26 |
-
"""Mock WebSocket for testing ElevenLabs service.
|
27 |
-
|
28 |
-
Attributes:
|
29 |
-
messages_to_return: List of messages to return during testing.
|
30 |
-
sent_messages: List of messages sent through the socket.
|
31 |
-
state: Current WebSocket connection state.
|
32 |
-
close_rcvd: Close frame received flag.
|
33 |
-
close_rcvd_then_sent: Close frame received and sent flag.
|
34 |
-
close_sent: Close frame sent flag.
|
35 |
-
closed: Whether the WebSocket is closed.
|
36 |
-
"""
|
37 |
-
|
38 |
-
def __init__(self, messages_to_return):
|
39 |
-
"""Initialize MockWebSocket.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
messages_to_return (list): List of messages to return during testing.
|
43 |
-
"""
|
44 |
-
self.messages_to_return = messages_to_return
|
45 |
-
self.sent_messages = []
|
46 |
-
self.state = State.OPEN
|
47 |
-
self.close_rcvd = None
|
48 |
-
self.close_rcvd_then_sent = None
|
49 |
-
self.close_sent = None
|
50 |
-
self.closed = False
|
51 |
-
|
52 |
-
async def send(self, message: str) -> None:
|
53 |
-
"""Sends a message through the mock socket.
|
54 |
-
|
55 |
-
Args:
|
56 |
-
message (str): Message to send.
|
57 |
-
"""
|
58 |
-
self.sent_messages.append(json.loads(message))
|
59 |
-
|
60 |
-
async def ping(self) -> bool:
|
61 |
-
"""Simulates WebSocket heartbeat.
|
62 |
-
|
63 |
-
Returns:
|
64 |
-
bool: Always True for testing.
|
65 |
-
"""
|
66 |
-
return True
|
67 |
-
|
68 |
-
async def close(self) -> bool:
|
69 |
-
"""Closes the mock WebSocket connection.
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
bool: Always True for testing.
|
73 |
-
"""
|
74 |
-
self.state = State.CLOSED
|
75 |
-
self.closed = True
|
76 |
-
return True
|
77 |
-
|
78 |
-
async def __aiter__(self):
|
79 |
-
"""Async iterator for messages.
|
80 |
-
|
81 |
-
Yields:
|
82 |
-
str: JSON-encoded message.
|
83 |
-
"""
|
84 |
-
for msg in self.messages_to_return:
|
85 |
-
yield json.dumps(msg)
|
86 |
-
while self.state != State.CLOSED:
|
87 |
-
await asyncio.sleep(1.0)
|
88 |
-
yield "{}"
|
89 |
-
|
90 |
-
|
91 |
-
@pytest.mark.asyncio()
|
92 |
-
async def test_elevenlabs_tts_service_with_end_of_speech():
|
93 |
-
"""Test ElevenLabsTTSServiceWithEndOfSpeech functionality.
|
94 |
-
|
95 |
-
Tests:
|
96 |
-
- End-of-speech boundary marker handling
|
97 |
-
- Audio message processing
|
98 |
-
- Alignment message processing
|
99 |
-
- TTSStoppedFrame generation
|
100 |
-
|
101 |
-
Raises:
|
102 |
-
AssertionError: If frame processing or timing is incorrect.
|
103 |
-
"""
|
104 |
-
# Test audio data
|
105 |
-
test_audio = b"test_audio_data"
|
106 |
-
test_audio_b64 = base64.b64encode(test_audio).decode()
|
107 |
-
|
108 |
-
# Test cases with different message sequences
|
109 |
-
testcases = {
|
110 |
-
"Normal audio with boundary marker": {
|
111 |
-
"frames_to_send": [TTSSpeakFrame("Hello")],
|
112 |
-
"messages": [
|
113 |
-
{
|
114 |
-
"audio": test_audio_b64,
|
115 |
-
"alignment": {
|
116 |
-
"chars": ["H", "e", "l", "l", "o", "\u200b"],
|
117 |
-
"charStartTimesMs": [0, 3, 7, 9, 11, 12],
|
118 |
-
"charDurationsMs": [3, 4, 2, 2, 1, 1],
|
119 |
-
},
|
120 |
-
},
|
121 |
-
],
|
122 |
-
"expected_frames": [
|
123 |
-
TTSAudioRawFrame(test_audio, 16000, 1),
|
124 |
-
TTSStoppedFrame(),
|
125 |
-
],
|
126 |
-
},
|
127 |
-
"Multiple audio chunks": {
|
128 |
-
"frames_to_send": [TTSSpeakFrame("Test")],
|
129 |
-
"messages": [
|
130 |
-
{
|
131 |
-
"audio": test_audio_b64,
|
132 |
-
"alignment": {
|
133 |
-
"chars": ["T", "e"],
|
134 |
-
"charStartTimesMs": [0, 3],
|
135 |
-
"charDurationsMs": [3, 4],
|
136 |
-
},
|
137 |
-
},
|
138 |
-
{
|
139 |
-
"audio": test_audio_b64,
|
140 |
-
"alignment": {
|
141 |
-
"chars": ["s", "t", "\u200b"],
|
142 |
-
"charStartTimesMs": [7, 9, 11],
|
143 |
-
"charDurationsMs": [2, 2, 1],
|
144 |
-
},
|
145 |
-
},
|
146 |
-
],
|
147 |
-
"expected_frames": [
|
148 |
-
TTSAudioRawFrame(test_audio, 16000, 1),
|
149 |
-
TTSAudioRawFrame(test_audio, 16000, 1),
|
150 |
-
TTSStoppedFrame(),
|
151 |
-
],
|
152 |
-
},
|
153 |
-
}
|
154 |
-
|
155 |
-
for tc_name, tc_data in testcases.items():
|
156 |
-
logger.info(f"Verifying test case: {tc_name}")
|
157 |
-
|
158 |
-
# Create mock websocket with test messages
|
159 |
-
mock_websocket = MockWebSocket(tc_data["messages"])
|
160 |
-
|
161 |
-
# mock = AsyncMock()
|
162 |
-
# mock.return_value = mock_websocket
|
163 |
-
|
164 |
-
with patch("pipecat.services.elevenlabs.tts.websockets.connect", new=AsyncMock()) as mock:
|
165 |
-
mock.return_value = mock_websocket
|
166 |
-
tts_service = ElevenLabsTTSServiceWithEndOfSpeech(
|
167 |
-
api_key="test_api_key", voice_id="test_voice_id", sample_rate=16000, channels=1
|
168 |
-
)
|
169 |
-
|
170 |
-
storage = FrameStorage()
|
171 |
-
pipeline = Pipeline([tts_service, storage])
|
172 |
-
|
173 |
-
async def test_routine(task: PipelineTask, test_data=tc_data, s=storage):
|
174 |
-
for frame in test_data["frames_to_send"]:
|
175 |
-
await task.queue_frame(frame)
|
176 |
-
# Wait for all expected frames
|
177 |
-
for expected_frame in test_data["expected_frames"]:
|
178 |
-
await s.wait_for_frame(ignore_ids(expected_frame))
|
179 |
-
print(f"got frame to be sent {expected_frame}")
|
180 |
-
|
181 |
-
# TODO: investigate why we need to cancel here
|
182 |
-
await task.cancel()
|
183 |
-
|
184 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_frame_creation.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for action frame creation and manipulation.
|
5 |
-
|
6 |
-
This module tests the creation, validation, and comparison of various action frames used for bot and user actions.
|
7 |
-
"""
|
8 |
-
|
9 |
-
# ruff: noqa: F405
|
10 |
-
|
11 |
-
from typing import Any
|
12 |
-
|
13 |
-
import pytest
|
14 |
-
from pipecat.frames.frames import TextFrame
|
15 |
-
|
16 |
-
from nvidia_pipecat.frames.action import * # noqa: F403
|
17 |
-
from nvidia_pipecat.frames.action import (
|
18 |
-
FinishedFacialGestureBotActionFrame,
|
19 |
-
StartedFacialGestureBotActionFrame,
|
20 |
-
StartFacialGestureBotActionFrame,
|
21 |
-
StopFacialGestureBotActionFrame,
|
22 |
-
)
|
23 |
-
from tests.unit.utils import ignore_ids
|
24 |
-
|
25 |
-
|
26 |
-
def test_action_frame_basic_usage():
|
27 |
-
"""Tests basic action frame functionality.
|
28 |
-
|
29 |
-
Tests:
|
30 |
-
- Frame creation with parameters
|
31 |
-
- Action ID propagation
|
32 |
-
- Frame name generation
|
33 |
-
- Frame attribute access
|
34 |
-
|
35 |
-
Raises:
|
36 |
-
AssertionError: If frame attributes don't match expected values.
|
37 |
-
"""
|
38 |
-
start_frame = StartFacialGestureBotActionFrame(facial_gesture="wink")
|
39 |
-
action_id = start_frame.action_id
|
40 |
-
|
41 |
-
started_frame = StartedFacialGestureBotActionFrame(action_id=action_id)
|
42 |
-
stop_frame = StopFacialGestureBotActionFrame(action_id=action_id)
|
43 |
-
finished_frame = FinishedFacialGestureBotActionFrame(action_id=action_id)
|
44 |
-
|
45 |
-
assert started_frame.action_id == action_id
|
46 |
-
assert stop_frame.action_id == action_id
|
47 |
-
assert stop_frame.name == "StopFacialGestureBotActionFrame#0"
|
48 |
-
assert finished_frame.action_id == action_id
|
49 |
-
assert finished_frame.name == "FinishedFacialGestureBotActionFrame#0"
|
50 |
-
|
51 |
-
|
52 |
-
def test_required_parameters():
|
53 |
-
"""Tests parameter validation in frame creation.
|
54 |
-
|
55 |
-
Tests:
|
56 |
-
- Required parameter enforcement
|
57 |
-
- Type checking
|
58 |
-
- Error handling
|
59 |
-
|
60 |
-
Raises:
|
61 |
-
TypeError: When required parameters are missing.
|
62 |
-
"""
|
63 |
-
with pytest.raises(TypeError):
|
64 |
-
StartedFacialGestureBotActionFrame() # type: ignore
|
65 |
-
|
66 |
-
|
67 |
-
def test_action_frame_existence():
|
68 |
-
"""Tests frame class contract compliance.
|
69 |
-
|
70 |
-
Tests:
|
71 |
-
- Frame class initialization
|
72 |
-
- Parameter handling
|
73 |
-
- Action ID validation
|
74 |
-
- Frame type verification
|
75 |
-
|
76 |
-
Raises:
|
77 |
-
TypeError: When frame initialization fails.
|
78 |
-
AssertionError: If frame attributes are incorrect.
|
79 |
-
"""
|
80 |
-
actions: dict[str, dict[str, dict[str, Any]]] = {
|
81 |
-
"FacialGestureBotActionFrame": {
|
82 |
-
"Start": {"facial_gesture": "wink"},
|
83 |
-
"Started": {},
|
84 |
-
"Stop": {},
|
85 |
-
"Finished": {},
|
86 |
-
},
|
87 |
-
"GestureBotActionFrame": {
|
88 |
-
"Start": {"gesture": "wave"},
|
89 |
-
"Started": {},
|
90 |
-
"Stop": {},
|
91 |
-
"Finished": {},
|
92 |
-
},
|
93 |
-
"PostureBotActionFrame": {
|
94 |
-
"Start": {"posture": "listening"},
|
95 |
-
"Started": {},
|
96 |
-
"Stop": {},
|
97 |
-
"Finished": {},
|
98 |
-
},
|
99 |
-
"PositionBotActionFrame": {
|
100 |
-
"Start": {"position": "left"},
|
101 |
-
"Started": {},
|
102 |
-
"Stop": {},
|
103 |
-
"Updated": {"position_reached": "left"},
|
104 |
-
"Finished": {},
|
105 |
-
},
|
106 |
-
"AttentionUserActionFrame": {
|
107 |
-
"Started": {"attention_level": "attentive"},
|
108 |
-
"Updated": {"attention_level": "inattentive"},
|
109 |
-
"Finished": {},
|
110 |
-
},
|
111 |
-
"PresenceUserActionFrame": {
|
112 |
-
"Started": {},
|
113 |
-
"Finished": {},
|
114 |
-
},
|
115 |
-
}
|
116 |
-
|
117 |
-
for action_name, frame_type in actions.items():
|
118 |
-
for f, args in frame_type.items():
|
119 |
-
if f == "Start":
|
120 |
-
test = globals()[f"{f}{action_name}"](**args)
|
121 |
-
assert test.name
|
122 |
-
else:
|
123 |
-
with pytest.raises(TypeError):
|
124 |
-
test = globals()[f"{f}{action_name}"]()
|
125 |
-
|
126 |
-
args["action_id"] = "1234"
|
127 |
-
test = globals()[f"{f}{action_name}"](**args)
|
128 |
-
assert test.action_id == "1234"
|
129 |
-
|
130 |
-
|
131 |
-
def test_frame_comparison_ignoring_ids():
|
132 |
-
"""Tests frame comparison with ID ignoring.
|
133 |
-
|
134 |
-
Tests:
|
135 |
-
- Frame equality comparison
|
136 |
-
- ID-independent comparison
|
137 |
-
- Content-based comparison
|
138 |
-
|
139 |
-
Raises:
|
140 |
-
AssertionError: If frame comparison results are incorrect.
|
141 |
-
"""
|
142 |
-
a = TextFrame(text="test")
|
143 |
-
b = TextFrame(text="test")
|
144 |
-
c = TextFrame(text="something")
|
145 |
-
|
146 |
-
assert a != b
|
147 |
-
assert a == ignore_ids(b)
|
148 |
-
assert a != ignore_ids(c)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_gesture.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the FacialGestureProviderProcessor."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from pipecat.frames.frames import (
|
8 |
-
BotStartedSpeakingFrame,
|
9 |
-
BotStoppedSpeakingFrame,
|
10 |
-
StartInterruptionFrame,
|
11 |
-
UserStoppedSpeakingFrame,
|
12 |
-
)
|
13 |
-
from pipecat.tests.utils import SleepFrame
|
14 |
-
|
15 |
-
from nvidia_pipecat.frames.action import StartFacialGestureBotActionFrame
|
16 |
-
from nvidia_pipecat.processors.gesture_provider import FacialGestureProviderProcessor
|
17 |
-
from tests.unit.utils import ignore_ids, run_test
|
18 |
-
|
19 |
-
|
20 |
-
@pytest.mark.asyncio()
|
21 |
-
async def test_gesture_provider_processor_interrupt():
|
22 |
-
"""Test facial gesture generation for bot speech start and interruption.
|
23 |
-
|
24 |
-
Tests that the processor generates appropriate facial gestures when receiving
|
25 |
-
BotStartedSpeakingFrame and StartInterruptionFrame events.
|
26 |
-
|
27 |
-
The test verifies:
|
28 |
-
- Correct handling of BotStartedSpeakingFrame
|
29 |
-
- Correct handling of StartInterruptionFrame
|
30 |
-
- Generation of "Pensive" facial gesture
|
31 |
-
"""
|
32 |
-
frames_to_send = [
|
33 |
-
BotStartedSpeakingFrame(),
|
34 |
-
SleepFrame(0.1),
|
35 |
-
StartInterruptionFrame(),
|
36 |
-
]
|
37 |
-
expected_down_frames = [
|
38 |
-
ignore_ids(BotStartedSpeakingFrame()),
|
39 |
-
ignore_ids(StartInterruptionFrame()),
|
40 |
-
ignore_ids(StartFacialGestureBotActionFrame(facial_gesture="Pensive")),
|
41 |
-
]
|
42 |
-
|
43 |
-
await run_test(
|
44 |
-
FacialGestureProviderProcessor(probability=1.0),
|
45 |
-
frames_to_send=frames_to_send,
|
46 |
-
expected_down_frames=expected_down_frames,
|
47 |
-
)
|
48 |
-
|
49 |
-
|
50 |
-
@pytest.mark.asyncio()
|
51 |
-
async def test_gesture_provider_processor_bot_finished():
|
52 |
-
"""Test facial gesture processing for bot speech completion.
|
53 |
-
|
54 |
-
Tests that the processor handles BotStoppedSpeakingFrame by passing it through
|
55 |
-
without generating additional gestures.
|
56 |
-
|
57 |
-
The test verifies:
|
58 |
-
- Correct passthrough of BotStoppedSpeakingFrame
|
59 |
-
- No additional gesture generation
|
60 |
-
"""
|
61 |
-
frames_to_send = [BotStoppedSpeakingFrame()]
|
62 |
-
expected_down_frames = [
|
63 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
64 |
-
]
|
65 |
-
|
66 |
-
await run_test(
|
67 |
-
FacialGestureProviderProcessor(probability=1.0),
|
68 |
-
frames_to_send=frames_to_send,
|
69 |
-
expected_down_frames=expected_down_frames,
|
70 |
-
)
|
71 |
-
|
72 |
-
|
73 |
-
@pytest.mark.asyncio()
|
74 |
-
async def test_gesture_provider_processor_tts():
|
75 |
-
"""Test facial gesture processing for interruption events.
|
76 |
-
|
77 |
-
Tests that the processor generates a "Taunt" facial gesture when receiving
|
78 |
-
UserStoppedSpeakingFrame events.
|
79 |
-
|
80 |
-
The test verifies:
|
81 |
-
- Correct handling of UserStoppedSpeakingFrame
|
82 |
-
- Generation of "Taunt" facial gesture
|
83 |
-
"""
|
84 |
-
frames_to_send = [UserStoppedSpeakingFrame()]
|
85 |
-
expected_down_frames = [
|
86 |
-
ignore_ids(UserStoppedSpeakingFrame()),
|
87 |
-
ignore_ids(StartFacialGestureBotActionFrame(facial_gesture="Taunt")),
|
88 |
-
]
|
89 |
-
|
90 |
-
await run_test(
|
91 |
-
FacialGestureProviderProcessor(probability=1.0),
|
92 |
-
frames_to_send=frames_to_send,
|
93 |
-
expected_down_frames=expected_down_frames,
|
94 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_guardrail.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the user presence frame processor."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from pipecat.frames.frames import TranscriptionFrame, TTSSpeakFrame
|
8 |
-
from pipecat.utils.time import time_now_iso8601
|
9 |
-
|
10 |
-
from nvidia_pipecat.processors.guardrail import GuardrailProcessor
|
11 |
-
from tests.unit.utils import ignore_ids, run_test
|
12 |
-
|
13 |
-
|
14 |
-
@pytest.mark.asyncio
|
15 |
-
async def test_blocked_word():
|
16 |
-
"""Tests blocking functionality for explicitly blocked words.
|
17 |
-
|
18 |
-
Tests that the processor replaces transcription frames containing blocked words
|
19 |
-
with a TTSSpeakFrame containing a rejection message.
|
20 |
-
|
21 |
-
Args:
|
22 |
-
None
|
23 |
-
|
24 |
-
Returns:
|
25 |
-
None
|
26 |
-
|
27 |
-
The test verifies:
|
28 |
-
- Input containing "football" is blocked
|
29 |
-
- Response is replaced with rejection message
|
30 |
-
"""
|
31 |
-
guardrail_bot = GuardrailProcessor(blocked_words=["football"])
|
32 |
-
frames_to_send = [TranscriptionFrame(text="I love football", user_id="", timestamp=time_now_iso8601())]
|
33 |
-
expected_down_frames = [ignore_ids(TTSSpeakFrame("I am not allowed to answer this question"))]
|
34 |
-
|
35 |
-
await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
36 |
-
|
37 |
-
|
38 |
-
@pytest.mark.asyncio
|
39 |
-
async def test_non_blocked_word():
|
40 |
-
"""Tests passthrough of allowed words.
|
41 |
-
|
42 |
-
Tests that the processor allows transcription frames that don't contain
|
43 |
-
any blocked words to pass through unchanged.
|
44 |
-
|
45 |
-
Args:
|
46 |
-
None
|
47 |
-
|
48 |
-
Returns:
|
49 |
-
None
|
50 |
-
|
51 |
-
The test verifies:
|
52 |
-
- Input without blocked words passes through
|
53 |
-
- Frame content remains unchanged
|
54 |
-
"""
|
55 |
-
guardrail_bot = GuardrailProcessor(blocked_words=["football"])
|
56 |
-
timestamp = time_now_iso8601()
|
57 |
-
frames_to_send = [TranscriptionFrame(text="Tell me about Pasta", user_id="", timestamp=timestamp)]
|
58 |
-
expected_down_frames = [ignore_ids(TranscriptionFrame(text="Tell me about Pasta", user_id="", timestamp=timestamp))]
|
59 |
-
|
60 |
-
await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
61 |
-
|
62 |
-
|
63 |
-
@pytest.mark.asyncio
|
64 |
-
async def test_substring_blocked_word():
|
65 |
-
"""Tests substring matching behavior for blocked words.
|
66 |
-
|
67 |
-
Tests that the processor only blocks exact word matches and allows
|
68 |
-
words that contain blocked words as substrings.
|
69 |
-
|
70 |
-
Args:
|
71 |
-
None
|
72 |
-
|
73 |
-
Returns:
|
74 |
-
None
|
75 |
-
|
76 |
-
The test verifies:
|
77 |
-
- Words containing blocked words as substrings are allowed
|
78 |
-
- Frame content remains unchanged
|
79 |
-
"""
|
80 |
-
guardrail_bot = GuardrailProcessor(blocked_words=["foot"])
|
81 |
-
timestamp = time_now_iso8601()
|
82 |
-
frames_to_send = [TranscriptionFrame(text="I love football", user_id="", timestamp=timestamp)]
|
83 |
-
expected_down_frames = [ignore_ids(TranscriptionFrame(text="I love football", user_id="", timestamp=timestamp))]
|
84 |
-
|
85 |
-
await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
86 |
-
|
87 |
-
|
88 |
-
@pytest.mark.asyncio
|
89 |
-
async def test_no_blocked_word():
|
90 |
-
"""Tests default behavior with no blocked words configured.
|
91 |
-
|
92 |
-
Tests that the processor allows all transcription frames to pass through
|
93 |
-
when no blocked words are specified.
|
94 |
-
|
95 |
-
Args:
|
96 |
-
None
|
97 |
-
|
98 |
-
Returns:
|
99 |
-
None
|
100 |
-
|
101 |
-
The test verifies:
|
102 |
-
- All input passes through when no words are blocked
|
103 |
-
- Frame content remains unchanged
|
104 |
-
"""
|
105 |
-
guardrail_bot = GuardrailProcessor()
|
106 |
-
timestamp = time_now_iso8601()
|
107 |
-
frames_to_send = [TranscriptionFrame(text="What is your name", user_id="", timestamp=timestamp)]
|
108 |
-
expected_down_frames = [ignore_ids(TranscriptionFrame(text="What is your name", user_id="", timestamp=timestamp))]
|
109 |
-
|
110 |
-
await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_message_broker.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the MessageBroker."""
|
5 |
-
|
6 |
-
import asyncio
|
7 |
-
from datetime import timedelta
|
8 |
-
|
9 |
-
import pytest
|
10 |
-
|
11 |
-
from nvidia_pipecat.utils.message_broker import LocalQueueMessageBroker
|
12 |
-
|
13 |
-
|
14 |
-
async def send_task(mb: LocalQueueMessageBroker) -> bool:
|
15 |
-
"""Send test messages to the message broker.
|
16 |
-
|
17 |
-
Args:
|
18 |
-
mb: The LocalQueueMessageBroker instance to send messages to.
|
19 |
-
|
20 |
-
Returns:
|
21 |
-
bool: True if all messages were sent successfully.
|
22 |
-
"""
|
23 |
-
for i in range(5):
|
24 |
-
if i % 2 == 0:
|
25 |
-
await mb.send_message("first", f"message {i}")
|
26 |
-
else:
|
27 |
-
await mb.send_message("second", f"message {i}")
|
28 |
-
await asyncio.sleep(0.08)
|
29 |
-
return True
|
30 |
-
|
31 |
-
|
32 |
-
async def receive_task(mb: LocalQueueMessageBroker) -> list[str]:
|
33 |
-
"""Receive messages from the message broker.
|
34 |
-
|
35 |
-
Args:
|
36 |
-
mb: The LocalQueueMessageBroker instance to receive messages from.
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
list[str]: List of received message data.
|
40 |
-
"""
|
41 |
-
result: list[str] = []
|
42 |
-
while len(result) < 5:
|
43 |
-
messages = await mb.receive_messages(timeout=timedelta(seconds=0.1))
|
44 |
-
for _, message_data in messages:
|
45 |
-
result.append(message_data)
|
46 |
-
return result
|
47 |
-
|
48 |
-
|
49 |
-
@pytest.mark.asyncio()
|
50 |
-
async def test_local_queue_message_broker():
|
51 |
-
"""Tests basic concurrent send/receive functionality.
|
52 |
-
|
53 |
-
Tests that messages can be successfully sent and received when send and
|
54 |
-
receive tasks run concurrently.
|
55 |
-
|
56 |
-
The test verifies:
|
57 |
-
- All messages are received in order
|
58 |
-
- Send operation completes successfully
|
59 |
-
- Messages are correctly routed between channels
|
60 |
-
"""
|
61 |
-
mb = LocalQueueMessageBroker(channels=["first", "second"])
|
62 |
-
task_1 = asyncio.create_task(send_task(mb))
|
63 |
-
task_2 = asyncio.create_task(receive_task(mb))
|
64 |
-
|
65 |
-
results = await asyncio.gather(task_1, task_2)
|
66 |
-
|
67 |
-
assert results == [True, ["message 0", "message 1", "message 2", "message 3", "message 4"]]
|
68 |
-
|
69 |
-
|
70 |
-
@pytest.mark.asyncio()
|
71 |
-
async def test_local_queue_message_broker_receive_first():
|
72 |
-
"""Tests message delivery when receive starts before send.
|
73 |
-
|
74 |
-
Tests that no messages are lost when the receive task is started before
|
75 |
-
any messages are sent.
|
76 |
-
|
77 |
-
The test verifies:
|
78 |
-
- All messages are received in order despite delayed send
|
79 |
-
- Send operation completes successfully
|
80 |
-
- No messages are lost due to timing
|
81 |
-
"""
|
82 |
-
mb = LocalQueueMessageBroker(channels=["first", "second"])
|
83 |
-
task_2 = asyncio.create_task(receive_task(mb))
|
84 |
-
await asyncio.sleep(0.2)
|
85 |
-
task_1 = asyncio.create_task(send_task(mb))
|
86 |
-
|
87 |
-
results = await asyncio.gather(task_1, task_2)
|
88 |
-
|
89 |
-
assert results == [True, ["message 0", "message 1", "message 2", "message 3", "message 4"]]
|
90 |
-
|
91 |
-
|
92 |
-
@pytest.mark.asyncio()
|
93 |
-
async def test_local_queue_message_broker_send_first():
|
94 |
-
"""Tests message delivery when send completes before receive starts.
|
95 |
-
|
96 |
-
Tests that no messages are lost when all messages are sent before the
|
97 |
-
receive task begins.
|
98 |
-
|
99 |
-
The test verifies:
|
100 |
-
- All messages are received in order despite delayed receive
|
101 |
-
- Send operation completes successfully
|
102 |
-
- Messages are properly queued until received
|
103 |
-
"""
|
104 |
-
mb = LocalQueueMessageBroker(channels=["first", "second"])
|
105 |
-
task_1 = asyncio.create_task(send_task(mb))
|
106 |
-
await asyncio.sleep(0.5)
|
107 |
-
task_2 = asyncio.create_task(receive_task(mb))
|
108 |
-
|
109 |
-
results = await asyncio.gather(task_1, task_2)
|
110 |
-
|
111 |
-
assert results == [True, ["message 0", "message 1", "message 2", "message 3", "message 4"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_nvidia_aggregators.py
DELETED
@@ -1,396 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the Nvidia Aggregators."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from pipecat.frames.frames import (
|
8 |
-
LLMFullResponseEndFrame,
|
9 |
-
LLMFullResponseStartFrame,
|
10 |
-
LLMMessagesFrame,
|
11 |
-
StartInterruptionFrame,
|
12 |
-
TextFrame,
|
13 |
-
TranscriptionFrame,
|
14 |
-
UserStartedSpeakingFrame,
|
15 |
-
UserStoppedSpeakingFrame,
|
16 |
-
)
|
17 |
-
from pipecat.pipeline.pipeline import Pipeline
|
18 |
-
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
|
19 |
-
from pipecat.tests.utils import SleepFrame
|
20 |
-
from pipecat.tests.utils import run_test as run_pipecat_test
|
21 |
-
from pipecat.utils.time import time_now_iso8601
|
22 |
-
|
23 |
-
from nvidia_pipecat.frames.riva import RivaInterimTranscriptionFrame
|
24 |
-
from nvidia_pipecat.processors.nvidia_context_aggregator import (
|
25 |
-
NvidiaUserContextAggregator,
|
26 |
-
create_nvidia_context_aggregator,
|
27 |
-
)
|
28 |
-
|
29 |
-
|
30 |
-
@pytest.mark.asyncio()
|
31 |
-
async def test_normal_flow():
|
32 |
-
"""Test the normal flow of user and assistant interactions with interim transcriptions enabled.
|
33 |
-
|
34 |
-
Tests the sequence of events from user speech start through assistant response,
|
35 |
-
verifying proper handling of interim and final transcriptions.
|
36 |
-
|
37 |
-
The test verifies:
|
38 |
-
- User speech start frame handling
|
39 |
-
- Interim transcription processing
|
40 |
-
- Final transcription handling
|
41 |
-
- User speech stop frame handling
|
42 |
-
- Assistant response processing
|
43 |
-
- Context updates at each stage
|
44 |
-
"""
|
45 |
-
messages = []
|
46 |
-
context = OpenAILLMContext(messages)
|
47 |
-
context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
|
48 |
-
|
49 |
-
pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
|
50 |
-
messages.append({"role": "system", "content": "This is system prompt"})
|
51 |
-
# Test Case 1: Normal flow with UserStartedSpeakingFrame first
|
52 |
-
frames_to_send = [
|
53 |
-
UserStartedSpeakingFrame(),
|
54 |
-
LLMMessagesFrame(messages),
|
55 |
-
RivaInterimTranscriptionFrame("Hello", "", time_now_iso8601(), None, stability=1.0),
|
56 |
-
TranscriptionFrame("Hello User Aggregator!", 1, 2),
|
57 |
-
SleepFrame(0.1),
|
58 |
-
UserStoppedSpeakingFrame(),
|
59 |
-
SleepFrame(0.1),
|
60 |
-
]
|
61 |
-
# Assistant response
|
62 |
-
frames_to_send.extend(
|
63 |
-
[
|
64 |
-
LLMFullResponseStartFrame(),
|
65 |
-
TextFrame("Hello Assistant Aggregator!"),
|
66 |
-
LLMFullResponseEndFrame(),
|
67 |
-
]
|
68 |
-
)
|
69 |
-
expected_down_frames = [
|
70 |
-
UserStartedSpeakingFrame,
|
71 |
-
StartInterruptionFrame,
|
72 |
-
StartInterruptionFrame,
|
73 |
-
OpenAILLMContextFrame, # From first interim
|
74 |
-
UserStoppedSpeakingFrame,
|
75 |
-
OpenAILLMContextFrame,
|
76 |
-
]
|
77 |
-
await run_pipecat_test(
|
78 |
-
pipeline,
|
79 |
-
frames_to_send=frames_to_send,
|
80 |
-
expected_down_frames=expected_down_frames,
|
81 |
-
)
|
82 |
-
|
83 |
-
# Verify final context state
|
84 |
-
assert context_aggregator.user().context.get_messages() == [
|
85 |
-
{"role": "user", "content": "Hello User Aggregator!"},
|
86 |
-
{"role": "assistant", "content": "Hello Assistant Aggregator!"},
|
87 |
-
]
|
88 |
-
|
89 |
-
|
90 |
-
@pytest.mark.asyncio()
|
91 |
-
async def test_user_speaking_frame_delay_cases():
|
92 |
-
"""Test handling of transcription frames that arrive before UserStartedSpeakingFrame.
|
93 |
-
|
94 |
-
Tests edge cases around transcription frame timing relative to the
|
95 |
-
UserStartedSpeakingFrame.
|
96 |
-
|
97 |
-
The test verifies:
|
98 |
-
- Interim frames before UserStartedSpeakingFrame are ignored
|
99 |
-
- Low stability interim frames are ignored
|
100 |
-
- Only processes transcriptions after UserStartedSpeakingFrame
|
101 |
-
- Context is updated correctly for valid frames
|
102 |
-
"""
|
103 |
-
messages = []
|
104 |
-
context = OpenAILLMContext(messages)
|
105 |
-
context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
|
106 |
-
|
107 |
-
pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
|
108 |
-
messages.append({"role": "system", "content": "This is system prompt"})
|
109 |
-
|
110 |
-
# Test Case 2: RivaInterimTranscriptionFrames before UserStartedSpeakingFrame
|
111 |
-
frames_to_send = [
|
112 |
-
RivaInterimTranscriptionFrame(
|
113 |
-
"Testing", "", time_now_iso8601(), None, stability=0.5
|
114 |
-
), # Should be ignored (low stability)
|
115 |
-
RivaInterimTranscriptionFrame(
|
116 |
-
"Testing delayed", "", time_now_iso8601(), None, stability=1.0
|
117 |
-
), # Should be ignored (no UserStartedSpeakingFrame yet)
|
118 |
-
SleepFrame(0.5),
|
119 |
-
UserStartedSpeakingFrame(),
|
120 |
-
RivaInterimTranscriptionFrame(
|
121 |
-
"Testing after start", "", time_now_iso8601(), None, stability=1.0
|
122 |
-
), # Should be processed
|
123 |
-
TranscriptionFrame("Testing after start complete", 1, 2),
|
124 |
-
SleepFrame(0.1),
|
125 |
-
UserStoppedSpeakingFrame(),
|
126 |
-
SleepFrame(0.1),
|
127 |
-
]
|
128 |
-
|
129 |
-
# Assistant response
|
130 |
-
frames_to_send.extend(
|
131 |
-
[
|
132 |
-
LLMFullResponseStartFrame(),
|
133 |
-
TextFrame("Hello Assistant Aggregator!"),
|
134 |
-
LLMFullResponseEndFrame(),
|
135 |
-
]
|
136 |
-
)
|
137 |
-
expected_down_frames = [
|
138 |
-
UserStartedSpeakingFrame,
|
139 |
-
StartInterruptionFrame,
|
140 |
-
StartInterruptionFrame,
|
141 |
-
OpenAILLMContextFrame, # from first interim after UserStartedSpeakingFrame
|
142 |
-
UserStoppedSpeakingFrame,
|
143 |
-
OpenAILLMContextFrame,
|
144 |
-
]
|
145 |
-
|
146 |
-
await run_pipecat_test(
|
147 |
-
pipeline,
|
148 |
-
frames_to_send=frames_to_send,
|
149 |
-
expected_down_frames=expected_down_frames,
|
150 |
-
)
|
151 |
-
|
152 |
-
# Verify final context state
|
153 |
-
assert context_aggregator.user().context.get_messages() == [
|
154 |
-
{"role": "user", "content": "Testing after start complete"},
|
155 |
-
{"role": "assistant", "content": "Hello Assistant Aggregator!"},
|
156 |
-
]
|
157 |
-
|
158 |
-
|
159 |
-
@pytest.mark.asyncio()
|
160 |
-
async def test_multiple_interims_with_final_transcription():
|
161 |
-
"""Test handling of multiple interim transcription frames followed by a final transcription.
|
162 |
-
|
163 |
-
Tests the processing of a sequence of interim transcriptions followed by
|
164 |
-
a final transcription.
|
165 |
-
|
166 |
-
The test verifies:
|
167 |
-
- Multiple interim transcriptions are processed correctly
|
168 |
-
- Final transcription properly overwrites previous interims
|
169 |
-
- Context updates occur for each valid frame
|
170 |
-
- Message history maintains correct order
|
171 |
-
"""
|
172 |
-
messages = []
|
173 |
-
context = OpenAILLMContext(messages)
|
174 |
-
context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
|
175 |
-
|
176 |
-
pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
|
177 |
-
messages.append({"role": "system", "content": "This is system prompt"})
|
178 |
-
|
179 |
-
# Test Case 3: Multiple interim frames with final transcription
|
180 |
-
frames_to_send = [
|
181 |
-
UserStartedSpeakingFrame(),
|
182 |
-
RivaInterimTranscriptionFrame("Hello", "", time_now_iso8601(), None, stability=1.0),
|
183 |
-
RivaInterimTranscriptionFrame("Hello Again", "", time_now_iso8601(), None, stability=1.0),
|
184 |
-
RivaInterimTranscriptionFrame("Hello Again User", "", time_now_iso8601(), None, stability=1.0),
|
185 |
-
TranscriptionFrame("Hello Again User Aggregator!", 1, 2),
|
186 |
-
SleepFrame(0.1),
|
187 |
-
UserStoppedSpeakingFrame(),
|
188 |
-
SleepFrame(0.1),
|
189 |
-
]
|
190 |
-
|
191 |
-
# Assistant response
|
192 |
-
frames_to_send.extend(
|
193 |
-
[
|
194 |
-
LLMFullResponseStartFrame(),
|
195 |
-
TextFrame("Hello Assistant Aggregator!"),
|
196 |
-
LLMFullResponseEndFrame(),
|
197 |
-
]
|
198 |
-
)
|
199 |
-
expected_down_frames = [
|
200 |
-
UserStartedSpeakingFrame,
|
201 |
-
StartInterruptionFrame,
|
202 |
-
StartInterruptionFrame,
|
203 |
-
StartInterruptionFrame,
|
204 |
-
StartInterruptionFrame,
|
205 |
-
OpenAILLMContextFrame, # From final transcription
|
206 |
-
UserStoppedSpeakingFrame,
|
207 |
-
OpenAILLMContextFrame,
|
208 |
-
]
|
209 |
-
|
210 |
-
await run_pipecat_test(
|
211 |
-
pipeline,
|
212 |
-
frames_to_send=frames_to_send,
|
213 |
-
expected_down_frames=expected_down_frames,
|
214 |
-
)
|
215 |
-
|
216 |
-
# Verify final context state
|
217 |
-
assert context_aggregator.user().context.get_messages() == [
|
218 |
-
{"role": "user", "content": "Hello Again User Aggregator!"},
|
219 |
-
{"role": "assistant", "content": "Hello Assistant Aggregator!"},
|
220 |
-
]
|
221 |
-
|
222 |
-
|
223 |
-
@pytest.mark.asyncio()
|
224 |
-
async def test_transcription_after_user_stopped_speaking():
|
225 |
-
"""Tests handling of late transcription frames.
|
226 |
-
|
227 |
-
Tests behavior when transcription frames arrive after UserStoppedSpeakingFrame.
|
228 |
-
|
229 |
-
The test verifies:
|
230 |
-
- Late transcriptions are still processed
|
231 |
-
- Context is updated with final transcription
|
232 |
-
- Assistant responses are handled correctly
|
233 |
-
- Message history maintains proper sequence
|
234 |
-
"""
|
235 |
-
messages = []
|
236 |
-
context = OpenAILLMContext(messages)
|
237 |
-
context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
|
238 |
-
|
239 |
-
pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
|
240 |
-
messages.append({"role": "system", "content": "This is system prompt"})
|
241 |
-
|
242 |
-
# Test Case 4: TranscriptionFrame after UserStoppedSpeakingFrame
|
243 |
-
frames_to_send = [
|
244 |
-
UserStartedSpeakingFrame(),
|
245 |
-
RivaInterimTranscriptionFrame("Late", "", time_now_iso8601(), None, stability=1.0),
|
246 |
-
SleepFrame(0.1),
|
247 |
-
UserStoppedSpeakingFrame(),
|
248 |
-
SleepFrame(0.1),
|
249 |
-
TranscriptionFrame("Late transcription!", 1, 2),
|
250 |
-
SleepFrame(0.1),
|
251 |
-
]
|
252 |
-
|
253 |
-
# Assistant response
|
254 |
-
frames_to_send.extend(
|
255 |
-
[
|
256 |
-
LLMFullResponseStartFrame(),
|
257 |
-
TextFrame("Hello Assistant Aggregator!"),
|
258 |
-
LLMFullResponseEndFrame(),
|
259 |
-
]
|
260 |
-
)
|
261 |
-
|
262 |
-
expected_down_frames = [
|
263 |
-
UserStartedSpeakingFrame,
|
264 |
-
StartInterruptionFrame,
|
265 |
-
OpenAILLMContextFrame, # From first interim
|
266 |
-
UserStoppedSpeakingFrame,
|
267 |
-
StartInterruptionFrame,
|
268 |
-
OpenAILLMContextFrame, # From final after UserStoppedSpeakingFrame
|
269 |
-
OpenAILLMContextFrame, # From assistant response
|
270 |
-
]
|
271 |
-
|
272 |
-
await run_pipecat_test(
|
273 |
-
pipeline,
|
274 |
-
frames_to_send=frames_to_send,
|
275 |
-
expected_down_frames=expected_down_frames,
|
276 |
-
)
|
277 |
-
|
278 |
-
# Verify final context state
|
279 |
-
assert context_aggregator.user().context.get_messages() == [
|
280 |
-
{"role": "user", "content": "Late transcription!"},
|
281 |
-
{"role": "assistant", "content": "Hello Assistant Aggregator!"},
|
282 |
-
]
|
283 |
-
|
284 |
-
|
285 |
-
@pytest.mark.asyncio()
|
286 |
-
async def test_no_interim_frames():
|
287 |
-
"""Tests behavior when interim frames are disabled.
|
288 |
-
|
289 |
-
Tests the aggregator's handling of transcriptions when send_interims=False.
|
290 |
-
|
291 |
-
The test verifies:
|
292 |
-
- Interim frames are ignored
|
293 |
-
- Only final transcription is processed
|
294 |
-
- System prompts are preserved
|
295 |
-
- Context updates occur only for final transcription
|
296 |
-
- Assistant responses are processed correctly
|
297 |
-
"""
|
298 |
-
messages = [{"role": "system", "content": "This is system prompt"}]
|
299 |
-
context = OpenAILLMContext(messages)
|
300 |
-
context_aggregator = create_nvidia_context_aggregator(context, send_interims=False)
|
301 |
-
pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
|
302 |
-
|
303 |
-
frames_to_send = [
|
304 |
-
UserStartedSpeakingFrame(),
|
305 |
-
LLMMessagesFrame(messages),
|
306 |
-
# These interim frames should be ignored due to send_interims=False
|
307 |
-
RivaInterimTranscriptionFrame("Hello", "", time_now_iso8601(), None, stability=1.0),
|
308 |
-
RivaInterimTranscriptionFrame("Hello there", "", time_now_iso8601(), None, stability=1.0),
|
309 |
-
RivaInterimTranscriptionFrame("Hello there user", "", time_now_iso8601(), None, stability=1.0),
|
310 |
-
# Only the final transcription should be processed
|
311 |
-
TranscriptionFrame("Hello there user final!", 1, 2),
|
312 |
-
SleepFrame(0.1),
|
313 |
-
UserStoppedSpeakingFrame(),
|
314 |
-
SleepFrame(0.1),
|
315 |
-
# Assistant response
|
316 |
-
LLMFullResponseStartFrame(),
|
317 |
-
TextFrame("Hello from assistant!"),
|
318 |
-
LLMFullResponseEndFrame(),
|
319 |
-
]
|
320 |
-
|
321 |
-
expected_down_frames = [
|
322 |
-
UserStartedSpeakingFrame,
|
323 |
-
StartInterruptionFrame,
|
324 |
-
OpenAILLMContextFrame, # Only from final transcription
|
325 |
-
UserStoppedSpeakingFrame,
|
326 |
-
OpenAILLMContextFrame, # From assistant response
|
327 |
-
]
|
328 |
-
|
329 |
-
await run_pipecat_test(
|
330 |
-
pipeline,
|
331 |
-
frames_to_send=frames_to_send,
|
332 |
-
expected_down_frames=expected_down_frames,
|
333 |
-
)
|
334 |
-
|
335 |
-
# Verify final context state
|
336 |
-
assert context_aggregator.user().context.get_messages() == [
|
337 |
-
{"role": "system", "content": "This is system prompt"},
|
338 |
-
{"role": "user", "content": "Hello there user final!"},
|
339 |
-
{"role": "assistant", "content": "Hello from assistant!"},
|
340 |
-
]
|
341 |
-
|
342 |
-
|
343 |
-
@pytest.mark.asyncio()
|
344 |
-
async def test_get_truncated_context():
|
345 |
-
"""Tests context truncation functionality.
|
346 |
-
|
347 |
-
Tests the get_truncated_context() method of NvidiaUserContextAggregator
|
348 |
-
with a specified chat history limit.
|
349 |
-
|
350 |
-
Args:
|
351 |
-
None
|
352 |
-
|
353 |
-
Returns:
|
354 |
-
None
|
355 |
-
|
356 |
-
The test verifies:
|
357 |
-
- Context is truncated to specified limit
|
358 |
-
- System prompt is preserved
|
359 |
-
- Most recent messages are retained
|
360 |
-
- Message order is maintained
|
361 |
-
"""
|
362 |
-
messages = [
|
363 |
-
{"role": "system", "content": "This is system prompt"},
|
364 |
-
{"role": "user", "content": "Hi, there!"},
|
365 |
-
{"role": "assistant", "content": "Hello, how may I assist you?"},
|
366 |
-
{"role": "user", "content": "How to be more productive?"},
|
367 |
-
{"role": "assistant", "content": "Priotize the tasks, make a list..."},
|
368 |
-
{"role": "user", "content": "What is metaverse?"},
|
369 |
-
{
|
370 |
-
"role": "assistant",
|
371 |
-
"content": "The metaverse is envisioned as a digital ecosystem built on virtual 3D technology",
|
372 |
-
},
|
373 |
-
{
|
374 |
-
"role": "assistant",
|
375 |
-
"content": "It leverages 3D technology and digital"
|
376 |
-
"representation for creating virtual environments and user experiences",
|
377 |
-
},
|
378 |
-
{"role": "user", "content": "thanks, Bye!"},
|
379 |
-
]
|
380 |
-
context = OpenAILLMContext(messages)
|
381 |
-
user = NvidiaUserContextAggregator(context=context, chat_history_limit=2)
|
382 |
-
truncated_context = await user.get_truncated_context()
|
383 |
-
assert truncated_context.get_messages() == [
|
384 |
-
{"role": "system", "content": "This is system prompt"},
|
385 |
-
{"role": "user", "content": "What is metaverse?"},
|
386 |
-
{
|
387 |
-
"role": "assistant",
|
388 |
-
"content": "The metaverse is envisioned as a digital ecosystem built on virtual 3D technology",
|
389 |
-
},
|
390 |
-
{
|
391 |
-
"role": "assistant",
|
392 |
-
"content": "It leverages 3D technology and digital"
|
393 |
-
"representation for creating virtual environments and user experiences",
|
394 |
-
},
|
395 |
-
{"role": "user", "content": "thanks, Bye!"},
|
396 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_nvidia_llm_service.py
DELETED
@@ -1,386 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the NvidiaLLMService.
|
5 |
-
|
6 |
-
This module contains tests for the NvidiaLLMService class, focusing on core functionalities:
|
7 |
-
- Think token filtering (including split tag handling)
|
8 |
-
- Mistral message preprocessing
|
9 |
-
- Token usage tracking
|
10 |
-
- LLM responses and function calls
|
11 |
-
"""
|
12 |
-
|
13 |
-
from unittest.mock import DEFAULT, AsyncMock, patch
|
14 |
-
|
15 |
-
import pytest
|
16 |
-
from pipecat.frames.frames import LLMTextFrame
|
17 |
-
from pipecat.metrics.metrics import LLMTokenUsage
|
18 |
-
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
19 |
-
|
20 |
-
from nvidia_pipecat.services.nvidia_llm import NvidiaLLMService
|
21 |
-
|
22 |
-
|
23 |
-
# Custom mocks that mimic OpenAI classes without inheriting from them
|
24 |
-
class MockCompletionUsage:
|
25 |
-
"""Mock for CompletionUsage that mimics the structure."""
|
26 |
-
|
27 |
-
def __init__(self, prompt_tokens, completion_tokens, total_tokens):
|
28 |
-
"""Initialize with token usage counts.
|
29 |
-
|
30 |
-
Args:
|
31 |
-
prompt_tokens: Number of tokens in the prompt
|
32 |
-
completion_tokens: Number of tokens in the completion
|
33 |
-
total_tokens: Total number of tokens
|
34 |
-
"""
|
35 |
-
self.prompt_tokens = prompt_tokens
|
36 |
-
self.completion_tokens = completion_tokens
|
37 |
-
self.total_tokens = total_tokens
|
38 |
-
|
39 |
-
|
40 |
-
class MockChoiceDelta:
|
41 |
-
"""Mock for ChoiceDelta that mimics the structure."""
|
42 |
-
|
43 |
-
def __init__(self, content=None, tool_calls=None):
|
44 |
-
"""Initialize with optional content and tool calls.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
content: The text content of the delta
|
48 |
-
tool_calls: List of tool calls in the delta
|
49 |
-
"""
|
50 |
-
self.content = content
|
51 |
-
self.tool_calls = tool_calls
|
52 |
-
self.function_call = None
|
53 |
-
self.role = None
|
54 |
-
|
55 |
-
|
56 |
-
class MockChoice:
|
57 |
-
"""Mock for Choice that mimics the structure."""
|
58 |
-
|
59 |
-
def __init__(self, delta, index=0, finish_reason=None):
|
60 |
-
"""Initialize with delta, index, and finish reason.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
delta: The delta containing content or tool calls
|
64 |
-
index: The index of this choice
|
65 |
-
finish_reason: Reason for finishing generation
|
66 |
-
"""
|
67 |
-
self.delta = delta
|
68 |
-
self.index = index
|
69 |
-
self.finish_reason = finish_reason
|
70 |
-
|
71 |
-
|
72 |
-
class MockChatCompletionChunk:
|
73 |
-
"""Mock for ChatCompletionChunk that mimics the structure."""
|
74 |
-
|
75 |
-
def __init__(self, content=None, usage=None, id="mock-id", tool_calls=None):
|
76 |
-
"""Initialize a mock of ChatCompletionChunk.
|
77 |
-
|
78 |
-
Args:
|
79 |
-
content: The text content in the chunk
|
80 |
-
usage: Token usage information
|
81 |
-
id: Chunk identifier
|
82 |
-
tool_calls: Any tool calls in the chunk
|
83 |
-
"""
|
84 |
-
self.id = id
|
85 |
-
self.model = "mock-model"
|
86 |
-
self.object = "chat.completion.chunk"
|
87 |
-
self.created = 1234567890
|
88 |
-
self.usage = usage
|
89 |
-
|
90 |
-
if tool_calls:
|
91 |
-
self.choices = [MockChoice(MockChoiceDelta(tool_calls=tool_calls))]
|
92 |
-
else:
|
93 |
-
self.choices = [MockChoice(MockChoiceDelta(content=content))]
|
94 |
-
|
95 |
-
|
96 |
-
class MockToolCall:
|
97 |
-
"""Mock for ToolCall."""
|
98 |
-
|
99 |
-
def __init__(self, id="tool-id", function=None, index=0, type="function"):
|
100 |
-
"""Initialize a tool call.
|
101 |
-
|
102 |
-
Args:
|
103 |
-
id: Tool call identifier
|
104 |
-
function: The function being called
|
105 |
-
index: Index of this tool call
|
106 |
-
type: Type of tool call
|
107 |
-
"""
|
108 |
-
self.id = id
|
109 |
-
self.function = function
|
110 |
-
self.index = index
|
111 |
-
self.type = type
|
112 |
-
|
113 |
-
|
114 |
-
class MockFunction:
|
115 |
-
"""Mock for Function in a tool call."""
|
116 |
-
|
117 |
-
def __init__(self, name="", arguments=""):
|
118 |
-
"""Initialize a function with name and arguments.
|
119 |
-
|
120 |
-
Args:
|
121 |
-
name: Name of the function
|
122 |
-
arguments: JSON string of function arguments
|
123 |
-
"""
|
124 |
-
self.name = name
|
125 |
-
self.arguments = arguments
|
126 |
-
|
127 |
-
|
128 |
-
class MockAsyncStream:
|
129 |
-
"""Mock implementation of AsyncStream for testing."""
|
130 |
-
|
131 |
-
def __init__(self, chunks):
|
132 |
-
"""Initialize with a list of chunks to yield.
|
133 |
-
|
134 |
-
Args:
|
135 |
-
chunks: List of chunks to return when iterating
|
136 |
-
"""
|
137 |
-
self.chunks = chunks
|
138 |
-
|
139 |
-
def __aiter__(self):
|
140 |
-
"""Return self as an async iterator."""
|
141 |
-
return self
|
142 |
-
|
143 |
-
async def __anext__(self):
|
144 |
-
"""Return the next chunk or raise StopAsyncIteration."""
|
145 |
-
if not self.chunks:
|
146 |
-
raise StopAsyncIteration
|
147 |
-
return self.chunks.pop(0)
|
148 |
-
|
149 |
-
|
150 |
-
@pytest.mark.asyncio
|
151 |
-
async def test_mistral_message_preprocessing():
|
152 |
-
"""Test the Mistral message preprocessing functionality."""
|
153 |
-
service = NvidiaLLMService(api_key="test_api_key", mistral_model_support=True)
|
154 |
-
|
155 |
-
# Test with alternating roles (already valid)
|
156 |
-
messages = [
|
157 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
158 |
-
{"role": "user", "content": "Hello"},
|
159 |
-
{"role": "assistant", "content": "Hi there!"},
|
160 |
-
{"role": "user", "content": "How are you?"},
|
161 |
-
]
|
162 |
-
processed = service._preprocess_messages_for_mistral(messages)
|
163 |
-
assert len(processed) == len(messages) # No changes needed
|
164 |
-
|
165 |
-
# Test with consecutive messages from same role
|
166 |
-
messages = [
|
167 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
168 |
-
{"role": "user", "content": "Hello"},
|
169 |
-
{"role": "user", "content": "How are you?"},
|
170 |
-
]
|
171 |
-
processed = service._preprocess_messages_for_mistral(messages)
|
172 |
-
assert len(processed) == 2 # System + combined user
|
173 |
-
assert processed[1]["role"] == "user"
|
174 |
-
assert processed[1]["content"] == "Hello How are you?"
|
175 |
-
|
176 |
-
# Test with system message at end
|
177 |
-
messages = [
|
178 |
-
{"role": "user", "content": "Hello"},
|
179 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
180 |
-
]
|
181 |
-
processed = service._preprocess_messages_for_mistral(messages)
|
182 |
-
assert len(processed) == 2 # User + system
|
183 |
-
assert processed[0]["role"] == "user"
|
184 |
-
assert processed[1]["role"] == "system"
|
185 |
-
|
186 |
-
|
187 |
-
@pytest.mark.asyncio
|
188 |
-
async def test_filter_think_token_simple():
|
189 |
-
"""Test the basic think token filtering functionality."""
|
190 |
-
service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
|
191 |
-
service._reset_think_filter_state()
|
192 |
-
|
193 |
-
# Test with simple think content followed by real content
|
194 |
-
content = "I'm thinking about the answer</think>This is the actual response"
|
195 |
-
filtered = service._filter_think_token(content)
|
196 |
-
assert filtered == "This is the actual response"
|
197 |
-
assert service._seen_end_tag is True
|
198 |
-
|
199 |
-
# Subsequent content should pass through untouched
|
200 |
-
more_content = " and some more text"
|
201 |
-
filtered = service._filter_think_token(more_content)
|
202 |
-
assert filtered == " and some more text"
|
203 |
-
|
204 |
-
|
205 |
-
@pytest.mark.asyncio
|
206 |
-
async def test_filter_think_token_split():
|
207 |
-
"""Test the think token filtering with split tags."""
|
208 |
-
service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
|
209 |
-
service._reset_think_filter_state()
|
210 |
-
|
211 |
-
# First part with beginning of tag
|
212 |
-
content1 = "Let me think about this problem<"
|
213 |
-
filtered1 = service._filter_think_token(content1)
|
214 |
-
assert filtered1 == "" # No output yet
|
215 |
-
assert service._partial_tag_buffer == "<" # Partial tag saved
|
216 |
-
|
217 |
-
# Second part with rest of tag and response
|
218 |
-
content2 = "/think>Here's the answer"
|
219 |
-
filtered2 = service._filter_think_token(content2)
|
220 |
-
assert filtered2 == "Here's the answer" # Output after tag
|
221 |
-
assert service._seen_end_tag is True
|
222 |
-
|
223 |
-
|
224 |
-
@pytest.mark.asyncio
|
225 |
-
async def test_filter_think_token_no_tag():
|
226 |
-
"""Test what happens when no think tag is found."""
|
227 |
-
service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
|
228 |
-
service._reset_think_filter_state()
|
229 |
-
|
230 |
-
# Add some content in multiple chunks
|
231 |
-
filtered1 = service._filter_think_token("This is a response")
|
232 |
-
filtered2 = service._filter_think_token(" with no think tag")
|
233 |
-
# Verify filtering behavior
|
234 |
-
assert filtered1 == filtered2 == "" # No output during filtering
|
235 |
-
assert service._thinking_aggregation == "This is a response with no think tag"
|
236 |
-
# Test end-of-processing behavior
|
237 |
-
service.push_frame = AsyncMock()
|
238 |
-
await service.push_frame(LLMTextFrame(service._thinking_aggregation))
|
239 |
-
service._reset_think_filter_state()
|
240 |
-
# Verify results
|
241 |
-
service.push_frame.assert_called_once()
|
242 |
-
assert service.push_frame.call_args.args[0].text == "This is a response with no think tag"
|
243 |
-
assert service._thinking_aggregation == "" # State was reset
|
244 |
-
|
245 |
-
|
246 |
-
@pytest.mark.asyncio
|
247 |
-
async def test_token_usage_tracking():
|
248 |
-
"""Test the token usage tracking functionality."""
|
249 |
-
service = NvidiaLLMService(api_key="test_api_key")
|
250 |
-
service._is_processing = True
|
251 |
-
|
252 |
-
# Test initial accumulation of prompt tokens
|
253 |
-
tokens1 = LLMTokenUsage(prompt_tokens=10, completion_tokens=0, total_tokens=10)
|
254 |
-
await service.start_llm_usage_metrics(tokens1)
|
255 |
-
assert service._prompt_tokens == 10
|
256 |
-
assert service._has_reported_prompt_tokens is True
|
257 |
-
|
258 |
-
# Test incremental completion tokens
|
259 |
-
tokens2 = LLMTokenUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
260 |
-
await service.start_llm_usage_metrics(tokens2)
|
261 |
-
assert service._completion_tokens == 5
|
262 |
-
|
263 |
-
# Test more completion tokens
|
264 |
-
tokens3 = LLMTokenUsage(prompt_tokens=10, completion_tokens=8, total_tokens=18)
|
265 |
-
await service.start_llm_usage_metrics(tokens3)
|
266 |
-
assert service._completion_tokens == 8
|
267 |
-
|
268 |
-
# Test reporting duplicate prompt tokens
|
269 |
-
tokens4 = LLMTokenUsage(prompt_tokens=10, completion_tokens=10, total_tokens=20)
|
270 |
-
await service.start_llm_usage_metrics(tokens4)
|
271 |
-
assert service._prompt_tokens == 10 # Unchanged
|
272 |
-
assert service._completion_tokens == 10
|
273 |
-
|
274 |
-
|
275 |
-
@pytest.mark.asyncio
|
276 |
-
async def test_process_context_with_think_filtering():
|
277 |
-
"""Test the full processing with think token filtering."""
|
278 |
-
with patch.multiple(
|
279 |
-
NvidiaLLMService,
|
280 |
-
create_client=DEFAULT,
|
281 |
-
_stream_chat_completions=DEFAULT,
|
282 |
-
start_ttfb_metrics=DEFAULT,
|
283 |
-
stop_ttfb_metrics=DEFAULT,
|
284 |
-
push_frame=DEFAULT,
|
285 |
-
) as mocks:
|
286 |
-
service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
|
287 |
-
mock_push_frame = mocks["push_frame"]
|
288 |
-
|
289 |
-
# Setup mock stream
|
290 |
-
chunks = [
|
291 |
-
MockChatCompletionChunk(content="Thinking<"),
|
292 |
-
MockChatCompletionChunk(content="/think>Real content"),
|
293 |
-
MockChatCompletionChunk(content=" continues"),
|
294 |
-
]
|
295 |
-
mocks["_stream_chat_completions"].return_value = MockAsyncStream(chunks)
|
296 |
-
|
297 |
-
# Process context
|
298 |
-
context = OpenAILLMContext(messages=[{"role": "user", "content": "Test query"}])
|
299 |
-
await service._process_context(context)
|
300 |
-
|
301 |
-
# Verify frame content - empty frames during thinking, content after tag
|
302 |
-
frames = [call.args[0].text for call in mock_push_frame.call_args_list]
|
303 |
-
assert frames == ["", "Real content", " continues"]
|
304 |
-
|
305 |
-
|
306 |
-
@pytest.mark.asyncio
|
307 |
-
async def test_process_context_with_function_calls():
|
308 |
-
"""Test handling of function calls from LLM."""
|
309 |
-
with (
|
310 |
-
patch.object(NvidiaLLMService, "create_client"),
|
311 |
-
patch.object(NvidiaLLMService, "_stream_chat_completions") as mock_stream,
|
312 |
-
patch.object(NvidiaLLMService, "has_function") as mock_has_function,
|
313 |
-
patch.object(NvidiaLLMService, "call_function") as mock_call_function,
|
314 |
-
):
|
315 |
-
service = NvidiaLLMService(api_key="test_api_key")
|
316 |
-
|
317 |
-
# Create tool call chunks that come in parts
|
318 |
-
tool_call1 = MockToolCall(
|
319 |
-
id="call1", function=MockFunction(name="get_weather", arguments='{"location"'), index=0
|
320 |
-
)
|
321 |
-
|
322 |
-
tool_call2 = MockToolCall(id="call1", function=MockFunction(name="", arguments=':"New York"}'), index=0)
|
323 |
-
|
324 |
-
# Create chunks with tool calls
|
325 |
-
chunk1 = MockChatCompletionChunk(tool_calls=[tool_call1], id="chunk1")
|
326 |
-
chunk2 = MockChatCompletionChunk(tool_calls=[tool_call2], id="chunk2")
|
327 |
-
|
328 |
-
mock_stream.return_value = MockAsyncStream([chunk1, chunk2])
|
329 |
-
mock_has_function.return_value = True
|
330 |
-
mock_call_function.return_value = None
|
331 |
-
|
332 |
-
# Process a context
|
333 |
-
context = OpenAILLMContext(messages=[{"role": "user", "content": "What's the weather in New York?"}])
|
334 |
-
await service._process_context(context)
|
335 |
-
|
336 |
-
# Verify function was called with combined arguments
|
337 |
-
mock_call_function.assert_called_once()
|
338 |
-
args = mock_call_function.call_args.kwargs
|
339 |
-
assert args["function_name"] == "get_weather"
|
340 |
-
assert args["arguments"] == {"location": "New York"}
|
341 |
-
assert args["tool_call_id"] == "call1"
|
342 |
-
|
343 |
-
|
344 |
-
@pytest.mark.asyncio
|
345 |
-
async def test_process_context_with_mistral_preprocessing():
|
346 |
-
"""Test processing context with Mistral message preprocessing."""
|
347 |
-
with (
|
348 |
-
patch.object(NvidiaLLMService, "create_client"),
|
349 |
-
patch.object(NvidiaLLMService, "_stream_chat_completions") as mock_stream,
|
350 |
-
):
|
351 |
-
service = NvidiaLLMService(api_key="test_api_key", mistral_model_support=True)
|
352 |
-
|
353 |
-
# Setup mock stream
|
354 |
-
chunks = [MockChatCompletionChunk(content="I am a response")]
|
355 |
-
mock_stream.return_value = MockAsyncStream(chunks)
|
356 |
-
|
357 |
-
# Test 1: Combining consecutive user messages
|
358 |
-
context = OpenAILLMContext(
|
359 |
-
messages=[
|
360 |
-
{"role": "system", "content": "You are helpful."},
|
361 |
-
{"role": "user", "content": "Hello"},
|
362 |
-
{"role": "user", "content": "How are you?"},
|
363 |
-
]
|
364 |
-
)
|
365 |
-
await service._process_context(context)
|
366 |
-
|
367 |
-
# Verify messages were combined
|
368 |
-
processed_messages = context.get_messages()
|
369 |
-
assert len(processed_messages) == 2 # System + combined user
|
370 |
-
assert processed_messages[1]["role"] == "user"
|
371 |
-
assert processed_messages[1]["content"] == "Hello How are you?"
|
372 |
-
|
373 |
-
# Verify stream was called (normal processing)
|
374 |
-
mock_stream.assert_called_once()
|
375 |
-
|
376 |
-
# Test 2: System message only - should skip processing
|
377 |
-
mock_stream.reset_mock()
|
378 |
-
system_only_context = OpenAILLMContext(
|
379 |
-
messages=[
|
380 |
-
{"role": "system", "content": "You are helpful."},
|
381 |
-
]
|
382 |
-
)
|
383 |
-
await service._process_context(system_only_context)
|
384 |
-
|
385 |
-
# Verify that stream was not called (processing skipped)
|
386 |
-
mock_stream.assert_not_called()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_nvidia_rag_service.py
DELETED
@@ -1,261 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the NvidiaRAGService processor."""
|
5 |
-
|
6 |
-
import httpx
|
7 |
-
import pytest
|
8 |
-
from loguru import logger
|
9 |
-
from pipecat.frames.frames import ErrorFrame, LLMMessagesFrame, TextFrame
|
10 |
-
from pipecat.pipeline.pipeline import Pipeline
|
11 |
-
from pipecat.pipeline.task import PipelineTask
|
12 |
-
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
13 |
-
|
14 |
-
from nvidia_pipecat.frames.nvidia_rag import NvidiaRAGCitation, NvidiaRAGCitationsFrame, NvidiaRAGSettingsFrame
|
15 |
-
from nvidia_pipecat.services.nvidia_rag import NvidiaRAGService
|
16 |
-
from tests.unit.utils import FrameStorage, ignore_ids, run_interactive_test, run_test
|
17 |
-
|
18 |
-
|
19 |
-
class MockResponse:
|
20 |
-
"""Mock response class for testing HTTP responses.
|
21 |
-
|
22 |
-
Attributes:
|
23 |
-
json: The JSON response data.
|
24 |
-
"""
|
25 |
-
|
26 |
-
def __init__(self, json):
|
27 |
-
"""Initialize MockResponse with JSON data.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
json: The JSON response data.
|
31 |
-
"""
|
32 |
-
self.json = json
|
33 |
-
|
34 |
-
async def aclose(self):
|
35 |
-
"""Mock aclose method."""
|
36 |
-
pass
|
37 |
-
|
38 |
-
async def aiter_lines(self):
|
39 |
-
"""Simulate chunk iteration for streaming response.
|
40 |
-
|
41 |
-
Yields:
|
42 |
-
tuple: A tuple containing data and None.
|
43 |
-
"""
|
44 |
-
yield self.json
|
45 |
-
|
46 |
-
|
47 |
-
@pytest.mark.asyncio
|
48 |
-
async def test_nvidia_rag_service(mocker):
|
49 |
-
"""Test NvidiaRAGService functionality with various test cases.
|
50 |
-
|
51 |
-
Tests different RAG service behaviors including successful responses,
|
52 |
-
citation handling, and error conditions.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
mocker: Pytest mocker fixture for mocking HTTP responses.
|
56 |
-
|
57 |
-
The test verifies:
|
58 |
-
- Successful responses without citations
|
59 |
-
- Successful responses with citations
|
60 |
-
- Error handling for empty collection names
|
61 |
-
- Error handling for empty queries
|
62 |
-
- Error handling for incorrect message roles
|
63 |
-
"""
|
64 |
-
testcases = {
|
65 |
-
"Success without citations": {
|
66 |
-
"collection_name": "collection123",
|
67 |
-
"messages": [
|
68 |
-
{
|
69 |
-
"role": "system",
|
70 |
-
"content": "You are a helpful Large Language Model. "
|
71 |
-
"Your goal is to demonstrate your capabilities in a succinct way. "
|
72 |
-
"Your output will be converted to audio so don't include special characters in your answers. "
|
73 |
-
"Respond to what the user said in a creative and helpful way.",
|
74 |
-
},
|
75 |
-
],
|
76 |
-
"response_json": 'data: {"id":"a886cc44-e2ce-4ea3-95f0-9ffb1171adb1",'
|
77 |
-
'"choices":[{"index":0,"message":{"role":"assistant","content":"this is rag response content"},'
|
78 |
-
'"delta":{"role":"assistant","content":""},"finish_reason":"[DONE]"}]}',
|
79 |
-
"result_frame": TextFrame("this is rag response content"),
|
80 |
-
},
|
81 |
-
"Success with citations": {
|
82 |
-
"collection_name": "collection123",
|
83 |
-
"messages": [
|
84 |
-
{
|
85 |
-
"role": "system",
|
86 |
-
"content": "You are a helpful Large Language Model. "
|
87 |
-
"Your goal is to demonstrate your capabilities in a succinct way. "
|
88 |
-
"Your output will be converted to audio so don't include special characters in your answers. "
|
89 |
-
"Respond to what the user said in a creative and helpful way.",
|
90 |
-
},
|
91 |
-
],
|
92 |
-
"response_json": 'data: {"id":"a886cc44-e2ce-4ea3-95f0-9ffb1171adb1",'
|
93 |
-
'"choices":[{"index":0,"message":{"role":"assistant","content":"this is rag response content"},'
|
94 |
-
'"delta":{"role":"assistant","content":""},"finish_reason":"[DONE]"}], "citations":{"total_results":0,'
|
95 |
-
'"results":[{"document_id": "", "content": "this is rag citation content", "document_type": "text",'
|
96 |
-
' "document_name": "", "metadata": "", "score": 0.0}]}}',
|
97 |
-
"result_frame": NvidiaRAGCitationsFrame(
|
98 |
-
[
|
99 |
-
NvidiaRAGCitation(
|
100 |
-
document_id="",
|
101 |
-
document_type="text",
|
102 |
-
metadata="",
|
103 |
-
score=0.0,
|
104 |
-
document_name="",
|
105 |
-
content=b"this is rag citation content",
|
106 |
-
)
|
107 |
-
]
|
108 |
-
),
|
109 |
-
},
|
110 |
-
"Fail due to empty collection name": {
|
111 |
-
"collection_name": "",
|
112 |
-
"messages": [
|
113 |
-
{
|
114 |
-
"role": "system",
|
115 |
-
"content": "You are a helpful Large Language Model. "
|
116 |
-
"Your goal is to demonstrate your capabilities in a succinct way. "
|
117 |
-
"Your output will be converted to audio so don't include special characters in your answers. "
|
118 |
-
"Respond to what the user said in a creative and helpful way.",
|
119 |
-
},
|
120 |
-
],
|
121 |
-
"result_frame": ErrorFrame(
|
122 |
-
"An error occurred in http request to RAG endpoint, Error: No query or collection name is provided.."
|
123 |
-
),
|
124 |
-
},
|
125 |
-
"Fail due to empty query": {
|
126 |
-
"collection_name": "collection123",
|
127 |
-
"messages": [
|
128 |
-
{
|
129 |
-
"role": "system",
|
130 |
-
"content": "",
|
131 |
-
},
|
132 |
-
],
|
133 |
-
"result_frame": ErrorFrame(
|
134 |
-
"An error occurred in http request to RAG endpoint, Error: No query or collection name is provided.."
|
135 |
-
),
|
136 |
-
},
|
137 |
-
"Fail due to incorrect role": {
|
138 |
-
"collection_name": "collection123",
|
139 |
-
"messages": [
|
140 |
-
{
|
141 |
-
"role": "tool",
|
142 |
-
"content": "",
|
143 |
-
},
|
144 |
-
],
|
145 |
-
"result_frame": ErrorFrame(
|
146 |
-
"An error occurred in http request to RAG endpoint, Error: Unexpected role tool found!"
|
147 |
-
),
|
148 |
-
},
|
149 |
-
}
|
150 |
-
|
151 |
-
for tc_name, tc_data in testcases.items():
|
152 |
-
logger.info(f"Verifying test case: {tc_name}")
|
153 |
-
|
154 |
-
resp = None
|
155 |
-
if "response_json" in tc_data:
|
156 |
-
resp = MockResponse(tc_data["response_json"])
|
157 |
-
else:
|
158 |
-
resp = MockResponse("{}")
|
159 |
-
|
160 |
-
mocker.patch("httpx.AsyncClient.post", return_value=resp)
|
161 |
-
|
162 |
-
rag = NvidiaRAGService(collection_name=tc_data["collection_name"])
|
163 |
-
storage1 = FrameStorage()
|
164 |
-
storage2 = FrameStorage()
|
165 |
-
context_aggregator = rag.create_context_aggregator(OpenAILLMContext(tc_data["messages"]))
|
166 |
-
|
167 |
-
pipeline = Pipeline([context_aggregator.user(), storage1, rag, storage2, context_aggregator.assistant()])
|
168 |
-
|
169 |
-
async def test_routine(task: PipelineTask, test_data=tc_data, s1=storage1, s2=storage2):
|
170 |
-
await task.queue_frame(LLMMessagesFrame(test_data["messages"]))
|
171 |
-
|
172 |
-
# Wait for the result frame
|
173 |
-
if "ErrorFrame" in test_data["result_frame"].name:
|
174 |
-
await s1.wait_for_frame(ignore_ids(test_data["result_frame"]))
|
175 |
-
else:
|
176 |
-
await s2.wait_for_frame(ignore_ids(test_data["result_frame"]))
|
177 |
-
|
178 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
179 |
-
|
180 |
-
# Verify the frames in storage1
|
181 |
-
for frame_history_entry in storage1.history:
|
182 |
-
if frame_history_entry.frame.name.startswith("ErrorFrame"):
|
183 |
-
assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
|
184 |
-
|
185 |
-
# Verify the frames in storage2
|
186 |
-
for frame_history_entry in storage2.history:
|
187 |
-
if (
|
188 |
-
frame_history_entry.frame.name.startswith("TextFrame")
|
189 |
-
and tc_data["result_frame"].__str__().startswith("TextFrame")
|
190 |
-
or frame_history_entry.frame.name.startswith("NvidiaRAGCitationsFrame")
|
191 |
-
and tc_data["result_frame"].__str__().startswith("NvidiaRAGCitationsFrame")
|
192 |
-
):
|
193 |
-
assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
|
194 |
-
|
195 |
-
|
196 |
-
@pytest.mark.asyncio
|
197 |
-
async def test_rag_service_sharing_session():
|
198 |
-
"""Test session sharing behavior between NvidiaRAGService instances.
|
199 |
-
|
200 |
-
Tests the HTTP client session management across multiple RAG service
|
201 |
-
instances.
|
202 |
-
|
203 |
-
The test verifies:
|
204 |
-
- Session sharing between instances with same parameters
|
205 |
-
- Separate session handling for custom sessions
|
206 |
-
- Proper session cleanup
|
207 |
-
"""
|
208 |
-
rags = []
|
209 |
-
rags.append(NvidiaRAGService(collection_name="collection_1"))
|
210 |
-
rags.append(NvidiaRAGService(collection_name="collection_1"))
|
211 |
-
|
212 |
-
initial_session = rags[1].shared_session
|
213 |
-
|
214 |
-
for rag in rags:
|
215 |
-
assert rag.shared_session is initial_session
|
216 |
-
|
217 |
-
new_session = httpx.AsyncClient()
|
218 |
-
rags.append(NvidiaRAGService(collection_name="collection_1", session=new_session))
|
219 |
-
|
220 |
-
assert rags[0].shared_session is initial_session
|
221 |
-
assert rags[1].shared_session is initial_session
|
222 |
-
assert rags[2].shared_session is new_session
|
223 |
-
|
224 |
-
await new_session.aclose()
|
225 |
-
for r in rags:
|
226 |
-
await r.cleanup()
|
227 |
-
|
228 |
-
|
229 |
-
@pytest.mark.asyncio
|
230 |
-
async def test_nvidia_rag_settings_frame_update(mocker):
|
231 |
-
"""Tests NvidiaRAGService settings update functionality.
|
232 |
-
|
233 |
-
Tests the processing of NvidiaRAGSettingsFrame for dynamic configuration
|
234 |
-
updates.
|
235 |
-
|
236 |
-
Args:
|
237 |
-
mocker: Pytest mocker fixture for mocking HTTP responses.
|
238 |
-
|
239 |
-
The test verifies:
|
240 |
-
- Collection name updates
|
241 |
-
- Server URL updates
|
242 |
-
- Settings frame propagation
|
243 |
-
"""
|
244 |
-
mocker.patch("httpx.AsyncClient.post", return_value="")
|
245 |
-
|
246 |
-
rag_settings_frame = NvidiaRAGSettingsFrame(
|
247 |
-
settings={"collection_name": "nvidia_blogs", "rag_server_url": "http://10.41.23.247:8081"}
|
248 |
-
)
|
249 |
-
rag = NvidiaRAGService(collection_name="collection123")
|
250 |
-
|
251 |
-
frames_to_send = [rag_settings_frame]
|
252 |
-
expected_down_frames = [rag_settings_frame]
|
253 |
-
|
254 |
-
await run_test(
|
255 |
-
rag,
|
256 |
-
frames_to_send=frames_to_send,
|
257 |
-
expected_down_frames=expected_down_frames,
|
258 |
-
)
|
259 |
-
|
260 |
-
assert rag.collection_name == "nvidia_blogs"
|
261 |
-
assert rag.rag_server_url == "http://10.41.23.247:8081"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_nvidia_tts_response_cacher.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the Nvidia TTS Response Cacher."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from pipecat.frames.frames import (
|
8 |
-
LLMFullResponseEndFrame,
|
9 |
-
LLMFullResponseStartFrame,
|
10 |
-
TTSAudioRawFrame,
|
11 |
-
UserStartedSpeakingFrame,
|
12 |
-
UserStoppedSpeakingFrame,
|
13 |
-
)
|
14 |
-
from pipecat.pipeline.pipeline import Pipeline
|
15 |
-
from pipecat.tests.utils import run_test as run_pipecat_test
|
16 |
-
|
17 |
-
from nvidia_pipecat.processors.nvidia_context_aggregator import NvidiaTTSResponseCacher
|
18 |
-
|
19 |
-
|
20 |
-
@pytest.mark.asyncio()
|
21 |
-
async def test_nvidia_tts_response_cacher():
|
22 |
-
"""Tests NvidiaTTSResponseCacher's response deduplication functionality.
|
23 |
-
|
24 |
-
Tests the cacher's ability to deduplicate TTS audio responses in a sequence
|
25 |
-
of frames including user speech events and LLM responses.
|
26 |
-
|
27 |
-
Args:
|
28 |
-
None
|
29 |
-
|
30 |
-
Returns:
|
31 |
-
None
|
32 |
-
|
33 |
-
The test verifies:
|
34 |
-
- Correct handling of user speech start/stop frames
|
35 |
-
- Deduplication of identical TTS audio frames
|
36 |
-
- Preservation of LLM response start/end frames
|
37 |
-
- Frame ordering in pipeline output
|
38 |
-
- Only one TTS frame is retained
|
39 |
-
"""
|
40 |
-
nvidia_tts_response_cacher = NvidiaTTSResponseCacher()
|
41 |
-
pipeline = Pipeline([nvidia_tts_response_cacher])
|
42 |
-
|
43 |
-
test_audio = b"\x52\x49\x46\x46\x24\x08\x00\x00\x57\x41\x56\x45\x66\x6d\x74\x20"
|
44 |
-
frames_to_send = [
|
45 |
-
UserStartedSpeakingFrame(),
|
46 |
-
LLMFullResponseStartFrame(),
|
47 |
-
TTSAudioRawFrame(
|
48 |
-
audio=test_audio,
|
49 |
-
sample_rate=16000,
|
50 |
-
num_channels=1,
|
51 |
-
),
|
52 |
-
LLMFullResponseEndFrame(),
|
53 |
-
LLMFullResponseStartFrame(),
|
54 |
-
TTSAudioRawFrame(
|
55 |
-
audio=test_audio,
|
56 |
-
sample_rate=16000,
|
57 |
-
num_channels=1,
|
58 |
-
),
|
59 |
-
LLMFullResponseEndFrame(),
|
60 |
-
UserStoppedSpeakingFrame(),
|
61 |
-
]
|
62 |
-
|
63 |
-
expected_down_frames = [
|
64 |
-
UserStartedSpeakingFrame,
|
65 |
-
UserStoppedSpeakingFrame,
|
66 |
-
LLMFullResponseStartFrame,
|
67 |
-
TTSAudioRawFrame,
|
68 |
-
LLMFullResponseEndFrame,
|
69 |
-
]
|
70 |
-
|
71 |
-
received_down_frames, received_up_frames = await run_pipecat_test(
|
72 |
-
pipeline,
|
73 |
-
frames_to_send=frames_to_send,
|
74 |
-
expected_down_frames=expected_down_frames,
|
75 |
-
)
|
76 |
-
|
77 |
-
# Verify we only got one TTS frame
|
78 |
-
tts_frames = [f for f in received_down_frames if isinstance(f, TTSAudioRawFrame)]
|
79 |
-
assert len(tts_frames) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_posture.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the PostureProviderProcessor class."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from pipecat.frames.frames import BotStoppedSpeakingFrame, StartInterruptionFrame, TTSStartedFrame
|
8 |
-
|
9 |
-
from nvidia_pipecat.frames.action import StartPostureBotActionFrame
|
10 |
-
from nvidia_pipecat.processors.posture_provider import PostureProviderProcessor
|
11 |
-
from tests.unit.utils import ignore_ids, run_test
|
12 |
-
|
13 |
-
|
14 |
-
@pytest.mark.asyncio()
|
15 |
-
async def test_posture_provider_processor_tts():
|
16 |
-
"""Test TTSStartedFrame processing in PostureProviderProcessor.
|
17 |
-
|
18 |
-
Tests that the processor generates appropriate "Talking" posture when
|
19 |
-
text-to-speech begins.
|
20 |
-
|
21 |
-
Args:
|
22 |
-
None
|
23 |
-
|
24 |
-
Returns:
|
25 |
-
None
|
26 |
-
|
27 |
-
The test verifies:
|
28 |
-
- TTSStartedFrame is processed correctly
|
29 |
-
- "Talking" posture is generated
|
30 |
-
- Frames are emitted in correct order
|
31 |
-
"""
|
32 |
-
frames_to_send = [TTSStartedFrame()]
|
33 |
-
expected_down_frames = [
|
34 |
-
ignore_ids(TTSStartedFrame()),
|
35 |
-
ignore_ids(StartPostureBotActionFrame(posture="Talking")),
|
36 |
-
]
|
37 |
-
|
38 |
-
await run_test(
|
39 |
-
PostureProviderProcessor(),
|
40 |
-
frames_to_send=frames_to_send,
|
41 |
-
expected_down_frames=expected_down_frames,
|
42 |
-
)
|
43 |
-
|
44 |
-
|
45 |
-
@pytest.mark.asyncio()
|
46 |
-
async def test_posture_provider_processor_bot_finished():
|
47 |
-
"""Test BotStoppedSpeakingFrame processing in PostureProviderProcessor.
|
48 |
-
|
49 |
-
Tests that the processor generates appropriate "Attentive" posture when
|
50 |
-
bot stops speaking.
|
51 |
-
|
52 |
-
Args:
|
53 |
-
None
|
54 |
-
|
55 |
-
Returns:
|
56 |
-
None
|
57 |
-
|
58 |
-
The test verifies:
|
59 |
-
- BotStoppedSpeakingFrame is processed correctly
|
60 |
-
- "Attentive" posture is generated
|
61 |
-
- Frames are emitted in correct order
|
62 |
-
"""
|
63 |
-
frames_to_send = [BotStoppedSpeakingFrame()]
|
64 |
-
expected_down_frames = [
|
65 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
66 |
-
ignore_ids(StartPostureBotActionFrame(posture="Attentive")),
|
67 |
-
]
|
68 |
-
|
69 |
-
await run_test(
|
70 |
-
PostureProviderProcessor(),
|
71 |
-
frames_to_send=frames_to_send,
|
72 |
-
expected_down_frames=expected_down_frames,
|
73 |
-
)
|
74 |
-
|
75 |
-
|
76 |
-
@pytest.mark.asyncio()
|
77 |
-
async def test_posture_provider_processor_interrupt():
|
78 |
-
"""Tests posture generation for interruption events.
|
79 |
-
|
80 |
-
Tests that the processor generates appropriate "Listening" posture when
|
81 |
-
an interruption occurs.
|
82 |
-
|
83 |
-
Args:
|
84 |
-
None
|
85 |
-
|
86 |
-
Returns:
|
87 |
-
None
|
88 |
-
|
89 |
-
The test verifies:
|
90 |
-
- StartInterruptionFrame is processed correctly
|
91 |
-
- "Listening" posture is generated
|
92 |
-
- Frames are emitted in correct order
|
93 |
-
"""
|
94 |
-
frames_to_send = [StartInterruptionFrame()]
|
95 |
-
expected_down_frames = [
|
96 |
-
ignore_ids(StartInterruptionFrame()),
|
97 |
-
ignore_ids(StartPostureBotActionFrame(posture="Listening")),
|
98 |
-
]
|
99 |
-
|
100 |
-
await run_test(
|
101 |
-
PostureProviderProcessor(),
|
102 |
-
frames_to_send=frames_to_send,
|
103 |
-
expected_down_frames=expected_down_frames,
|
104 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_proactivity.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the proactivity processor module.
|
5 |
-
|
6 |
-
This module contains tests that verify the behavior of the ProactivityProcessor.
|
7 |
-
"""
|
8 |
-
|
9 |
-
import asyncio
|
10 |
-
import os
|
11 |
-
import sys
|
12 |
-
|
13 |
-
import pytest
|
14 |
-
|
15 |
-
sys.path.append(os.path.abspath("../../src"))
|
16 |
-
|
17 |
-
from pipecat.frames.frames import TTSSpeakFrame, UserStoppedSpeakingFrame
|
18 |
-
from pipecat.pipeline.pipeline import Pipeline
|
19 |
-
from pipecat.pipeline.task import PipelineTask
|
20 |
-
|
21 |
-
from nvidia_pipecat.frames.action import StartedPresenceUserActionFrame
|
22 |
-
from nvidia_pipecat.processors.proactivity import ProactivityProcessor
|
23 |
-
from tests.unit.utils import FrameStorage, run_interactive_test
|
24 |
-
|
25 |
-
|
26 |
-
@pytest.mark.asyncio
|
27 |
-
async def test_proactive_bot_processor_timer_behavior():
|
28 |
-
"""Test the ProactiveBotProcessor's timer and message behavior.
|
29 |
-
|
30 |
-
Tests the processor's ability to manage timer-based proactive messages
|
31 |
-
and handle timer resets based on user activity.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
None
|
35 |
-
|
36 |
-
Returns:
|
37 |
-
None
|
38 |
-
|
39 |
-
The test verifies:
|
40 |
-
- Default message is sent after timer expiration
|
41 |
-
- Timer resets correctly on user activity
|
42 |
-
- Timer reset prevents premature message generation
|
43 |
-
- Frames are processed in correct order
|
44 |
-
- Message content matches configuration
|
45 |
-
"""
|
46 |
-
proactivity = ProactivityProcessor(default_message="I'm here if you need me!", timer_duration=0.5)
|
47 |
-
storage = FrameStorage()
|
48 |
-
pipeline = Pipeline([proactivity, storage])
|
49 |
-
|
50 |
-
async def test_routine(task: PipelineTask):
|
51 |
-
"""Inner test coroutine for proactivity testing.
|
52 |
-
|
53 |
-
Args:
|
54 |
-
task: PipelineTask instance for frame queueing.
|
55 |
-
|
56 |
-
The routine:
|
57 |
-
1. Sends initial presence frame
|
58 |
-
2. Waits for timer expiration
|
59 |
-
3. Verifies message generation
|
60 |
-
4. Tests timer reset behavior
|
61 |
-
5. Confirms no premature messages
|
62 |
-
"""
|
63 |
-
await task.queue_frame(StartedPresenceUserActionFrame(action_id="1"))
|
64 |
-
# Wait for initial proactive message
|
65 |
-
await asyncio.sleep(0.6)
|
66 |
-
|
67 |
-
# Confirm at least one frame
|
68 |
-
assert len(storage.history) >= 1, "Expected at least one frame."
|
69 |
-
|
70 |
-
# Confirm correct text frame output
|
71 |
-
frame = storage.history[2].frame
|
72 |
-
assert isinstance(frame, TTSSpeakFrame)
|
73 |
-
assert frame.text == "I'm here if you need me!"
|
74 |
-
|
75 |
-
# Send another StartFrame to reset the timer
|
76 |
-
await task.queue_frame(UserStoppedSpeakingFrame())
|
77 |
-
await asyncio.sleep(0)
|
78 |
-
|
79 |
-
# Wait half the timer (0.5s) => no new message yet
|
80 |
-
frame_count_after_reset = len(storage.history)
|
81 |
-
await asyncio.sleep(0.3)
|
82 |
-
# Confirm no additional message arrived yet
|
83 |
-
assert frame_count_after_reset == len(storage.history)
|
84 |
-
|
85 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_riva_asr_service.py
DELETED
@@ -1,523 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the RivaASRService.
|
5 |
-
|
6 |
-
This module contains tests for the RivaASRService class, including initialization,
|
7 |
-
ASR functionality, interruption handling, and integration tests.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import asyncio
|
11 |
-
import unittest
|
12 |
-
from unittest.mock import AsyncMock, MagicMock, patch
|
13 |
-
|
14 |
-
import pytest
|
15 |
-
from pipecat.frames.frames import (
|
16 |
-
CancelFrame,
|
17 |
-
EndFrame,
|
18 |
-
StartFrame,
|
19 |
-
StartInterruptionFrame,
|
20 |
-
StopInterruptionFrame,
|
21 |
-
TranscriptionFrame,
|
22 |
-
UserStartedSpeakingFrame,
|
23 |
-
UserStoppedSpeakingFrame,
|
24 |
-
)
|
25 |
-
from pipecat.transcriptions.language import Language
|
26 |
-
|
27 |
-
from nvidia_pipecat.frames.riva import RivaInterimTranscriptionFrame
|
28 |
-
from nvidia_pipecat.services.riva_speech import RivaASRService
|
29 |
-
|
30 |
-
|
31 |
-
class TestRivaASRService(unittest.TestCase):
|
32 |
-
"""Test suite for RivaASRService functionality."""
|
33 |
-
|
34 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
35 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
36 |
-
def test_initialization_with_default_parameters(self, mock_asr_service, mock_auth):
|
37 |
-
"""Tests RivaASRService initialization with default parameters.
|
38 |
-
|
39 |
-
Args:
|
40 |
-
mock_asr_service: Mock for the ASR service.
|
41 |
-
mock_auth: Mock for the authentication service.
|
42 |
-
|
43 |
-
The test verifies:
|
44 |
-
- Correct type initialization for all parameters
|
45 |
-
- Default server configuration
|
46 |
-
- Authentication setup
|
47 |
-
- Service parameter defaults
|
48 |
-
"""
|
49 |
-
# Act
|
50 |
-
service = RivaASRService(api_key="test_api_key")
|
51 |
-
|
52 |
-
# Assert - only check types, not specific values
|
53 |
-
self.assertIsInstance(service._language_code, Language)
|
54 |
-
self.assertIsInstance(service._sample_rate, int)
|
55 |
-
self.assertIsInstance(service._model, str)
|
56 |
-
# Basic boolean parameter checks without restricting values
|
57 |
-
self.assertIsInstance(service._profanity_filter, bool)
|
58 |
-
self.assertIsInstance(service._automatic_punctuation, bool)
|
59 |
-
self.assertIsInstance(service._interim_results, bool)
|
60 |
-
self.assertIsInstance(service._max_alternatives, int)
|
61 |
-
self.assertIsInstance(service._generate_interruptions, bool)
|
62 |
-
|
63 |
-
# Verify Auth was called with correct parameters
|
64 |
-
# For server "grpc.nvcf.nvidia.com:443", use_ssl is automatically set to True
|
65 |
-
mock_auth.assert_called_with(
|
66 |
-
None,
|
67 |
-
True, # Changed from False to True since default server is "grpc.nvcf.nvidia.com:443"
|
68 |
-
"grpc.nvcf.nvidia.com:443",
|
69 |
-
[["function-id", "1598d209-5e27-4d3c-8079-4751568b1081"], ["authorization", "Bearer test_api_key"]],
|
70 |
-
)
|
71 |
-
|
72 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
73 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
74 |
-
def test_initialization_with_custom_parameters(self, mock_asr_service, mock_auth):
|
75 |
-
"""Tests RivaASRService initialization with custom parameters.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
mock_asr_service: Mock for the ASR service.
|
79 |
-
mock_auth: Mock for the authentication service.
|
80 |
-
|
81 |
-
The test verifies:
|
82 |
-
- Custom parameter values are set correctly
|
83 |
-
- Server configuration is customized
|
84 |
-
- Authentication with custom credentials
|
85 |
-
- Optional parameter handling
|
86 |
-
"""
|
87 |
-
# Define test parameters
|
88 |
-
test_api_key = "test_api_key"
|
89 |
-
test_server = "custom_server:50051"
|
90 |
-
test_function_id = "custom_function_id"
|
91 |
-
test_language = Language.ES_ES
|
92 |
-
test_model = "custom_model"
|
93 |
-
test_sample_rate = 44100
|
94 |
-
test_channel_count = 2
|
95 |
-
test_max_alternatives = 2
|
96 |
-
test_boosted_words = {"boost": 1.0}
|
97 |
-
test_boosted_score = 5.0
|
98 |
-
|
99 |
-
# Act
|
100 |
-
service = RivaASRService(
|
101 |
-
api_key=test_api_key,
|
102 |
-
server=test_server,
|
103 |
-
function_id=test_function_id,
|
104 |
-
language=test_language,
|
105 |
-
model=test_model,
|
106 |
-
profanity_filter=True,
|
107 |
-
automatic_punctuation=True,
|
108 |
-
no_verbatim_transcripts=True,
|
109 |
-
boosted_lm_words=test_boosted_words,
|
110 |
-
boosted_lm_score=test_boosted_score,
|
111 |
-
sample_rate=test_sample_rate,
|
112 |
-
audio_channel_count=test_channel_count,
|
113 |
-
max_alternatives=test_max_alternatives,
|
114 |
-
interim_results=False,
|
115 |
-
generate_interruptions=True,
|
116 |
-
use_ssl=True,
|
117 |
-
)
|
118 |
-
|
119 |
-
# Assert - verify custom parameters were set correctly
|
120 |
-
self.assertEqual(service._language_code, test_language)
|
121 |
-
self.assertEqual(service._sample_rate, test_sample_rate)
|
122 |
-
self.assertEqual(service._model, test_model)
|
123 |
-
self.assertEqual(service._boosted_lm_words, test_boosted_words)
|
124 |
-
self.assertEqual(service._boosted_lm_score, test_boosted_score)
|
125 |
-
self.assertEqual(service._max_alternatives, test_max_alternatives)
|
126 |
-
self.assertEqual(service._audio_channel_count, test_channel_count)
|
127 |
-
|
128 |
-
# Verify Auth was called with correct parameters
|
129 |
-
mock_auth.assert_called_with(
|
130 |
-
None,
|
131 |
-
True,
|
132 |
-
test_server,
|
133 |
-
[["function-id", test_function_id], ["authorization", f"Bearer {test_api_key}"]],
|
134 |
-
)
|
135 |
-
|
136 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
137 |
-
def test_error_handling_during_initialization(self, mock_auth):
|
138 |
-
"""Tests error handling during service initialization.
|
139 |
-
|
140 |
-
Args:
|
141 |
-
mock_auth: Mock for the authentication service.
|
142 |
-
|
143 |
-
The test verifies:
|
144 |
-
- Proper exception handling
|
145 |
-
- Error message formatting
|
146 |
-
- Service cleanup on failure
|
147 |
-
"""
|
148 |
-
# Arrange
|
149 |
-
mock_auth.side_effect = Exception("Connection failed")
|
150 |
-
|
151 |
-
# Act & Assert
|
152 |
-
with self.assertRaises(Exception) as context:
|
153 |
-
RivaASRService(api_key="test_api_key")
|
154 |
-
|
155 |
-
# Verify the error message
|
156 |
-
self.assertTrue("Missing module: Connection failed" in str(context.exception))
|
157 |
-
|
158 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
159 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
160 |
-
def test_can_generate_metrics(self, mock_asr_service, mock_auth):
|
161 |
-
"""Test that the service can generate metrics."""
|
162 |
-
# Arrange
|
163 |
-
service = RivaASRService(api_key="test_api_key")
|
164 |
-
|
165 |
-
# Act & Assert
|
166 |
-
self.assertFalse(service.can_generate_metrics())
|
167 |
-
|
168 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
169 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
170 |
-
def test_start_method(self, mock_asr_service, mock_auth):
|
171 |
-
"""Test the start method of RivaASRService."""
|
172 |
-
# Arrange
|
173 |
-
service = RivaASRService(api_key="test_api_key")
|
174 |
-
|
175 |
-
# Create a completed mock task for expected return value
|
176 |
-
mock_task = MagicMock()
|
177 |
-
mock_task.done.return_value = True
|
178 |
-
|
179 |
-
# Use MagicMock to avoid coroutine awaiting issues
|
180 |
-
service.create_task = MagicMock(return_value=mock_task)
|
181 |
-
|
182 |
-
# Mock the response task handler
|
183 |
-
mock_response_coro = MagicMock()
|
184 |
-
service._response_task_handler = MagicMock(return_value=mock_response_coro)
|
185 |
-
|
186 |
-
# Create a mock StartFrame with the necessary attributes
|
187 |
-
mock_start_frame = MagicMock(spec=StartFrame)
|
188 |
-
|
189 |
-
# Act
|
190 |
-
async def run_test():
|
191 |
-
await service.start(mock_start_frame)
|
192 |
-
|
193 |
-
# Run the test
|
194 |
-
asyncio.run(run_test())
|
195 |
-
|
196 |
-
# Assert
|
197 |
-
# Verify create_task was called with the right handlers
|
198 |
-
service.create_task.assert_called_once_with(mock_response_coro)
|
199 |
-
|
200 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
201 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
202 |
-
def test_stop_method(self, mock_asr_service, mock_auth):
|
203 |
-
"""Test the stop method of RivaASRService."""
|
204 |
-
# Arrange
|
205 |
-
service = RivaASRService(api_key="test_api_key")
|
206 |
-
service._stop_tasks = AsyncMock()
|
207 |
-
|
208 |
-
# Create a mock EndFrame
|
209 |
-
mock_end_frame = MagicMock(spec=EndFrame)
|
210 |
-
|
211 |
-
# Act
|
212 |
-
async def run_test():
|
213 |
-
await service.stop(mock_end_frame)
|
214 |
-
|
215 |
-
# Run the test
|
216 |
-
asyncio.run(run_test())
|
217 |
-
|
218 |
-
# Assert
|
219 |
-
service._stop_tasks.assert_called_once()
|
220 |
-
|
221 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
222 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
223 |
-
def test_cancel_method(self, mock_asr_service, mock_auth):
|
224 |
-
"""Test the cancel method of RivaASRService."""
|
225 |
-
# Arrange
|
226 |
-
service = RivaASRService(api_key="test_api_key")
|
227 |
-
service._stop_tasks = AsyncMock()
|
228 |
-
|
229 |
-
# Create a mock CancelFrame
|
230 |
-
mock_cancel_frame = MagicMock(spec=CancelFrame)
|
231 |
-
|
232 |
-
# Act
|
233 |
-
async def run_test():
|
234 |
-
await service.cancel(mock_cancel_frame)
|
235 |
-
|
236 |
-
# Run the test
|
237 |
-
asyncio.run(run_test())
|
238 |
-
|
239 |
-
# Assert
|
240 |
-
service._stop_tasks.assert_called_once()
|
241 |
-
|
242 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
243 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
244 |
-
def test_run_stt_yields_frames(self, mock_asr_service, mock_auth):
|
245 |
-
"""Test that run_stt method yields frames."""
|
246 |
-
# Arrange
|
247 |
-
service = RivaASRService(api_key="test_api_key")
|
248 |
-
service._queue = AsyncMock()
|
249 |
-
|
250 |
-
# Create a completed mock task and mock thread task handler
|
251 |
-
mock_task = MagicMock()
|
252 |
-
mock_task.done.return_value = True
|
253 |
-
|
254 |
-
mock_thread_coro = MagicMock()
|
255 |
-
service._thread_task_handler = MagicMock(return_value=mock_thread_coro)
|
256 |
-
service._thread_task = mock_task
|
257 |
-
|
258 |
-
# Mock the create_task method to return our mock task and avoid coroutine warnings
|
259 |
-
service.create_task = MagicMock(return_value=mock_task)
|
260 |
-
|
261 |
-
# Act
|
262 |
-
async def run_test():
|
263 |
-
frames = []
|
264 |
-
audio_data = b"test_audio_data"
|
265 |
-
async for frame in service.run_stt(audio_data):
|
266 |
-
frames.append(frame)
|
267 |
-
|
268 |
-
# Assert
|
269 |
-
service._queue.put.assert_called_once_with(audio_data)
|
270 |
-
# run_stt yields a single None frame for RivaASRService
|
271 |
-
self.assertEqual(len(frames), 1)
|
272 |
-
self.assertIsNone(frames[0])
|
273 |
-
service.create_task.assert_called_once_with(mock_thread_coro)
|
274 |
-
|
275 |
-
# Run the test
|
276 |
-
asyncio.run(run_test())
|
277 |
-
|
278 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
279 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
280 |
-
def test_handle_response_with_final_transcript(self, mock_asr_service, mock_auth):
|
281 |
-
"""Tests handling of ASR responses with final transcripts.
|
282 |
-
|
283 |
-
Args:
|
284 |
-
mock_asr_service: Mock for the ASR service.
|
285 |
-
mock_auth: Mock for the authentication service.
|
286 |
-
|
287 |
-
The test verifies:
|
288 |
-
- Final transcript processing
|
289 |
-
- Metrics handling
|
290 |
-
- Frame generation
|
291 |
-
- Response completion handling
|
292 |
-
"""
|
293 |
-
# Arrange
|
294 |
-
service = RivaASRService(api_key="test_api_key")
|
295 |
-
service.push_frame = AsyncMock()
|
296 |
-
|
297 |
-
# Use MagicMock instead of asyncio.Future to avoid event loop issues
|
298 |
-
service.stop_ttfb_metrics = AsyncMock()
|
299 |
-
service.stop_processing_metrics = AsyncMock()
|
300 |
-
|
301 |
-
# Create a mock response with final transcript
|
302 |
-
mock_result = MagicMock()
|
303 |
-
mock_alternative = MagicMock()
|
304 |
-
mock_alternative.transcript = "This is a final transcript"
|
305 |
-
mock_result.alternatives = [mock_alternative]
|
306 |
-
mock_result.is_final = True
|
307 |
-
|
308 |
-
mock_response = MagicMock()
|
309 |
-
mock_response.results = [mock_result]
|
310 |
-
|
311 |
-
# Act
|
312 |
-
async def run_test():
|
313 |
-
await service._handle_response(mock_response)
|
314 |
-
|
315 |
-
# Run the test
|
316 |
-
asyncio.run(run_test())
|
317 |
-
|
318 |
-
# Assert
|
319 |
-
service.stop_ttfb_metrics.assert_called_once()
|
320 |
-
service.stop_processing_metrics.assert_called_once()
|
321 |
-
|
322 |
-
# Verify that a TranscriptionFrame was pushed with the correct text
|
323 |
-
found = False
|
324 |
-
for call_args in service.push_frame.call_args_list:
|
325 |
-
frame = call_args[0][0]
|
326 |
-
if isinstance(frame, TranscriptionFrame) and frame.text == "This is a final transcript":
|
327 |
-
found = True
|
328 |
-
break
|
329 |
-
self.assertTrue(found, "No TranscriptionFrame with the expected text was pushed")
|
330 |
-
|
331 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
332 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
333 |
-
def test_handle_response_with_interim_transcript(self, mock_asr_service, mock_auth):
|
334 |
-
"""Test handling of ASR responses with interim transcript."""
|
335 |
-
# Arrange
|
336 |
-
service = RivaASRService(api_key="test_api_key")
|
337 |
-
service.push_frame = AsyncMock()
|
338 |
-
|
339 |
-
# Use AsyncMock directly instead of Future
|
340 |
-
service.stop_ttfb_metrics = AsyncMock()
|
341 |
-
|
342 |
-
# Create a mock response with interim transcript
|
343 |
-
mock_result = MagicMock()
|
344 |
-
mock_alternative = MagicMock()
|
345 |
-
mock_alternative.transcript = "This is an interim transcript"
|
346 |
-
mock_result.alternatives = [mock_alternative]
|
347 |
-
mock_result.is_final = False
|
348 |
-
mock_result.stability = 1.0 # High stability interim result
|
349 |
-
|
350 |
-
mock_response = MagicMock()
|
351 |
-
mock_response.results = [mock_result]
|
352 |
-
|
353 |
-
# Act
|
354 |
-
async def run_test():
|
355 |
-
await service._handle_response(mock_response)
|
356 |
-
|
357 |
-
# Run the test
|
358 |
-
asyncio.run(run_test())
|
359 |
-
|
360 |
-
# Assert
|
361 |
-
service.stop_ttfb_metrics.assert_called_once()
|
362 |
-
|
363 |
-
# Verify that a RivaInterimTranscriptionFrame was pushed with the correct text
|
364 |
-
found = False
|
365 |
-
for call_args in service.push_frame.call_args_list:
|
366 |
-
frame = call_args[0][0]
|
367 |
-
if (
|
368 |
-
isinstance(frame, RivaInterimTranscriptionFrame)
|
369 |
-
and frame.text == "This is an interim transcript"
|
370 |
-
and frame.stability == 1.0
|
371 |
-
):
|
372 |
-
found = True
|
373 |
-
break
|
374 |
-
self.assertTrue(found, "No RivaInterimTranscriptionFrame with the expected text was pushed")
|
375 |
-
|
376 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
377 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
|
378 |
-
def test_handle_interruptions(self, mock_asr_service, mock_auth):
|
379 |
-
"""Tests handling of speech interruptions.
|
380 |
-
|
381 |
-
Args:
|
382 |
-
mock_asr_service: Mock for the ASR service.
|
383 |
-
mock_auth: Mock for the authentication service.
|
384 |
-
|
385 |
-
The test verifies:
|
386 |
-
- Interruption start handling
|
387 |
-
- Interruption stop handling
|
388 |
-
- Frame sequence generation
|
389 |
-
- State management during interruptions
|
390 |
-
"""
|
391 |
-
# Arrange
|
392 |
-
service = RivaASRService(api_key="test_api_key", generate_interruptions=True)
|
393 |
-
service.push_frame = AsyncMock()
|
394 |
-
|
395 |
-
# Use AsyncMock directly instead of Future
|
396 |
-
service._start_interruption = AsyncMock()
|
397 |
-
service._stop_interruption = AsyncMock()
|
398 |
-
|
399 |
-
# Mock the property to return True - avoids setting the property directly
|
400 |
-
type(service).interruptions_allowed = MagicMock(return_value=True)
|
401 |
-
|
402 |
-
# Act
|
403 |
-
async def run_test():
|
404 |
-
# Simulate interruption handling
|
405 |
-
user_started_frame = UserStartedSpeakingFrame()
|
406 |
-
user_stopped_frame = UserStoppedSpeakingFrame()
|
407 |
-
|
408 |
-
# Direct calls to _handle_interruptions
|
409 |
-
await service._handle_interruptions(user_started_frame)
|
410 |
-
await service._handle_interruptions(user_stopped_frame)
|
411 |
-
|
412 |
-
# Run the test
|
413 |
-
asyncio.run(run_test())
|
414 |
-
|
415 |
-
# Assert
|
416 |
-
service._start_interruption.assert_called_once()
|
417 |
-
service._stop_interruption.assert_called_once()
|
418 |
-
|
419 |
-
# Check that frames were pushed (check by type instead of exact equality)
|
420 |
-
pushed_frame_types = [type(call[0][0]) for call in service.push_frame.call_args_list]
|
421 |
-
self.assertIn(StartInterruptionFrame, pushed_frame_types, "No StartInterruptionFrame was pushed")
|
422 |
-
self.assertIn(StopInterruptionFrame, pushed_frame_types, "No StopInterruptionFrame was pushed")
|
423 |
-
self.assertIn(UserStartedSpeakingFrame, pushed_frame_types, "No UserStartedSpeakingFrame was pushed")
|
424 |
-
self.assertIn(UserStoppedSpeakingFrame, pushed_frame_types, "No UserStoppedSpeakingFrame was pushed")
|
425 |
-
|
426 |
-
|
427 |
-
@pytest.mark.asyncio
|
428 |
-
async def test_riva_asr_integration():
|
429 |
-
"""Tests integration of RivaASRService components.
|
430 |
-
|
431 |
-
Tests the complete flow of the ASR service including initialization,
|
432 |
-
processing, and cleanup.
|
433 |
-
|
434 |
-
The test verifies:
|
435 |
-
- Service startup sequence
|
436 |
-
- Audio processing pipeline
|
437 |
-
- Response handling
|
438 |
-
- Service shutdown sequence
|
439 |
-
- Resource cleanup
|
440 |
-
"""
|
441 |
-
with (
|
442 |
-
patch("nvidia_pipecat.services.riva_speech.riva.client.Auth"),
|
443 |
-
patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService") as mock_asr_service,
|
444 |
-
):
|
445 |
-
# Setup mock ASR service
|
446 |
-
mock_instance = mock_asr_service.return_value
|
447 |
-
|
448 |
-
# Initialize service with interruptions enabled
|
449 |
-
service = RivaASRService(api_key="test_api_key", generate_interruptions=True)
|
450 |
-
service._asr_service = mock_instance
|
451 |
-
|
452 |
-
# Set up the response queue
|
453 |
-
service._response_queue = asyncio.Queue()
|
454 |
-
|
455 |
-
# Mock the _stop_tasks method directly instead of relying on task_manager
|
456 |
-
service._stop_tasks = AsyncMock()
|
457 |
-
|
458 |
-
# Create mock coroutines for handlers
|
459 |
-
thread_coro = MagicMock()
|
460 |
-
response_coro = MagicMock()
|
461 |
-
|
462 |
-
# Mock the handler methods to return the coroutines
|
463 |
-
service._thread_task_handler = MagicMock(return_value=thread_coro)
|
464 |
-
service._response_task_handler = MagicMock(return_value=response_coro)
|
465 |
-
|
466 |
-
# Create a mock task that's already completed
|
467 |
-
mock_task = MagicMock()
|
468 |
-
mock_task.done.return_value = True
|
469 |
-
|
470 |
-
# Mock task creation to return our completed task
|
471 |
-
service.create_task = MagicMock(return_value=mock_task)
|
472 |
-
|
473 |
-
# Set the tasks to our mock task
|
474 |
-
service._thread_task = mock_task
|
475 |
-
service._response_task = mock_task
|
476 |
-
|
477 |
-
# Create a mock result for testing
|
478 |
-
mock_result = MagicMock()
|
479 |
-
mock_alternative = MagicMock()
|
480 |
-
mock_alternative.transcript = "This is a test transcript"
|
481 |
-
mock_result.alternatives = [mock_alternative]
|
482 |
-
mock_result.is_final = True
|
483 |
-
|
484 |
-
mock_response = MagicMock()
|
485 |
-
mock_response.results = [mock_result]
|
486 |
-
|
487 |
-
# Create a mock StartFrame
|
488 |
-
mock_start_frame = MagicMock(spec=StartFrame)
|
489 |
-
|
490 |
-
# Start the service with a start frame
|
491 |
-
await service.start(mock_start_frame)
|
492 |
-
|
493 |
-
# Verify response task was created
|
494 |
-
service.create_task.assert_called_with(response_coro)
|
495 |
-
|
496 |
-
# Put a mock response in the queue
|
497 |
-
await service._response_queue.put(mock_response)
|
498 |
-
|
499 |
-
# Test some other functionality
|
500 |
-
audio_data = b"test_audio_data"
|
501 |
-
|
502 |
-
# Run the run_stt method
|
503 |
-
frames = []
|
504 |
-
async for frame in service.run_stt(audio_data):
|
505 |
-
frames.append(frame)
|
506 |
-
|
507 |
-
# Verify results
|
508 |
-
assert len(frames) == 1
|
509 |
-
assert frames[0] is None
|
510 |
-
|
511 |
-
# Verify thread task was created
|
512 |
-
service.create_task.assert_any_call(thread_coro)
|
513 |
-
|
514 |
-
# Simulate stopping the service
|
515 |
-
mock_end_frame = MagicMock(spec=EndFrame)
|
516 |
-
await service.stop(mock_end_frame)
|
517 |
-
|
518 |
-
# Verify stop_tasks was called
|
519 |
-
service._stop_tasks.assert_called_once()
|
520 |
-
|
521 |
-
|
522 |
-
if __name__ == "__main__":
|
523 |
-
unittest.main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_riva_nmt_service.py
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the Riva Neural Machine Translation (NMT) service.
|
5 |
-
|
6 |
-
This module contains tests that verify the behavior of the RivaNMTService,
|
7 |
-
including successful translations and various error cases for both STT and LLM outputs.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import pytest
|
11 |
-
from loguru import logger
|
12 |
-
from pipecat.frames.frames import (
|
13 |
-
ErrorFrame,
|
14 |
-
LLMFullResponseEndFrame,
|
15 |
-
LLMFullResponseStartFrame,
|
16 |
-
LLMMessagesFrame,
|
17 |
-
TextFrame,
|
18 |
-
TranscriptionFrame,
|
19 |
-
)
|
20 |
-
from pipecat.pipeline.pipeline import Pipeline
|
21 |
-
from pipecat.pipeline.task import PipelineTask
|
22 |
-
from pipecat.transcriptions.language import Language
|
23 |
-
|
24 |
-
from nvidia_pipecat.services.riva_nmt import RivaNMTService
|
25 |
-
from tests.unit.utils import FrameStorage, ignore_ids, run_interactive_test
|
26 |
-
|
27 |
-
|
28 |
-
class MockText:
|
29 |
-
"""Mock class representing a translated text response.
|
30 |
-
|
31 |
-
Attributes:
|
32 |
-
text: The translated text content.
|
33 |
-
"""
|
34 |
-
|
35 |
-
def __init__(self, text):
|
36 |
-
"""Initialize MockText.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
text (str): The translated text content.
|
40 |
-
"""
|
41 |
-
self.text = text
|
42 |
-
|
43 |
-
|
44 |
-
class MockTranslations:
|
45 |
-
"""Mock class representing a collection of translations.
|
46 |
-
|
47 |
-
Attributes:
|
48 |
-
translations: List of MockText objects containing translations.
|
49 |
-
"""
|
50 |
-
|
51 |
-
def __init__(self, text):
|
52 |
-
"""Initialize MockTranslations.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
text (str): The text to be wrapped in a MockText object.
|
56 |
-
"""
|
57 |
-
self.translations = [MockText(text)]
|
58 |
-
|
59 |
-
|
60 |
-
class MockRivaNMTClient:
|
61 |
-
"""Mock class simulating the Riva NMT client.
|
62 |
-
|
63 |
-
Attributes:
|
64 |
-
translated_text: The text to return as translation.
|
65 |
-
"""
|
66 |
-
|
67 |
-
def __init__(self, translated_text):
|
68 |
-
"""Initialize MockRivaNMTClient.
|
69 |
-
|
70 |
-
Args:
|
71 |
-
translated_text (str): Text to be returned as translation.
|
72 |
-
"""
|
73 |
-
self.translated_text = translated_text
|
74 |
-
|
75 |
-
def translate(self, arg1, arg2, arg3, arg4):
|
76 |
-
"""Mock translation method.
|
77 |
-
|
78 |
-
Args:
|
79 |
-
arg1: Source language.
|
80 |
-
arg2: Target language.
|
81 |
-
arg3: Text to translate.
|
82 |
-
arg4: Additional options.
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
MockTranslations: A mock translations object containing the pre-defined translated text.
|
86 |
-
"""
|
87 |
-
return MockTranslations(self.translated_text)
|
88 |
-
|
89 |
-
|
90 |
-
@pytest.mark.asyncio()
|
91 |
-
async def test_riva_nmt_service(mocker):
|
92 |
-
"""Test the RivaNMTService functionality.
|
93 |
-
|
94 |
-
Tests translation service behavior including successful translations
|
95 |
-
and error handling for both STT and LLM outputs.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
mocker: Pytest mocker fixture for mocking dependencies.
|
99 |
-
|
100 |
-
The test verifies:
|
101 |
-
- STT output translation
|
102 |
-
- LLM output translation
|
103 |
-
- Empty input handling
|
104 |
-
- Missing language handling
|
105 |
-
- Error frame generation
|
106 |
-
- Frame sequence correctness
|
107 |
-
"""
|
108 |
-
testcases = {
|
109 |
-
"Success: STT output translated": {
|
110 |
-
"source_language": Language.ES_US,
|
111 |
-
"target_language": Language.EN_US,
|
112 |
-
"input_frames": [TranscriptionFrame("Hola, por favor preséntate.", "", "")],
|
113 |
-
"translated_text": "Hello, please introduce yourself.",
|
114 |
-
"result_frame_name": "LLMMessagesFrame",
|
115 |
-
"result_frame": LLMMessagesFrame([{"role": "system", "content": "Hello, please introduce yourself."}]),
|
116 |
-
},
|
117 |
-
"Success: LLM output translated": {
|
118 |
-
"source_language": Language.EN_US,
|
119 |
-
"target_language": Language.ES_US,
|
120 |
-
"input_frames": [
|
121 |
-
LLMFullResponseStartFrame(),
|
122 |
-
TextFrame("Hello there!"),
|
123 |
-
TextFrame("Im an artificial intelligence model known as Llama."),
|
124 |
-
LLMFullResponseEndFrame(),
|
125 |
-
],
|
126 |
-
"translated_text": "Hola Im un modelo de inteligencia artificial conocido como Llama",
|
127 |
-
"result_frame_name": "TextFrame",
|
128 |
-
"result_frame": TextFrame("Hola Im un modelo de inteligencia artificial conocido como Llama."),
|
129 |
-
},
|
130 |
-
"Fail due to empty input text": {
|
131 |
-
"source_language": Language.ES_US,
|
132 |
-
"target_language": Language.EN_US,
|
133 |
-
"input_frames": [TranscriptionFrame("", "", "")],
|
134 |
-
"translated_text": None,
|
135 |
-
"result_frame_name": "ErrorFrame",
|
136 |
-
"result_frame": ErrorFrame(
|
137 |
-
f"Error while translating the text from {Language.ES_US} to {Language.EN_US}, "
|
138 |
-
"Error: No input text provided for the translation..",
|
139 |
-
),
|
140 |
-
},
|
141 |
-
"Fail due to no source language provided": {
|
142 |
-
"source_language": None,
|
143 |
-
"target_language": Language.EN_US,
|
144 |
-
"input_frames": [TranscriptionFrame("Hola, por favor preséntate.", "", "")],
|
145 |
-
"translated_text": None,
|
146 |
-
"error": Exception("No source language provided for the translation.."),
|
147 |
-
},
|
148 |
-
"Fail due to no target language provided": {
|
149 |
-
"source_language": Language.ES_US,
|
150 |
-
"target_language": None,
|
151 |
-
"input_frames": [TranscriptionFrame("Hola, por favor preséntate.", "", "")],
|
152 |
-
"translated_text": None,
|
153 |
-
"error": Exception("No target language provided for the translation.."),
|
154 |
-
},
|
155 |
-
}
|
156 |
-
|
157 |
-
for tc_name, tc_data in testcases.items():
|
158 |
-
logger.info(f"Verifying test case: {tc_name}")
|
159 |
-
|
160 |
-
mocker.patch(
|
161 |
-
"riva.client.NeuralMachineTranslationClient", return_value=MockRivaNMTClient(tc_data["translated_text"])
|
162 |
-
)
|
163 |
-
|
164 |
-
try:
|
165 |
-
nmt_service = RivaNMTService(
|
166 |
-
source_language=tc_data["source_language"], target_language=tc_data["target_language"]
|
167 |
-
)
|
168 |
-
except Exception as e:
|
169 |
-
assert str(e) == str(tc_data["error"])
|
170 |
-
continue
|
171 |
-
|
172 |
-
storage1 = FrameStorage()
|
173 |
-
storage2 = FrameStorage()
|
174 |
-
|
175 |
-
pipeline = Pipeline([storage1, nmt_service, storage2])
|
176 |
-
|
177 |
-
async def test_routine(task: PipelineTask, test_data=tc_data, s1=storage1, s2=storage2):
|
178 |
-
await task.queue_frames(test_data["input_frames"])
|
179 |
-
|
180 |
-
# Wait for the result frame
|
181 |
-
if "ErrorFrame" in test_data["result_frame"].name:
|
182 |
-
await s1.wait_for_frame(ignore_ids(test_data["result_frame"]))
|
183 |
-
else:
|
184 |
-
await s2.wait_for_frame(ignore_ids(test_data["result_frame"]))
|
185 |
-
|
186 |
-
await run_interactive_test(pipeline, test_coroutine=test_routine)
|
187 |
-
|
188 |
-
for frame_history_entry in storage1.history:
|
189 |
-
if frame_history_entry.frame.name.startswith("TextFrame"):
|
190 |
-
# ignoring input text frames getting stored in storage1
|
191 |
-
continue
|
192 |
-
if frame_history_entry.frame.name.startswith(tc_data["result_frame_name"]):
|
193 |
-
assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
|
194 |
-
|
195 |
-
for frame_history_entry in storage2.history:
|
196 |
-
if frame_history_entry.frame.name.startswith(tc_data["result_frame_name"]):
|
197 |
-
assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_riva_tts_service.py
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the RivaTTSService.
|
5 |
-
|
6 |
-
This module contains tests for the RivaTTSService class, including initialization,
|
7 |
-
TTS frame generation, audio frame handling, and integration tests.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import asyncio
|
11 |
-
import unittest
|
12 |
-
from unittest.mock import AsyncMock, MagicMock, patch
|
13 |
-
|
14 |
-
import pytest
|
15 |
-
from pipecat.frames.frames import TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, TTSTextFrame
|
16 |
-
from pipecat.transcriptions.language import Language
|
17 |
-
|
18 |
-
from nvidia_pipecat.services.riva_speech import RivaTTSService
|
19 |
-
|
20 |
-
|
21 |
-
class TestRivaTTSService(unittest.TestCase):
|
22 |
-
"""Test suite for RivaTTSService functionality."""
|
23 |
-
|
24 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
25 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
|
26 |
-
def test_run_tts_yields_audio_frames(self, mock_speech_service, mock_auth):
|
27 |
-
"""Tests that run_tts correctly yields audio frames in sequence.
|
28 |
-
|
29 |
-
Tests the complete flow of text-to-speech conversion including start,
|
30 |
-
text processing, audio generation, and completion frames.
|
31 |
-
|
32 |
-
Args:
|
33 |
-
mock_speech_service: Mock for the speech synthesis service.
|
34 |
-
mock_auth: Mock for the authentication service.
|
35 |
-
|
36 |
-
The test verifies:
|
37 |
-
- TTSStartedFrame generation
|
38 |
-
- TTSTextFrame content
|
39 |
-
- TTSAudioRawFrame generation
|
40 |
-
- TTSStoppedFrame generation
|
41 |
-
- Frame sequence order
|
42 |
-
- Audio data integrity
|
43 |
-
"""
|
44 |
-
# Arrange
|
45 |
-
mock_audio_data = b"sample_audio_data"
|
46 |
-
mock_audio_frame = TTSAudioRawFrame(audio=mock_audio_data, sample_rate=16000, num_channels=1)
|
47 |
-
|
48 |
-
# Create a properly structured mock service
|
49 |
-
mock_service_instance = MagicMock()
|
50 |
-
mock_speech_service.return_value = mock_service_instance
|
51 |
-
|
52 |
-
# Set up the mock to return a regular list, not an async generator
|
53 |
-
# The service expects a list/iterable that it can call iter() on
|
54 |
-
mock_service_instance.synthesize_online.return_value = [mock_audio_frame]
|
55 |
-
|
56 |
-
# Create an instance of RivaTTSService
|
57 |
-
service = RivaTTSService(api_key="test_api_key")
|
58 |
-
|
59 |
-
# Ensure the service has the right mock
|
60 |
-
service._service = mock_service_instance
|
61 |
-
|
62 |
-
# Act
|
63 |
-
async def run_test():
|
64 |
-
frames = []
|
65 |
-
async for frame in service.run_tts("Hello, world!"):
|
66 |
-
frames.append(frame)
|
67 |
-
|
68 |
-
# Assert
|
69 |
-
self.assertEqual(
|
70 |
-
len(frames), 4
|
71 |
-
) # Should yield 4 frames: TTSStartedFrame, TTSTextFrame, TTSAudioRawFrame, TTSStoppedFrame
|
72 |
-
self.assertIsInstance(frames[0], TTSStartedFrame)
|
73 |
-
self.assertIsInstance(frames[1], TTSTextFrame)
|
74 |
-
self.assertEqual(frames[1].text, "Hello, world!")
|
75 |
-
self.assertIsInstance(frames[2], TTSAudioRawFrame)
|
76 |
-
self.assertEqual(frames[2].audio, mock_audio_data)
|
77 |
-
self.assertIsInstance(frames[3], TTSStoppedFrame)
|
78 |
-
|
79 |
-
# Run the async test
|
80 |
-
asyncio.run(run_test())
|
81 |
-
|
82 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
83 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
|
84 |
-
def test_push_tts_frames(self, mock_speech_service, mock_auth):
|
85 |
-
"""Tests that _push_tts_frames correctly processes TTS generation.
|
86 |
-
|
87 |
-
Tests the internal frame processing mechanism including metrics
|
88 |
-
handling and generator processing.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
mock_speech_service: Mock for the speech synthesis service.
|
92 |
-
mock_auth: Mock for the authentication service.
|
93 |
-
|
94 |
-
The test verifies:
|
95 |
-
- Metrics start/stop timing
|
96 |
-
- Generator processing
|
97 |
-
- Frame processing order
|
98 |
-
- Method call sequence
|
99 |
-
"""
|
100 |
-
# Arrange
|
101 |
-
mock_audio_frame = TTSAudioRawFrame(audio=b"sample_audio_data", sample_rate=16000, num_channels=1)
|
102 |
-
|
103 |
-
# Create a properly structured mock service
|
104 |
-
mock_service_instance = MagicMock()
|
105 |
-
mock_speech_service.return_value = mock_service_instance
|
106 |
-
|
107 |
-
# Return a regular list for synthesize_online
|
108 |
-
mock_service_instance.synthesize_online.return_value = [mock_audio_frame]
|
109 |
-
|
110 |
-
# Create an instance of RivaTTSService
|
111 |
-
service = RivaTTSService(api_key="test_api_key")
|
112 |
-
service._service = mock_service_instance
|
113 |
-
service.start_processing_metrics = AsyncMock()
|
114 |
-
service.stop_processing_metrics = AsyncMock()
|
115 |
-
service.process_generator = AsyncMock()
|
116 |
-
|
117 |
-
# Create a mock for run_tts instead of trying to capture its generator
|
118 |
-
async def mock_run_tts(text):
|
119 |
-
# This is the sequence of frames that would be yielded by run_tts
|
120 |
-
yield TTSStartedFrame()
|
121 |
-
yield TTSTextFrame(text)
|
122 |
-
yield mock_audio_frame
|
123 |
-
yield TTSStoppedFrame()
|
124 |
-
|
125 |
-
# Replace the run_tts method with our mock
|
126 |
-
with patch.object(service, "run_tts", side_effect=mock_run_tts):
|
127 |
-
# Act
|
128 |
-
async def run_test():
|
129 |
-
await service._push_tts_frames("Hello, world!")
|
130 |
-
|
131 |
-
# Assert
|
132 |
-
# Check that the necessary methods were called
|
133 |
-
service.start_processing_metrics.assert_called_once()
|
134 |
-
service.process_generator.assert_called_once()
|
135 |
-
service.stop_processing_metrics.assert_called_once()
|
136 |
-
|
137 |
-
# Verify call order using the call_args.called_before method
|
138 |
-
assert service.start_processing_metrics.call_args.called_before(service.process_generator.call_args)
|
139 |
-
assert service.process_generator.call_args.called_before(service.stop_processing_metrics.call_args)
|
140 |
-
|
141 |
-
# Run the async test
|
142 |
-
asyncio.run(run_test())
|
143 |
-
|
144 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
145 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
|
146 |
-
def test_init_with_different_parameters(self, mock_speech_service, mock_auth):
|
147 |
-
"""Tests initialization with various configuration parameters.
|
148 |
-
|
149 |
-
Tests service initialization with different combinations of
|
150 |
-
configuration options.
|
151 |
-
|
152 |
-
Args:
|
153 |
-
mock_speech_service: Mock for the speech synthesis service.
|
154 |
-
mock_auth: Mock for the authentication service.
|
155 |
-
|
156 |
-
The test verifies:
|
157 |
-
- Parameter assignment
|
158 |
-
- Default value handling
|
159 |
-
- Custom parameter validation
|
160 |
-
- Authentication configuration
|
161 |
-
- Service initialization
|
162 |
-
"""
|
163 |
-
# Define test parameters - users should be able to use any values
|
164 |
-
test_api_key = "test_api_key"
|
165 |
-
test_server = "custom_server:50051"
|
166 |
-
test_voice_id = "English-US.Male-1"
|
167 |
-
test_sample_rate = 22050
|
168 |
-
test_language = Language.ES_ES
|
169 |
-
test_zero_shot_quality = 10
|
170 |
-
test_model = "custom-tts-model"
|
171 |
-
test_dictionary = {"word": "pronunciation"}
|
172 |
-
test_audio_prompt_file = "test_audio.wav"
|
173 |
-
# Test initialization with different parameters
|
174 |
-
service = RivaTTSService(
|
175 |
-
api_key=test_api_key,
|
176 |
-
server=test_server,
|
177 |
-
voice_id=test_voice_id,
|
178 |
-
sample_rate=test_sample_rate,
|
179 |
-
language=test_language,
|
180 |
-
zero_shot_quality=test_zero_shot_quality,
|
181 |
-
model=test_model,
|
182 |
-
custom_dictionary=test_dictionary,
|
183 |
-
zero_shot_audio_prompt_file=test_audio_prompt_file,
|
184 |
-
use_ssl=True,
|
185 |
-
)
|
186 |
-
|
187 |
-
# Verify the parameters were set correctly
|
188 |
-
self.assertEqual(service._api_key, test_api_key)
|
189 |
-
self.assertEqual(service._voice_id, test_voice_id)
|
190 |
-
self.assertEqual(service._sample_rate, test_sample_rate)
|
191 |
-
self.assertEqual(service._language_code, test_language)
|
192 |
-
self.assertEqual(service._zero_shot_quality, test_zero_shot_quality)
|
193 |
-
self.assertEqual(service._model_name, test_model)
|
194 |
-
self.assertEqual(service._custom_dictionary, test_dictionary)
|
195 |
-
self.assertEqual(service._zero_shot_audio_prompt_file, test_audio_prompt_file)
|
196 |
-
|
197 |
-
# Verify Auth was called with correct parameters
|
198 |
-
mock_auth.assert_called_with(
|
199 |
-
None,
|
200 |
-
True, # use_ssl=True
|
201 |
-
test_server,
|
202 |
-
[["function-id", "0149dedb-2be8-4195-b9a0-e57e0e14f972"], ["authorization", f"Bearer {test_api_key}"]],
|
203 |
-
)
|
204 |
-
|
205 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
206 |
-
def test_riva_error_handling(self, mock_auth):
|
207 |
-
"""Tests error handling when Riva service initialization fails.
|
208 |
-
|
209 |
-
Tests the service's behavior when encountering initialization errors.
|
210 |
-
|
211 |
-
Args:
|
212 |
-
mock_auth: Mock for the authentication service.
|
213 |
-
|
214 |
-
The test verifies:
|
215 |
-
- Exception propagation
|
216 |
-
- Error message formatting
|
217 |
-
- Cleanup behavior
|
218 |
-
- Service state after error
|
219 |
-
"""
|
220 |
-
# Test error handling when Riva service initialization fails
|
221 |
-
mock_auth.side_effect = Exception("Connection failed")
|
222 |
-
|
223 |
-
# Assert that exception is raised and propagated
|
224 |
-
with self.assertRaises(Exception) as context:
|
225 |
-
RivaTTSService(api_key="test_api_key")
|
226 |
-
|
227 |
-
self.assertTrue("Missing module: Connection failed" in str(context.exception))
|
228 |
-
|
229 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
|
230 |
-
@patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
|
231 |
-
def test_can_generate_metrics(self, mock_speech_service, mock_auth):
|
232 |
-
"""Tests that the service reports capability to generate metrics.
|
233 |
-
|
234 |
-
Tests the metrics generation capability reporting functionality.
|
235 |
-
|
236 |
-
Args:
|
237 |
-
mock_speech_service: Mock for the speech synthesis service.
|
238 |
-
mock_auth: Mock for the authentication service.
|
239 |
-
|
240 |
-
The test verifies:
|
241 |
-
- Metrics capability reporting
|
242 |
-
- Consistency of capability flag
|
243 |
-
"""
|
244 |
-
# Test that the service can generate metrics
|
245 |
-
service = RivaTTSService(api_key="test_api_key")
|
246 |
-
self.assertTrue(service.can_generate_metrics())
|
247 |
-
|
248 |
-
|
249 |
-
@pytest.mark.asyncio
|
250 |
-
async def test_riva_tts_integration():
|
251 |
-
"""Tests integration of RivaTTSService components.
|
252 |
-
|
253 |
-
Tests the complete flow of the TTS service in an integrated environment.
|
254 |
-
|
255 |
-
The test verifies:
|
256 |
-
- Service initialization
|
257 |
-
- Frame generation sequence
|
258 |
-
- Audio chunk processing
|
259 |
-
- Frame content validation
|
260 |
-
- Service completion
|
261 |
-
"""
|
262 |
-
# Use parentheses for multiline with statements instead of backslashes
|
263 |
-
with (
|
264 |
-
patch("nvidia_pipecat.services.riva_speech.riva.client.Auth"),
|
265 |
-
patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService") as mock_service,
|
266 |
-
):
|
267 |
-
# Setup mock responses
|
268 |
-
mock_instance = mock_service.return_value
|
269 |
-
|
270 |
-
# Create audio frames for the response
|
271 |
-
audio_frame1 = TTSAudioRawFrame(audio=b"audio_chunk_1", sample_rate=16000, num_channels=1)
|
272 |
-
audio_frame2 = TTSAudioRawFrame(audio=b"audio_chunk_2", sample_rate=16000, num_channels=1)
|
273 |
-
|
274 |
-
# Return a list of frames, not an async generator
|
275 |
-
mock_instance.synthesize_online.return_value = [audio_frame1, audio_frame2]
|
276 |
-
|
277 |
-
# Initialize service and call its methods
|
278 |
-
service = RivaTTSService(api_key="test_api_key")
|
279 |
-
service._service = mock_instance
|
280 |
-
|
281 |
-
# Simulate running the service
|
282 |
-
collected_frames = []
|
283 |
-
async for frame in service.run_tts("Test sentence for TTS"):
|
284 |
-
collected_frames.append(frame)
|
285 |
-
|
286 |
-
# Verify the expected frames were produced
|
287 |
-
assert len(collected_frames) == 5 # Started, TextFrame, 2 audio chunks, stopped
|
288 |
-
assert isinstance(collected_frames[0], TTSStartedFrame)
|
289 |
-
assert isinstance(collected_frames[1], TTSTextFrame)
|
290 |
-
assert collected_frames[1].text == "Test sentence for TTS"
|
291 |
-
assert isinstance(collected_frames[2], TTSAudioRawFrame)
|
292 |
-
assert isinstance(collected_frames[3], TTSAudioRawFrame)
|
293 |
-
assert isinstance(collected_frames[4], TTSStoppedFrame)
|
294 |
-
|
295 |
-
# Verify audio content
|
296 |
-
assert collected_frames[2].audio == b"audio_chunk_1"
|
297 |
-
assert collected_frames[3].audio == b"audio_chunk_2"
|
298 |
-
|
299 |
-
|
300 |
-
if __name__ == "__main__":
|
301 |
-
unittest.main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_speech_planner.py
DELETED
@@ -1,546 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the SpeechPlanner service.
|
5 |
-
|
6 |
-
This module contains tests for the SpeechPlanner class, focusing on core functionalities:
|
7 |
-
- Service initialization with various parameters
|
8 |
-
- Frame processing for different frame types
|
9 |
-
- Speech completion detection and LLM integration
|
10 |
-
- Chat history management with context windows
|
11 |
-
- Interruption handling and VAD integration
|
12 |
-
- Label preprocessing and classification logic
|
13 |
-
"""
|
14 |
-
|
15 |
-
import tempfile
|
16 |
-
from datetime import datetime
|
17 |
-
from unittest.mock import AsyncMock, Mock, patch
|
18 |
-
|
19 |
-
import pytest
|
20 |
-
import yaml
|
21 |
-
from pipecat.frames.frames import (
|
22 |
-
InterimTranscriptionFrame,
|
23 |
-
StartInterruptionFrame,
|
24 |
-
StopInterruptionFrame,
|
25 |
-
TranscriptionFrame,
|
26 |
-
)
|
27 |
-
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
28 |
-
|
29 |
-
from nvidia_pipecat.services.speech_planner import SpeechPlanner
|
30 |
-
|
31 |
-
|
32 |
-
class MockBaseMessageChunk:
|
33 |
-
"""Mock for BaseMessageChunk that mimics the structure."""
|
34 |
-
|
35 |
-
def __init__(self, content=""):
|
36 |
-
"""Initialize with content.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
content: The text content of the chunk
|
40 |
-
"""
|
41 |
-
self.content = content
|
42 |
-
|
43 |
-
|
44 |
-
@pytest.fixture
|
45 |
-
def mock_prompt_file():
|
46 |
-
"""Create a temporary YAML prompt file for testing."""
|
47 |
-
prompt_data = {
|
48 |
-
"configurations": {"using_chat_history": False},
|
49 |
-
"prompts": {
|
50 |
-
"completion_prompt": (
|
51 |
-
"Evaluate whether the following user speech is sufficient:\n"
|
52 |
-
"1. Label1: Complete and coherent thought\n"
|
53 |
-
"2. Label2: Incomplete speech\n"
|
54 |
-
"3. Label3: User commands\n"
|
55 |
-
"4. Label4: Acknowledgments\n"
|
56 |
-
"User Speech: {transcript}\n"
|
57 |
-
"Only return Label1 or Label2 or Label3 or Label4."
|
58 |
-
)
|
59 |
-
},
|
60 |
-
}
|
61 |
-
|
62 |
-
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
63 |
-
yaml.dump(prompt_data, f)
|
64 |
-
return f.name
|
65 |
-
|
66 |
-
|
67 |
-
@pytest.fixture
|
68 |
-
def mock_context():
|
69 |
-
"""Create a mock OpenAI LLM context."""
|
70 |
-
context = Mock(spec=OpenAILLMContext)
|
71 |
-
context.get_messages.return_value = [
|
72 |
-
{"role": "user", "content": "Hello"},
|
73 |
-
{"role": "assistant", "content": "Hi there!"},
|
74 |
-
{"role": "user", "content": "How are you?"},
|
75 |
-
{"role": "assistant", "content": "I'm doing well!"},
|
76 |
-
]
|
77 |
-
return context
|
78 |
-
|
79 |
-
|
80 |
-
@pytest.fixture
|
81 |
-
def speech_planner(mock_prompt_file, mock_context):
|
82 |
-
"""Create a SpeechPlanner instance for testing."""
|
83 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
84 |
-
mock_chat.return_value = AsyncMock()
|
85 |
-
planner = SpeechPlanner(
|
86 |
-
prompt_file=mock_prompt_file,
|
87 |
-
model="test-model",
|
88 |
-
api_key="test-key",
|
89 |
-
base_url="http://test-url",
|
90 |
-
context=mock_context,
|
91 |
-
context_window=2,
|
92 |
-
)
|
93 |
-
return planner
|
94 |
-
|
95 |
-
|
96 |
-
class TestSpeechPlannerInitialization:
|
97 |
-
"""Test SpeechPlanner initialization."""
|
98 |
-
|
99 |
-
def test_init_with_default_params(self, mock_prompt_file, mock_context):
|
100 |
-
"""Test initialization with default parameters."""
|
101 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
102 |
-
mock_chat.return_value = AsyncMock()
|
103 |
-
|
104 |
-
planner = SpeechPlanner(prompt_file=mock_prompt_file, context=mock_context)
|
105 |
-
|
106 |
-
assert planner.model_name == "nvdev/google/gemma-2b-it"
|
107 |
-
assert planner.context == mock_context
|
108 |
-
assert planner.context_window == 1
|
109 |
-
assert planner.user_speaking is None
|
110 |
-
assert planner.current_prediction is None
|
111 |
-
|
112 |
-
def test_init_with_custom_params(self, mock_prompt_file, mock_context):
|
113 |
-
"""Test initialization with custom parameters."""
|
114 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
115 |
-
mock_chat.return_value = AsyncMock()
|
116 |
-
|
117 |
-
params = SpeechPlanner.InputParams(temperature=0.7, max_tokens=100, top_p=0.9)
|
118 |
-
|
119 |
-
planner = SpeechPlanner(
|
120 |
-
prompt_file=mock_prompt_file,
|
121 |
-
model="custom-model",
|
122 |
-
api_key="custom-key",
|
123 |
-
base_url="http://custom-url",
|
124 |
-
context=mock_context,
|
125 |
-
params=params,
|
126 |
-
context_window=3,
|
127 |
-
)
|
128 |
-
|
129 |
-
assert planner.model_name == "custom-model"
|
130 |
-
assert planner.context_window == 3
|
131 |
-
assert planner._settings["temperature"] == 0.7
|
132 |
-
assert planner._settings["max_tokens"] == 100
|
133 |
-
assert planner._settings["top_p"] == 0.9
|
134 |
-
|
135 |
-
def test_init_loads_prompts(self, mock_prompt_file, mock_context):
|
136 |
-
"""Test that initialization properly loads prompts from file."""
|
137 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
138 |
-
mock_chat.return_value = AsyncMock()
|
139 |
-
|
140 |
-
planner = SpeechPlanner(prompt_file=mock_prompt_file, context=mock_context)
|
141 |
-
|
142 |
-
assert "configurations" in planner.prompts
|
143 |
-
assert "prompts" in planner.prompts
|
144 |
-
assert "completion_prompt" in planner.prompts["prompts"]
|
145 |
-
assert planner.prompts["configurations"]["using_chat_history"] is False
|
146 |
-
|
147 |
-
|
148 |
-
### Adding tests for preprocess_pred function
|
149 |
-
|
150 |
-
|
151 |
-
class TestPreprocessPred:
|
152 |
-
"""Test the preprocess_pred function logic through end-to-end processing."""
|
153 |
-
|
154 |
-
@pytest.mark.asyncio
|
155 |
-
async def test_preprocess_pred_label1_complete(self, speech_planner):
|
156 |
-
"""Test preprocess_pred with Label1 returns Complete via end-to-end processing."""
|
157 |
-
frame = TranscriptionFrame("Hello there", "user1", datetime.now())
|
158 |
-
test_cases = ["Label1", "Label 1", "The answer is Label1.", "I think this is Label 1 based on analysis."]
|
159 |
-
|
160 |
-
for case in test_cases:
|
161 |
-
# Mock the LLM to return the specific prediction
|
162 |
-
mock_chunks = [MockBaseMessageChunk(case)]
|
163 |
-
|
164 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
165 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
166 |
-
|
167 |
-
await speech_planner._process_complete_context(frame)
|
168 |
-
assert speech_planner.current_prediction == "Complete", f"Failed for case: {case}"
|
169 |
-
|
170 |
-
@pytest.mark.asyncio
|
171 |
-
async def test_preprocess_pred_label2_incomplete(self, speech_planner):
|
172 |
-
"""Test preprocess_pred with Label2 returns Incomplete via end-to-end processing."""
|
173 |
-
frame = TranscriptionFrame("Hello", "user1", datetime.now())
|
174 |
-
test_cases = ["Label2", "Label 2", "The answer is Label2.", "This should be Label 2."]
|
175 |
-
|
176 |
-
for case in test_cases:
|
177 |
-
# Mock the LLM to return the specific prediction
|
178 |
-
mock_chunks = [MockBaseMessageChunk(case)]
|
179 |
-
|
180 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
181 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
182 |
-
|
183 |
-
await speech_planner._process_complete_context(frame)
|
184 |
-
assert speech_planner.current_prediction == "Incomplete", f"Failed for case: {case}"
|
185 |
-
|
186 |
-
@pytest.mark.asyncio
|
187 |
-
async def test_preprocess_pred_label3_complete(self, speech_planner):
|
188 |
-
"""Test preprocess_pred with Label3 returns Complete via end-to-end processing."""
|
189 |
-
frame = TranscriptionFrame("Stop that", "user1", datetime.now())
|
190 |
-
test_cases = ["Label3", "Label 3", "The answer is Label3.", "I classify this as Label 3."]
|
191 |
-
|
192 |
-
for case in test_cases:
|
193 |
-
# Mock the LLM to return the specific prediction
|
194 |
-
mock_chunks = [MockBaseMessageChunk(case)]
|
195 |
-
|
196 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
197 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
198 |
-
|
199 |
-
await speech_planner._process_complete_context(frame)
|
200 |
-
assert speech_planner.current_prediction == "Complete", f"Failed for case: {case}"
|
201 |
-
|
202 |
-
@pytest.mark.asyncio
|
203 |
-
async def test_preprocess_pred_label4_complete(self, speech_planner):
|
204 |
-
"""Test preprocess_pred with Label4 returns Complete via end-to-end processing."""
|
205 |
-
frame = TranscriptionFrame("Okay", "user1", datetime.now())
|
206 |
-
test_cases = ["Label4", "Label 4", "The answer is Label4.", "This is Label 4 category."]
|
207 |
-
|
208 |
-
for case in test_cases:
|
209 |
-
# Mock the LLM to return the specific prediction
|
210 |
-
mock_chunks = [MockBaseMessageChunk(case)]
|
211 |
-
|
212 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
213 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
214 |
-
|
215 |
-
await speech_planner._process_complete_context(frame)
|
216 |
-
assert speech_planner.current_prediction == "Complete", f"Failed for case: {case}"
|
217 |
-
|
218 |
-
@pytest.mark.asyncio
|
219 |
-
async def test_preprocess_pred_unrecognized_incomplete(self, speech_planner):
|
220 |
-
"""Test preprocess_pred with unrecognized labels returns Incomplete via end-to-end processing."""
|
221 |
-
frame = TranscriptionFrame("Unknown input", "user1", datetime.now())
|
222 |
-
test_cases = ["Label5", "Unknown", "No label found", "", "Some random text"]
|
223 |
-
|
224 |
-
for case in test_cases:
|
225 |
-
# Mock the LLM to return the specific prediction
|
226 |
-
mock_chunks = [MockBaseMessageChunk(case)]
|
227 |
-
|
228 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
229 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
230 |
-
|
231 |
-
await speech_planner._process_complete_context(frame)
|
232 |
-
assert speech_planner.current_prediction == "Incomplete", f"Failed for case: {case}"
|
233 |
-
|
234 |
-
|
235 |
-
### Adding tests for chat_history management
|
236 |
-
class TestChatHistory:
|
237 |
-
"""Test chat history management."""
|
238 |
-
|
239 |
-
def test_get_chat_history_empty_messages(self, mock_prompt_file):
|
240 |
-
"""Test get_chat_history with empty message list."""
|
241 |
-
context = Mock(spec=OpenAILLMContext)
|
242 |
-
context.get_messages.return_value = []
|
243 |
-
|
244 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
245 |
-
mock_chat.return_value = AsyncMock()
|
246 |
-
|
247 |
-
planner = SpeechPlanner(prompt_file=mock_prompt_file, context=context, context_window=2)
|
248 |
-
|
249 |
-
history = planner.get_chat_history()
|
250 |
-
assert history == []
|
251 |
-
|
252 |
-
def test_get_chat_history_with_context_window(self, mock_prompt_file):
|
253 |
-
"""Test get_chat_history respects context_window setting."""
|
254 |
-
messages = [
|
255 |
-
{"role": "user", "content": "First user message"},
|
256 |
-
{"role": "assistant", "content": "First assistant response"},
|
257 |
-
{"role": "user", "content": "Second user message"},
|
258 |
-
{"role": "assistant", "content": "Second assistant response"},
|
259 |
-
{"role": "user", "content": "Third user message"},
|
260 |
-
{"role": "assistant", "content": "Third assistant response"},
|
261 |
-
]
|
262 |
-
|
263 |
-
context = Mock(spec=OpenAILLMContext)
|
264 |
-
context.get_messages.return_value = messages
|
265 |
-
|
266 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
267 |
-
mock_chat.return_value = AsyncMock()
|
268 |
-
|
269 |
-
# Test with context_window=1 (should get last 2 messages)
|
270 |
-
planner = SpeechPlanner(prompt_file=mock_prompt_file, context=context, context_window=1)
|
271 |
-
|
272 |
-
history = planner.get_chat_history()
|
273 |
-
assert len(history) == 2
|
274 |
-
assert history[0]["content"] == "Third user message"
|
275 |
-
assert history[1]["content"] == "Third assistant response"
|
276 |
-
|
277 |
-
def test_get_chat_history_starts_with_user(self, mock_prompt_file):
|
278 |
-
"""Test get_chat_history starts with user message."""
|
279 |
-
messages = [
|
280 |
-
{"role": "user", "content": "User message 1"},
|
281 |
-
{"role": "assistant", "content": "Assistant response 1"},
|
282 |
-
{"role": "user", "content": "User message 2"},
|
283 |
-
{"role": "assistant", "content": "Assistant response 2"},
|
284 |
-
]
|
285 |
-
|
286 |
-
context = Mock(spec=OpenAILLMContext)
|
287 |
-
context.get_messages.return_value = messages
|
288 |
-
|
289 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
|
290 |
-
mock_chat.return_value = AsyncMock()
|
291 |
-
|
292 |
-
planner = SpeechPlanner(prompt_file=mock_prompt_file, context=context, context_window=2)
|
293 |
-
|
294 |
-
history = planner.get_chat_history()
|
295 |
-
assert len(history) > 0
|
296 |
-
assert history[0]["role"] == "user"
|
297 |
-
|
298 |
-
|
299 |
-
class TestFrameProcessing:
|
300 |
-
"""Test frame processing functionality by testing the logic directly."""
|
301 |
-
|
302 |
-
def test_user_speaking_state_changes(self, speech_planner):
|
303 |
-
"""Test that user speaking state changes correctly."""
|
304 |
-
# Test UserStartedSpeakingFrame logic
|
305 |
-
assert speech_planner.user_speaking is None
|
306 |
-
speech_planner.user_speaking = True
|
307 |
-
assert speech_planner.user_speaking is True
|
308 |
-
|
309 |
-
# Test UserStoppedSpeakingFrame logic
|
310 |
-
speech_planner.user_speaking = False
|
311 |
-
assert speech_planner.user_speaking is False
|
312 |
-
|
313 |
-
def test_bot_speaking_timestamp_tracking(self, speech_planner):
|
314 |
-
"""Test bot speaking timestamp tracking."""
|
315 |
-
# Initially no timestamp
|
316 |
-
assert speech_planner.latest_bot_started_speaking_frame_timestamp is None
|
317 |
-
|
318 |
-
# Set timestamp (simulating BotStartedSpeakingFrame)
|
319 |
-
test_time = datetime.now()
|
320 |
-
speech_planner.latest_bot_started_speaking_frame_timestamp = test_time
|
321 |
-
assert speech_planner.latest_bot_started_speaking_frame_timestamp == test_time
|
322 |
-
|
323 |
-
# Clear timestamp (simulating BotStoppedSpeakingFrame)
|
324 |
-
speech_planner.latest_bot_started_speaking_frame_timestamp = None
|
325 |
-
assert speech_planner.latest_bot_started_speaking_frame_timestamp is None
|
326 |
-
|
327 |
-
def test_frame_state_management(self, speech_planner):
|
328 |
-
"""Test frame state management without full processing."""
|
329 |
-
# Test last_frame tracking
|
330 |
-
frame = InterimTranscriptionFrame("Hello", "user1", datetime.now())
|
331 |
-
speech_planner.last_frame = frame
|
332 |
-
assert speech_planner.last_frame == frame
|
333 |
-
|
334 |
-
# Test clearing last_frame
|
335 |
-
speech_planner.last_frame = None
|
336 |
-
assert speech_planner.last_frame is None
|
337 |
-
|
338 |
-
# Test current_prediction state
|
339 |
-
speech_planner.current_prediction = "Complete"
|
340 |
-
assert speech_planner.current_prediction == "Complete"
|
341 |
-
|
342 |
-
speech_planner.current_prediction = "Incomplete"
|
343 |
-
assert speech_planner.current_prediction == "Incomplete"
|
344 |
-
|
345 |
-
def test_transcription_frame_conditions(self, speech_planner):
|
346 |
-
"""Test the conditions for processing transcription frames."""
|
347 |
-
# Set up conditions for processing
|
348 |
-
speech_planner.user_speaking = False
|
349 |
-
speech_planner.current_prediction = "Incomplete"
|
350 |
-
|
351 |
-
# These conditions should allow processing
|
352 |
-
assert speech_planner.user_speaking is False
|
353 |
-
assert speech_planner.current_prediction == "Incomplete"
|
354 |
-
|
355 |
-
# Test conditions that would prevent processing
|
356 |
-
speech_planner.user_speaking = True
|
357 |
-
assert speech_planner.user_speaking is True # Should prevent processing
|
358 |
-
|
359 |
-
@pytest.mark.asyncio
|
360 |
-
async def test_cancel_current_task_helper(self, speech_planner):
|
361 |
-
"""Test the _cancel_current_task helper method."""
|
362 |
-
# Test with no current task
|
363 |
-
speech_planner._current_task = None
|
364 |
-
await speech_planner._cancel_current_task()
|
365 |
-
assert speech_planner._current_task is None
|
366 |
-
|
367 |
-
# Test with completed task
|
368 |
-
completed_task = Mock()
|
369 |
-
completed_task.done.return_value = True
|
370 |
-
completed_task.cancelled.return_value = False
|
371 |
-
speech_planner._current_task = completed_task
|
372 |
-
|
373 |
-
await speech_planner._cancel_current_task()
|
374 |
-
assert speech_planner._current_task is None
|
375 |
-
|
376 |
-
# Test with cancelled task
|
377 |
-
cancelled_task = Mock()
|
378 |
-
cancelled_task.done.return_value = False
|
379 |
-
cancelled_task.cancelled.return_value = True
|
380 |
-
speech_planner._current_task = cancelled_task
|
381 |
-
|
382 |
-
await speech_planner._cancel_current_task()
|
383 |
-
assert speech_planner._current_task is None
|
384 |
-
|
385 |
-
# Test with active task that needs cancellation
|
386 |
-
active_task = Mock()
|
387 |
-
active_task.done.return_value = False
|
388 |
-
active_task.cancelled.return_value = False
|
389 |
-
speech_planner._current_task = active_task
|
390 |
-
|
391 |
-
with patch.object(speech_planner, "cancel_task", new_callable=AsyncMock) as mock_cancel:
|
392 |
-
await speech_planner._cancel_current_task()
|
393 |
-
mock_cancel.assert_called_once_with(active_task)
|
394 |
-
assert speech_planner._current_task is None
|
395 |
-
|
396 |
-
|
397 |
-
class TestCompletionDetection:
|
398 |
-
"""Test speech completion detection."""
|
399 |
-
|
400 |
-
@pytest.mark.asyncio
|
401 |
-
async def test_process_complete_context_with_complete_prediction(self, speech_planner):
|
402 |
-
"""Test _process_complete_context with complete prediction."""
|
403 |
-
frame = TranscriptionFrame("Hello there", "user1", datetime.now())
|
404 |
-
|
405 |
-
# Mock the LLM response
|
406 |
-
mock_chunks = [MockBaseMessageChunk("Label1")]
|
407 |
-
|
408 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
409 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
410 |
-
with patch.object(speech_planner, "push_frame", new_callable=AsyncMock) as mock_push:
|
411 |
-
await speech_planner._process_complete_context(frame)
|
412 |
-
|
413 |
-
assert speech_planner.current_prediction == "Complete"
|
414 |
-
|
415 |
-
# Should push start/stop interruption frames and transcription
|
416 |
-
assert mock_push.call_count >= 3
|
417 |
-
|
418 |
-
# Verify the correct frames were pushed
|
419 |
-
call_args = [args[0][0] for args in mock_push.call_args_list]
|
420 |
-
assert any(isinstance(f, StartInterruptionFrame) for f in call_args)
|
421 |
-
assert any(isinstance(f, StopInterruptionFrame) for f in call_args)
|
422 |
-
assert any(isinstance(f, TranscriptionFrame) for f in call_args)
|
423 |
-
|
424 |
-
@pytest.mark.asyncio
|
425 |
-
async def test_process_complete_context_with_incomplete_prediction(self, speech_planner):
|
426 |
-
"""Test _process_complete_context with incomplete prediction."""
|
427 |
-
frame = TranscriptionFrame("Hello", "user1", datetime.now())
|
428 |
-
|
429 |
-
# Mock the LLM response with Label2 (incomplete)
|
430 |
-
mock_chunks = [MockBaseMessageChunk("Label2")]
|
431 |
-
|
432 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
433 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
434 |
-
with patch.object(speech_planner, "push_frame", new_callable=AsyncMock) as mock_push:
|
435 |
-
await speech_planner._process_complete_context(frame)
|
436 |
-
|
437 |
-
assert speech_planner.current_prediction == "Incomplete"
|
438 |
-
# Should not push any frames for incomplete prediction
|
439 |
-
mock_push.assert_not_called()
|
440 |
-
|
441 |
-
@pytest.mark.asyncio
|
442 |
-
async def test_process_complete_context_with_error(self, speech_planner):
|
443 |
-
"""Test _process_complete_context handles errors gracefully with proper logging."""
|
444 |
-
frame = TranscriptionFrame("Hello", "user1", datetime.now())
|
445 |
-
|
446 |
-
# Mock an exception during processing
|
447 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
448 |
-
mock_stream.side_effect = Exception("LLM service error")
|
449 |
-
|
450 |
-
# Patch logger to verify warning is logged with stack trace
|
451 |
-
with patch("nvidia_pipecat.services.speech_planner.logger") as mock_logger:
|
452 |
-
await speech_planner._process_complete_context(frame)
|
453 |
-
|
454 |
-
# Should default to "Complete" on error
|
455 |
-
assert speech_planner.current_prediction == "Complete"
|
456 |
-
|
457 |
-
# Should log warning with stack trace information
|
458 |
-
mock_logger.warning.assert_called_once()
|
459 |
-
call_args = mock_logger.warning.call_args
|
460 |
-
assert "Disabling Smart EOU detection due to error" in call_args[0][0]
|
461 |
-
assert "LLM service error" in call_args[0][0]
|
462 |
-
assert call_args[1]["exc_info"] is True
|
463 |
-
|
464 |
-
@pytest.mark.asyncio
|
465 |
-
async def test_process_complete_context_with_chunk_content_error(self, speech_planner):
|
466 |
-
"""Test _process_complete_context handles chunk content errors gracefully."""
|
467 |
-
frame = TranscriptionFrame("Hello", "user1", datetime.now())
|
468 |
-
|
469 |
-
# Mock chunks where one has content that causes concatenation error
|
470 |
-
class BadChunk:
|
471 |
-
def __init__(self):
|
472 |
-
# Use a number instead of None to pass the `if not chunk.content:` check
|
473 |
-
# but still cause TypeError during string concatenation
|
474 |
-
self.content = 42
|
475 |
-
|
476 |
-
mock_chunks = [MockBaseMessageChunk("Label"), BadChunk()]
|
477 |
-
|
478 |
-
with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
|
479 |
-
mock_stream.return_value.__aiter__.return_value = mock_chunks
|
480 |
-
|
481 |
-
# Patch logger to verify debug message is logged for chunk error
|
482 |
-
with patch("nvidia_pipecat.services.speech_planner.logger") as mock_logger:
|
483 |
-
await speech_planner._process_complete_context(frame)
|
484 |
-
|
485 |
-
# Should still get a prediction (either from first chunk or default "Complete")
|
486 |
-
assert speech_planner.current_prediction is not None
|
487 |
-
|
488 |
-
# Should log debug message for chunk content error
|
489 |
-
debug_calls = [
|
490 |
-
call for call in mock_logger.debug.call_args_list if "Failed to append chunk content" in str(call)
|
491 |
-
]
|
492 |
-
assert len(debug_calls) >= 1, (
|
493 |
-
f"Expected debug log for chunk error, got calls: {mock_logger.debug.call_args_list}"
|
494 |
-
)
|
495 |
-
|
496 |
-
|
497 |
-
class TestClientCreation:
|
498 |
-
"""Test client creation functionality."""
|
499 |
-
|
500 |
-
def test_create_client_with_params(self, mock_prompt_file, mock_context):
|
501 |
-
"""Test create_client method with parameters."""
|
502 |
-
with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat_class:
|
503 |
-
mock_client = AsyncMock()
|
504 |
-
mock_chat_class.return_value = mock_client
|
505 |
-
|
506 |
-
planner = SpeechPlanner(
|
507 |
-
prompt_file=mock_prompt_file,
|
508 |
-
context=mock_context,
|
509 |
-
model="test-model",
|
510 |
-
api_key="test-key",
|
511 |
-
base_url="http://test-url",
|
512 |
-
)
|
513 |
-
|
514 |
-
client = planner.create_client(api_key="custom-key", base_url="http://custom-url")
|
515 |
-
|
516 |
-
# Verify ChatNVIDIA was called with correct parameters
|
517 |
-
mock_chat_class.assert_called_with(base_url="http://custom-url", model="test-model", api_key="custom-key")
|
518 |
-
assert client == mock_client
|
519 |
-
|
520 |
-
|
521 |
-
@pytest.mark.asyncio
|
522 |
-
async def test_get_chat_completions(speech_planner):
|
523 |
-
"""Test get_chat_completions method."""
|
524 |
-
messages = [{"role": "user", "content": "Test message"}]
|
525 |
-
|
526 |
-
# Mock the client's astream method to return an async iterator
|
527 |
-
mock_chunks = [MockBaseMessageChunk("Response chunk")]
|
528 |
-
|
529 |
-
async def mock_astream(*args, **kwargs):
|
530 |
-
for chunk in mock_chunks:
|
531 |
-
yield chunk
|
532 |
-
|
533 |
-
speech_planner._client.astream = mock_astream
|
534 |
-
|
535 |
-
result = await speech_planner.get_chat_completions(messages)
|
536 |
-
|
537 |
-
# The result should be the return value from astream
|
538 |
-
assert result is not None
|
539 |
-
|
540 |
-
# Convert result to list to verify contents
|
541 |
-
result_list = []
|
542 |
-
async for chunk in result:
|
543 |
-
result_list.append(chunk)
|
544 |
-
|
545 |
-
assert len(result_list) == 1
|
546 |
-
assert result_list[0].content == "Response chunk"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_traced_processor.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests tracing."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from opentelemetry import trace
|
8 |
-
from pipecat.frames.frames import EndFrame, ErrorFrame, Frame, TextFrame
|
9 |
-
from pipecat.pipeline.pipeline import Pipeline
|
10 |
-
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
11 |
-
|
12 |
-
from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced
|
13 |
-
from tests.unit.utils import ignore_ids, run_test
|
14 |
-
|
15 |
-
|
16 |
-
@pytest.mark.asyncio
|
17 |
-
async def test_traced_processor_basic_usage():
|
18 |
-
"""Tests basic tracing functionality in a pipeline processor.
|
19 |
-
|
20 |
-
Tests the @traceable and @traced decorators with different attachment
|
21 |
-
strategies in a simple pipeline configuration.
|
22 |
-
|
23 |
-
The test verifies:
|
24 |
-
- Span creation and attachment
|
25 |
-
- Event recording
|
26 |
-
- Nested span handling
|
27 |
-
- Generator tracing
|
28 |
-
- Frame processing
|
29 |
-
"""
|
30 |
-
tracer = trace.get_tracer(__name__)
|
31 |
-
|
32 |
-
@traceable
|
33 |
-
class TestProcessor(FrameProcessor):
|
34 |
-
"""Example processor demonstrating how to use the tracing utilities."""
|
35 |
-
|
36 |
-
@traced(attachment_strategy=AttachmentStrategy.NONE)
|
37 |
-
async def process_frame(self, frame, direction):
|
38 |
-
"""Process a frame with tracing.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
frame: The frame to process.
|
42 |
-
direction: The direction of frame flow.
|
43 |
-
|
44 |
-
The method demonstrates:
|
45 |
-
- Basic span creation
|
46 |
-
- Event recording
|
47 |
-
- Nested span handling
|
48 |
-
- Multiple tracing strategies
|
49 |
-
"""
|
50 |
-
await super().process_frame(frame, direction)
|
51 |
-
trace.get_current_span().add_event("Before inner")
|
52 |
-
with tracer.start_as_current_span("inner") as span:
|
53 |
-
span.add_event("inner event")
|
54 |
-
await self.child()
|
55 |
-
await self.linked()
|
56 |
-
await self.none()
|
57 |
-
trace.get_current_span().add_event("After inner")
|
58 |
-
async for f in self.generator():
|
59 |
-
print(f"{f}")
|
60 |
-
await super().push_frame(frame, direction)
|
61 |
-
|
62 |
-
@traced
|
63 |
-
async def child(self):
|
64 |
-
"""Example method with child attachment strategy.
|
65 |
-
|
66 |
-
This span is attached as CHILD, meaning it will be attached to
|
67 |
-
the class span if no parent exists, or to its parent otherwise.
|
68 |
-
"""
|
69 |
-
trace.get_current_span().add_event("child")
|
70 |
-
|
71 |
-
@traced(attachment_strategy=AttachmentStrategy.LINK)
|
72 |
-
async def linked(self):
|
73 |
-
"""Example method with link attachment strategy.
|
74 |
-
|
75 |
-
This span is attached as LINK, meaning it will be attached to
|
76 |
-
the class span but linked to its parent.
|
77 |
-
"""
|
78 |
-
trace.get_current_span().add_event("linked")
|
79 |
-
|
80 |
-
@traced(attachment_strategy=AttachmentStrategy.NONE)
|
81 |
-
async def none(self):
|
82 |
-
"""Example method with no attachment strategy.
|
83 |
-
|
84 |
-
This span is attached as NONE, meaning it will be attached to
|
85 |
-
the class span even if nested under another span.
|
86 |
-
"""
|
87 |
-
trace.get_current_span().add_event("none")
|
88 |
-
|
89 |
-
@traced
|
90 |
-
async def generator(self):
|
91 |
-
"""Example generator method with tracing.
|
92 |
-
|
93 |
-
Demonstrates tracing in a generator context.
|
94 |
-
|
95 |
-
Yields:
|
96 |
-
TextFrame: Text frames with sample content.
|
97 |
-
"""
|
98 |
-
yield TextFrame("Hello, ")
|
99 |
-
trace.get_current_span().add_event("generated!")
|
100 |
-
yield TextFrame("World")
|
101 |
-
|
102 |
-
processor = TestProcessor()
|
103 |
-
pipeline = Pipeline([processor])
|
104 |
-
|
105 |
-
await run_test(
|
106 |
-
pipeline,
|
107 |
-
frames_to_send=[],
|
108 |
-
expected_down_frames=[],
|
109 |
-
)
|
110 |
-
|
111 |
-
|
112 |
-
@pytest.mark.asyncio
|
113 |
-
async def test_wrong_usage() -> None:
|
114 |
-
"""Test that error message is raised when a processor is not traceable."""
|
115 |
-
|
116 |
-
class TestProcessor(FrameProcessor):
|
117 |
-
@traced
|
118 |
-
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
119 |
-
await super().process_frame(frame, direction)
|
120 |
-
await super().push_frame(frame, direction)
|
121 |
-
|
122 |
-
class HasSeenError(FrameProcessor):
|
123 |
-
def __init__(self, **kwargs):
|
124 |
-
super().__init__(**kwargs)
|
125 |
-
self.seen_error = False
|
126 |
-
|
127 |
-
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
128 |
-
await super().process_frame(frame, direction)
|
129 |
-
if isinstance(frame, ErrorFrame):
|
130 |
-
self.seen_error = True
|
131 |
-
elif isinstance(frame, EndFrame):
|
132 |
-
assert self.seen_error
|
133 |
-
await super().push_frame(frame, direction)
|
134 |
-
|
135 |
-
seen_error = HasSeenError()
|
136 |
-
processor = TestProcessor()
|
137 |
-
pipeline = Pipeline([seen_error, processor])
|
138 |
-
await run_test(
|
139 |
-
pipeline,
|
140 |
-
frames_to_send=[],
|
141 |
-
expected_down_frames=[],
|
142 |
-
expected_up_frames=[
|
143 |
-
ignore_ids(ErrorFrame("@traced annotation can only be used in classes inheriting from Traceable"))
|
144 |
-
],
|
145 |
-
)
|
146 |
-
|
147 |
-
|
148 |
-
@pytest.mark.asyncio
|
149 |
-
async def test_no_processor() -> None:
|
150 |
-
"""Test that a processor can be used without a pipeline."""
|
151 |
-
|
152 |
-
@traceable
|
153 |
-
class TestTraceable:
|
154 |
-
@traced
|
155 |
-
async def traced_test(self):
|
156 |
-
trace.get_current_span().add_event("I can use it as another utility function as well.")
|
157 |
-
|
158 |
-
test = TestTraceable()
|
159 |
-
await test.traced_test()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_transcription_sync_processors.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for transcript synchronization processors.
|
5 |
-
|
6 |
-
This module contains tests that verify the behavior of transcript synchronization processors,
|
7 |
-
including both user and bot transcript synchronization with different TTS providers.
|
8 |
-
The tests ensure proper handling of speech events, transcriptions, and TTS frames.
|
9 |
-
"""
|
10 |
-
|
11 |
-
import pytest
|
12 |
-
from pipecat.frames.frames import (
|
13 |
-
BotStartedSpeakingFrame,
|
14 |
-
BotStoppedSpeakingFrame,
|
15 |
-
InterimTranscriptionFrame,
|
16 |
-
StartInterruptionFrame,
|
17 |
-
TranscriptionFrame,
|
18 |
-
TTSStartedFrame,
|
19 |
-
TTSStoppedFrame,
|
20 |
-
TTSTextFrame,
|
21 |
-
UserStartedSpeakingFrame,
|
22 |
-
UserStoppedSpeakingFrame,
|
23 |
-
)
|
24 |
-
from pipecat.tests.utils import SleepFrame
|
25 |
-
from pipecat.utils.time import time_now_iso8601
|
26 |
-
|
27 |
-
from nvidia_pipecat.frames.transcripts import (
|
28 |
-
BotUpdatedSpeakingTranscriptFrame,
|
29 |
-
UserStoppedSpeakingTranscriptFrame,
|
30 |
-
UserUpdatedSpeakingTranscriptFrame,
|
31 |
-
)
|
32 |
-
from nvidia_pipecat.processors.transcript_synchronization import (
|
33 |
-
BotTranscriptSynchronization,
|
34 |
-
UserTranscriptSynchronization,
|
35 |
-
)
|
36 |
-
from tests.unit.utils import ignore_ids, run_test
|
37 |
-
|
38 |
-
|
39 |
-
@pytest.mark.asyncio()
|
40 |
-
async def test_user_transcript_synchronization_processor():
|
41 |
-
"""Test the UserTranscriptSynchronization processor functionality.
|
42 |
-
|
43 |
-
Tests the complete flow of user speech transcription synchronization,
|
44 |
-
including interim and final transcriptions.
|
45 |
-
|
46 |
-
The test verifies:
|
47 |
-
- User speech start/stop handling
|
48 |
-
- Interim transcription processing
|
49 |
-
- Speaking transcript updates
|
50 |
-
- Final transcript generation
|
51 |
-
- Frame sequence ordering
|
52 |
-
- Multiple speech segment handling
|
53 |
-
"""
|
54 |
-
user_id = ""
|
55 |
-
interim_transcript_frames = [
|
56 |
-
InterimTranscriptionFrame("Hi", user_id, time_now_iso8601()),
|
57 |
-
InterimTranscriptionFrame("Hi there!", user_id, time_now_iso8601()),
|
58 |
-
InterimTranscriptionFrame("How are", user_id, time_now_iso8601()),
|
59 |
-
InterimTranscriptionFrame("How are you?", user_id, time_now_iso8601()),
|
60 |
-
]
|
61 |
-
finale_transcript_frame1 = TranscriptionFrame("Hi there!", user_id, time_now_iso8601())
|
62 |
-
finale_transcript_frame2 = TranscriptionFrame("How are you?", user_id, time_now_iso8601())
|
63 |
-
|
64 |
-
frames_to_send = [
|
65 |
-
UserStartedSpeakingFrame(),
|
66 |
-
interim_transcript_frames[0],
|
67 |
-
interim_transcript_frames[1],
|
68 |
-
SleepFrame(0.1),
|
69 |
-
UserStoppedSpeakingFrame(),
|
70 |
-
finale_transcript_frame1,
|
71 |
-
SleepFrame(0.1),
|
72 |
-
UserStartedSpeakingFrame(),
|
73 |
-
interim_transcript_frames[0],
|
74 |
-
interim_transcript_frames[1],
|
75 |
-
finale_transcript_frame1,
|
76 |
-
interim_transcript_frames[2],
|
77 |
-
interim_transcript_frames[3],
|
78 |
-
finale_transcript_frame2,
|
79 |
-
SleepFrame(0.1),
|
80 |
-
UserStoppedSpeakingFrame(),
|
81 |
-
]
|
82 |
-
|
83 |
-
expected_down_frames = [
|
84 |
-
ignore_ids(UserStartedSpeakingFrame()),
|
85 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("user started speaking")),
|
86 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi")),
|
87 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there!")),
|
88 |
-
ignore_ids(interim_transcript_frames[0]),
|
89 |
-
ignore_ids(interim_transcript_frames[1]),
|
90 |
-
ignore_ids(UserStoppedSpeakingFrame()),
|
91 |
-
ignore_ids(UserStoppedSpeakingTranscriptFrame("Hi there!")),
|
92 |
-
ignore_ids(finale_transcript_frame1),
|
93 |
-
ignore_ids(UserStartedSpeakingFrame()),
|
94 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("user started speaking")),
|
95 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi")),
|
96 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there!")),
|
97 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there! How are")),
|
98 |
-
ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there! How are you?")),
|
99 |
-
ignore_ids(interim_transcript_frames[0]),
|
100 |
-
ignore_ids(interim_transcript_frames[1]),
|
101 |
-
ignore_ids(finale_transcript_frame1),
|
102 |
-
ignore_ids(interim_transcript_frames[2]),
|
103 |
-
ignore_ids(interim_transcript_frames[3]),
|
104 |
-
ignore_ids(finale_transcript_frame2),
|
105 |
-
ignore_ids(UserStoppedSpeakingFrame()),
|
106 |
-
ignore_ids(UserStoppedSpeakingTranscriptFrame("Hi there! How are you?")),
|
107 |
-
]
|
108 |
-
|
109 |
-
await run_test(
|
110 |
-
UserTranscriptSynchronization("user started speaking"),
|
111 |
-
frames_to_send=frames_to_send,
|
112 |
-
expected_down_frames=expected_down_frames,
|
113 |
-
)
|
114 |
-
|
115 |
-
|
116 |
-
@pytest.mark.asyncio()
|
117 |
-
async def test_bot_transcript_synchronization_processor_with_riva_tts():
|
118 |
-
"""Test the BotTranscriptSynchronization processor with Riva TTS.
|
119 |
-
|
120 |
-
Tests the synchronization of bot transcripts when using Riva TTS,
|
121 |
-
including speech events and interruption handling.
|
122 |
-
|
123 |
-
The test verifies:
|
124 |
-
- Bot speech start/stop handling
|
125 |
-
- TTS text frame processing
|
126 |
-
- Speaking transcript updates
|
127 |
-
- Interruption handling
|
128 |
-
- Frame sequence ordering
|
129 |
-
- Multiple sentence handling
|
130 |
-
"""
|
131 |
-
tts_text_frames = [
|
132 |
-
TTSTextFrame("Welcome user!"),
|
133 |
-
TTSTextFrame("How are you?"),
|
134 |
-
TTSTextFrame("Did you have a nice day?"),
|
135 |
-
]
|
136 |
-
|
137 |
-
frames_to_send = [
|
138 |
-
TTSStartedFrame(), # Bot sentence transcript 1
|
139 |
-
tts_text_frames[0],
|
140 |
-
SleepFrame(0.1), # Give time for transcript to be buffered
|
141 |
-
BotStartedSpeakingFrame(), # Start playing sentence 1
|
142 |
-
TTSStoppedFrame(),
|
143 |
-
BotStoppedSpeakingFrame(), # End of playing sentence 1
|
144 |
-
SleepFrame(0.1),
|
145 |
-
TTSStartedFrame(), # Bot sentence transcript 2
|
146 |
-
tts_text_frames[1],
|
147 |
-
SleepFrame(0.1), # Give time for transcript to be buffered
|
148 |
-
BotStartedSpeakingFrame(), # Start playing sentence 2
|
149 |
-
TTSStoppedFrame(),
|
150 |
-
BotStoppedSpeakingFrame(), # End of playing sentence 2
|
151 |
-
SleepFrame(0.1),
|
152 |
-
TTSStartedFrame(), # Bot sentence transcript 3
|
153 |
-
tts_text_frames[2],
|
154 |
-
SleepFrame(0.1), # Give time for transcript to be buffered
|
155 |
-
BotStartedSpeakingFrame(), # Start playing sentence 3
|
156 |
-
TTSStoppedFrame(),
|
157 |
-
BotStoppedSpeakingFrame(), # End of playing sentence 3
|
158 |
-
SleepFrame(0.1),
|
159 |
-
StartInterruptionFrame(), # User interrupts
|
160 |
-
TTSStartedFrame(), # Bot sentence 1 again
|
161 |
-
tts_text_frames[0],
|
162 |
-
SleepFrame(0.1), # Give time for transcript to be buffered
|
163 |
-
BotStartedSpeakingFrame(), # Start playing sentence 1
|
164 |
-
TTSStoppedFrame(),
|
165 |
-
BotStoppedSpeakingFrame(), # End of playing sentence 1
|
166 |
-
]
|
167 |
-
|
168 |
-
expected_down_frames = [
|
169 |
-
ignore_ids(TTSStartedFrame()),
|
170 |
-
ignore_ids(tts_text_frames[0]),
|
171 |
-
ignore_ids(BotStartedSpeakingFrame()),
|
172 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user!")),
|
173 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
174 |
-
ignore_ids(TTSStoppedFrame()),
|
175 |
-
ignore_ids(TTSStartedFrame()),
|
176 |
-
ignore_ids(tts_text_frames[1]),
|
177 |
-
ignore_ids(BotStartedSpeakingFrame()),
|
178 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("How are you?")),
|
179 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
180 |
-
ignore_ids(TTSStoppedFrame()),
|
181 |
-
ignore_ids(TTSStartedFrame()),
|
182 |
-
ignore_ids(tts_text_frames[2]),
|
183 |
-
ignore_ids(BotStartedSpeakingFrame()),
|
184 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Did you have a nice day?")),
|
185 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
186 |
-
ignore_ids(TTSStoppedFrame()),
|
187 |
-
ignore_ids(StartInterruptionFrame()),
|
188 |
-
ignore_ids(TTSStartedFrame()),
|
189 |
-
ignore_ids(tts_text_frames[0]),
|
190 |
-
ignore_ids(BotStartedSpeakingFrame()),
|
191 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user!")),
|
192 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
193 |
-
ignore_ids(TTSStoppedFrame()),
|
194 |
-
]
|
195 |
-
|
196 |
-
await run_test(
|
197 |
-
BotTranscriptSynchronization(),
|
198 |
-
frames_to_send=frames_to_send,
|
199 |
-
expected_down_frames=expected_down_frames,
|
200 |
-
)
|
201 |
-
|
202 |
-
|
203 |
-
@pytest.mark.asyncio()
|
204 |
-
async def test_bot_transcript_synchronization_processor_with_elevenlabs_tts():
|
205 |
-
"""Test the BotTranscriptSynchronization processor with ElevenLabs TTS.
|
206 |
-
|
207 |
-
Tests the synchronization of bot transcripts when using ElevenLabs TTS,
|
208 |
-
including partial text handling and concatenation.
|
209 |
-
|
210 |
-
The test verifies:
|
211 |
-
- Bot speech start/stop handling
|
212 |
-
- Partial TTS text processing
|
213 |
-
- Speaking transcript updates
|
214 |
-
- Text concatenation
|
215 |
-
- Frame sequence ordering
|
216 |
-
- Complete transcript assembly
|
217 |
-
"""
|
218 |
-
tts_text_frames = [
|
219 |
-
TTSTextFrame("Welcome"),
|
220 |
-
TTSTextFrame("user!"),
|
221 |
-
TTSTextFrame("How"),
|
222 |
-
TTSTextFrame("are"),
|
223 |
-
TTSTextFrame("you?"),
|
224 |
-
]
|
225 |
-
|
226 |
-
frames_to_send = [
|
227 |
-
TTSStartedFrame(),
|
228 |
-
tts_text_frames[0],
|
229 |
-
SleepFrame(0.1),
|
230 |
-
BotStartedSpeakingFrame(),
|
231 |
-
tts_text_frames[1],
|
232 |
-
tts_text_frames[2],
|
233 |
-
tts_text_frames[3],
|
234 |
-
tts_text_frames[4],
|
235 |
-
SleepFrame(0.1),
|
236 |
-
TTSStoppedFrame(),
|
237 |
-
SleepFrame(0.1),
|
238 |
-
BotStoppedSpeakingFrame(),
|
239 |
-
]
|
240 |
-
|
241 |
-
expected_down_frames = [
|
242 |
-
ignore_ids(TTSStartedFrame()),
|
243 |
-
ignore_ids(tts_text_frames[0]),
|
244 |
-
ignore_ids(BotStartedSpeakingFrame()),
|
245 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome")),
|
246 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user!")),
|
247 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user! How")),
|
248 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user! How are")),
|
249 |
-
ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user! How are you?")),
|
250 |
-
ignore_ids(tts_text_frames[1]),
|
251 |
-
ignore_ids(tts_text_frames[2]),
|
252 |
-
ignore_ids(tts_text_frames[3]),
|
253 |
-
ignore_ids(tts_text_frames[4]),
|
254 |
-
ignore_ids(TTSStoppedFrame()),
|
255 |
-
ignore_ids(BotStoppedSpeakingFrame()),
|
256 |
-
]
|
257 |
-
|
258 |
-
await run_test(
|
259 |
-
BotTranscriptSynchronization(),
|
260 |
-
frames_to_send=frames_to_send,
|
261 |
-
expected_down_frames=expected_down_frames,
|
262 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_user_presence.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Unit tests for the user presence frame processor."""
|
5 |
-
|
6 |
-
import pytest
|
7 |
-
from pipecat.frames.frames import (
|
8 |
-
FilterControlFrame,
|
9 |
-
LLMUpdateSettingsFrame,
|
10 |
-
StartInterruptionFrame,
|
11 |
-
TextFrame,
|
12 |
-
TTSSpeakFrame,
|
13 |
-
TTSStartedFrame,
|
14 |
-
UserStartedSpeakingFrame,
|
15 |
-
)
|
16 |
-
from pipecat.tests.utils import SleepFrame
|
17 |
-
|
18 |
-
from nvidia_pipecat.frames.action import FinishedPresenceUserActionFrame, StartedPresenceUserActionFrame
|
19 |
-
from nvidia_pipecat.processors.user_presence import UserPresenceProcesssor
|
20 |
-
from tests.unit.utils import ignore, run_test
|
21 |
-
|
22 |
-
|
23 |
-
@pytest.mark.asyncio
|
24 |
-
async def test_user_presence_start():
|
25 |
-
"""Tests user presence start handling with welcome message.
|
26 |
-
|
27 |
-
Tests that the processor correctly handles user presence start events
|
28 |
-
and generates appropriate welcome messages.
|
29 |
-
|
30 |
-
Args:
|
31 |
-
None
|
32 |
-
|
33 |
-
Returns:
|
34 |
-
None
|
35 |
-
|
36 |
-
The test verifies:
|
37 |
-
- StartedPresenceUserActionFrame processing
|
38 |
-
- Welcome message generation
|
39 |
-
- UserStartedSpeakingFrame handling
|
40 |
-
- Frame sequence ordering
|
41 |
-
- Message content accuracy
|
42 |
-
"""
|
43 |
-
user_presence_bot = UserPresenceProcesssor(welcome_msg="Hey there!")
|
44 |
-
frames_to_send = [StartedPresenceUserActionFrame(action_id=123), SleepFrame(0.01), UserStartedSpeakingFrame()]
|
45 |
-
expected_down_frames = [
|
46 |
-
ignore(StartedPresenceUserActionFrame(action_id=123), "ids", "timestamps"),
|
47 |
-
ignore(TTSSpeakFrame("Hey there!"), "ids"),
|
48 |
-
ignore(UserStartedSpeakingFrame(), "ids", "timestamps"),
|
49 |
-
]
|
50 |
-
|
51 |
-
await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
52 |
-
|
53 |
-
|
54 |
-
@pytest.mark.asyncio
|
55 |
-
async def test_user_presence_finished():
|
56 |
-
"""Tests user presence finish handling with farewell message.
|
57 |
-
|
58 |
-
Tests that the processor correctly handles user presence finish events
|
59 |
-
and generates appropriate farewell messages.
|
60 |
-
|
61 |
-
Args:
|
62 |
-
None
|
63 |
-
|
64 |
-
Returns:
|
65 |
-
None
|
66 |
-
|
67 |
-
The test verifies:
|
68 |
-
- StartedPresenceUserActionFrame processing
|
69 |
-
- FinishedPresenceUserActionFrame handling
|
70 |
-
- Farewell message generation
|
71 |
-
- StartInterruptionFrame sequencing
|
72 |
-
- Frame ordering
|
73 |
-
- Message content accuracy
|
74 |
-
"""
|
75 |
-
user_presence_bot = UserPresenceProcesssor(farewell_msg="Bye bye!")
|
76 |
-
frames_to_send = [
|
77 |
-
StartedPresenceUserActionFrame(action_id=123),
|
78 |
-
SleepFrame(0.5),
|
79 |
-
FinishedPresenceUserActionFrame(action_id=123),
|
80 |
-
SleepFrame(0.5),
|
81 |
-
]
|
82 |
-
|
83 |
-
expected_down_frames = [
|
84 |
-
ignore(StartedPresenceUserActionFrame(action_id=123), "ids", "timestamps"),
|
85 |
-
ignore(TTSSpeakFrame("Hello"), "ids"),
|
86 |
-
# WAR: The StartInterruptionFrame is sent in response to the FinishedPresenceUserActionFrame.
|
87 |
-
# However, the test framework consistently logs it in the sequence prior to the FinishedPresenceUserFrame
|
88 |
-
ignore(StartInterruptionFrame(), "ids", "timestamps"),
|
89 |
-
ignore(FinishedPresenceUserActionFrame(action_id=123), "ids", "timestamps"),
|
90 |
-
ignore(TTSSpeakFrame("Bye bye!"), "ids"),
|
91 |
-
]
|
92 |
-
|
93 |
-
await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
94 |
-
|
95 |
-
|
96 |
-
@pytest.mark.asyncio
|
97 |
-
async def test_user_presence():
|
98 |
-
"""Tests behavior without presence frames.
|
99 |
-
|
100 |
-
Tests that no welcome/farewell messages are sent when no presence
|
101 |
-
frames are received.
|
102 |
-
|
103 |
-
Args:
|
104 |
-
None
|
105 |
-
|
106 |
-
Returns:
|
107 |
-
None
|
108 |
-
|
109 |
-
The test verifies:
|
110 |
-
- No messages without presence frames
|
111 |
-
- UserStartedSpeakingFrame handling
|
112 |
-
- Frame filtering behavior
|
113 |
-
"""
|
114 |
-
user_presence_bot = UserPresenceProcesssor(welcome_msg="Hello", farewell_msg="Bye bye!")
|
115 |
-
frames_to_send = [UserStartedSpeakingFrame()]
|
116 |
-
expected_down_frames = []
|
117 |
-
|
118 |
-
await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
119 |
-
|
120 |
-
|
121 |
-
@pytest.mark.asyncio
|
122 |
-
async def test_user_presence_system_frames():
|
123 |
-
"""Tests system and control frame handling.
|
124 |
-
|
125 |
-
Tests that system and control frames are processed regardless of
|
126 |
-
user presence state.
|
127 |
-
|
128 |
-
Args:
|
129 |
-
None
|
130 |
-
|
131 |
-
Returns:
|
132 |
-
None
|
133 |
-
|
134 |
-
The test verifies:
|
135 |
-
- TTSStartedFrame processing
|
136 |
-
- FilterControlFrame handling
|
137 |
-
- LLMUpdateSettingsFrame processing
|
138 |
-
- TextFrame filtering
|
139 |
-
- Frame passthrough behavior
|
140 |
-
- Frame sequence preservation
|
141 |
-
"""
|
142 |
-
user_presence_bot = UserPresenceProcesssor(welcome_msg="Hello", farewell_msg="Bye bye!")
|
143 |
-
|
144 |
-
frames_to_send = [
|
145 |
-
TTSStartedFrame(),
|
146 |
-
FilterControlFrame(),
|
147 |
-
LLMUpdateSettingsFrame(settings="ABC"),
|
148 |
-
TextFrame("How are you?"),
|
149 |
-
]
|
150 |
-
|
151 |
-
expected_down_frames = [
|
152 |
-
ignore(TTSStartedFrame(), "ids", "timestamps"),
|
153 |
-
ignore(FilterControlFrame(), "ids", "timestamps"),
|
154 |
-
ignore(LLMUpdateSettingsFrame(settings="ABC"), "ids", "timestamps"),
|
155 |
-
]
|
156 |
-
|
157 |
-
await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/test_utils.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
"""Unit tests for utility processors."""
|
2 |
-
|
3 |
-
import pytest
|
4 |
-
from pipecat.frames.frames import Frame, TextFrame
|
5 |
-
|
6 |
-
from nvidia_pipecat.processors.utils import FrameBlockingProcessor
|
7 |
-
from tests.unit.utils import ignore_ids, run_test
|
8 |
-
|
9 |
-
|
10 |
-
class TestFrame(Frame):
|
11 |
-
"""Test frame for testing FrameBlockingProcessor."""
|
12 |
-
|
13 |
-
pass
|
14 |
-
|
15 |
-
|
16 |
-
class TestResetFrame(Frame):
|
17 |
-
"""Test frame for testing FrameBlockingProcessor reset."""
|
18 |
-
|
19 |
-
pass
|
20 |
-
|
21 |
-
|
22 |
-
@pytest.mark.asyncio()
|
23 |
-
async def test_frame_blocking_processor():
|
24 |
-
"""Test that FrameBlockingProcessor blocks frames after threshold.
|
25 |
-
|
26 |
-
Verifies that:
|
27 |
-
- Frames are passed through until threshold is reached
|
28 |
-
- Frames are blocked after threshold
|
29 |
-
- Non-matching frame types are always passed through
|
30 |
-
"""
|
31 |
-
# Create processor that blocks after 2 TextFrames
|
32 |
-
processor = FrameBlockingProcessor(block_after_frame=2, frame_type=TextFrame)
|
33 |
-
|
34 |
-
# Create test frames - mix of TextFrames and StartFrames
|
35 |
-
frames_to_send = [
|
36 |
-
TestFrame(), # Should pass through
|
37 |
-
TextFrame("First"), # Should pass through
|
38 |
-
TextFrame("Second"), # Should pass through
|
39 |
-
TextFrame("Third"), # Should be blocked
|
40 |
-
TestFrame(), # Should pass through
|
41 |
-
TextFrame("Fourth"), # Should be blocked
|
42 |
-
]
|
43 |
-
|
44 |
-
# Expected frames - all StartFrames and first 2 TextFrames
|
45 |
-
expected_down_frames = [
|
46 |
-
ignore_ids(TestFrame()),
|
47 |
-
ignore_ids(TextFrame("First")),
|
48 |
-
ignore_ids(TextFrame("Second")),
|
49 |
-
ignore_ids(TestFrame()),
|
50 |
-
]
|
51 |
-
|
52 |
-
await run_test(
|
53 |
-
processor,
|
54 |
-
frames_to_send=frames_to_send,
|
55 |
-
expected_down_frames=expected_down_frames,
|
56 |
-
)
|
57 |
-
|
58 |
-
|
59 |
-
@pytest.mark.asyncio()
|
60 |
-
async def test_frame_blocking_processor_with_reset():
|
61 |
-
"""Test that FrameBlockingProcessor resets counter on reset frame type.
|
62 |
-
|
63 |
-
Verifies that:
|
64 |
-
- Counter resets when reset frame type is received
|
65 |
-
- Frames are blocked after threshold
|
66 |
-
- Counter can be reset multiple times
|
67 |
-
"""
|
68 |
-
# Create processor that blocks after 2 TextFrames and resets on StartFrame
|
69 |
-
processor = FrameBlockingProcessor(block_after_frame=2, frame_type=TextFrame, reset_frame_type=TestResetFrame)
|
70 |
-
|
71 |
-
# Create test frames - mix of TextFrames and TestResetFrame
|
72 |
-
frames_to_send = [
|
73 |
-
TextFrame("First"), # Should pass through
|
74 |
-
TextFrame("Second"), # Should pass through
|
75 |
-
TextFrame("Third"), # Should be blocked
|
76 |
-
TestResetFrame(), # Should reset counter and pass through
|
77 |
-
TextFrame("Fourth"), # Should pass through (counter reset)
|
78 |
-
TextFrame("Fifth"), # Should pass through
|
79 |
-
TextFrame("Sixth"), # Should be blocked
|
80 |
-
]
|
81 |
-
|
82 |
-
# Expected frames - all TestResetFrame and TextFrames except blocked ones
|
83 |
-
expected_down_frames = [
|
84 |
-
ignore_ids(TextFrame("First")),
|
85 |
-
ignore_ids(TextFrame("Second")),
|
86 |
-
ignore_ids(TestResetFrame()),
|
87 |
-
ignore_ids(TextFrame("Fourth")),
|
88 |
-
ignore_ids(TextFrame("Fifth")),
|
89 |
-
]
|
90 |
-
|
91 |
-
await run_test(
|
92 |
-
processor,
|
93 |
-
frames_to_send=frames_to_send,
|
94 |
-
expected_down_frames=expected_down_frames,
|
95 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/utils.py
DELETED
@@ -1,428 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: BSD 2-Clause License
|
3 |
-
|
4 |
-
"""Utility functions and classes for testing pipelines and processors.
|
5 |
-
|
6 |
-
This module provides various test utilities including frame processors, test runners,
|
7 |
-
and helper functions for frame comparison in tests.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import asyncio
|
11 |
-
from asyncio.tasks import sleep
|
12 |
-
from collections.abc import Sequence
|
13 |
-
from copy import copy
|
14 |
-
from dataclasses import dataclass
|
15 |
-
from datetime import datetime, timedelta
|
16 |
-
from typing import Any
|
17 |
-
from unittest.mock import ANY
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
from pipecat.frames.frames import EndFrame, Frame, InputAudioRawFrame, StartFrame
|
21 |
-
from pipecat.pipeline.pipeline import Pipeline
|
22 |
-
from pipecat.pipeline.runner import PipelineRunner
|
23 |
-
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
24 |
-
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
25 |
-
from pipecat.tests.utils import QueuedFrameProcessor
|
26 |
-
from pipecat.tests.utils import run_test as run_pipecat_test
|
27 |
-
|
28 |
-
from nvidia_pipecat.frames.action import ActionFrame
|
29 |
-
|
30 |
-
|
31 |
-
def ignore(frame: Frame, *args: str):
|
32 |
-
"""Return a copy of the frames with attributes listed in *args set to unittest.mock.ANY.
|
33 |
-
|
34 |
-
Any attribute listed in args will be ignored when using comparisons against the returned frame.
|
35 |
-
|
36 |
-
Args:
|
37 |
-
frame (Frame): Frame to create a copy from and set selected attributes to unittest.mock.ANY.
|
38 |
-
*args (str): Attribute names. Special values to ignore common sets of attributes:
|
39 |
-
'ids': ignore standard frame attributes related to frame ids
|
40 |
-
'all_ids': ignore standard frame attributes related to frame and action ids
|
41 |
-
'timestamps': ignore standard frame attributes containing timestamps
|
42 |
-
|
43 |
-
Returns:
|
44 |
-
Frame: A copy of the input frame with specified attributes set to ANY.
|
45 |
-
|
46 |
-
Raises:
|
47 |
-
ValueError: If an attribute name is not found in the frame.
|
48 |
-
"""
|
49 |
-
new_frame = copy(frame)
|
50 |
-
for arg in args:
|
51 |
-
is_updated = False
|
52 |
-
if arg == "all_ids" or arg == "ids":
|
53 |
-
new_frame.id = ANY
|
54 |
-
new_frame.name = ANY
|
55 |
-
is_updated = True
|
56 |
-
if arg == "all_ids":
|
57 |
-
new_frame.action_id = ANY
|
58 |
-
is_updated = True
|
59 |
-
if arg == "timestamps":
|
60 |
-
new_frame.pts = ANY
|
61 |
-
new_frame.action_started_at = ANY
|
62 |
-
new_frame.action_finished_at = ANY
|
63 |
-
new_frame.action_updated_at = ANY
|
64 |
-
is_updated = True
|
65 |
-
|
66 |
-
if hasattr(new_frame, arg):
|
67 |
-
setattr(new_frame, arg, ANY)
|
68 |
-
elif not is_updated:
|
69 |
-
raise ValueError(
|
70 |
-
f"Frame {frame.__class__.__name__} has not attribute '{arg}' to ignore. Did you misspell it?"
|
71 |
-
)
|
72 |
-
return new_frame
|
73 |
-
|
74 |
-
|
75 |
-
def ignore_ids(frame: Frame) -> Frame:
|
76 |
-
"""Return a copy of the frame that matches any id and name.
|
77 |
-
|
78 |
-
This is useful if you do not want to assert a specific ID and name.
|
79 |
-
|
80 |
-
Args:
|
81 |
-
frame (Frame): The frame to create a copy from.
|
82 |
-
|
83 |
-
Returns:
|
84 |
-
Frame: A copy of the input frame with ID and name set to ANY.
|
85 |
-
"""
|
86 |
-
new_frame = copy(frame)
|
87 |
-
new_frame.id = ANY
|
88 |
-
new_frame.name = ANY
|
89 |
-
if isinstance(frame, ActionFrame):
|
90 |
-
new_frame.action_id = ANY
|
91 |
-
return new_frame
|
92 |
-
|
93 |
-
|
94 |
-
def ignore_timestamps(frame: Frame) -> Frame:
|
95 |
-
"""Return a copy of the frame that matches frames ignoring any timestamps.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
frame (Frame): The frame to create a copy from.
|
99 |
-
|
100 |
-
Returns:
|
101 |
-
Frame: A copy of the input frame with all timestamp fields set to ANY.
|
102 |
-
"""
|
103 |
-
new_frame = copy(frame)
|
104 |
-
new_frame.pts = ANY
|
105 |
-
new_frame.action_started_at = ANY
|
106 |
-
new_frame.action_finished_at = ANY
|
107 |
-
new_frame.action_updated_at = ANY
|
108 |
-
return new_frame
|
109 |
-
|
110 |
-
|
111 |
-
@dataclass
|
112 |
-
class FrameHistoryEntry:
|
113 |
-
"""Storing a frame and the frame direction.
|
114 |
-
|
115 |
-
Attributes:
|
116 |
-
frame (Frame): The stored frame.
|
117 |
-
direction (FrameDirection): The direction of the frame in the pipeline.
|
118 |
-
"""
|
119 |
-
|
120 |
-
frame: Frame
|
121 |
-
direction: FrameDirection
|
122 |
-
|
123 |
-
|
124 |
-
class FrameStorage(FrameProcessor):
|
125 |
-
"""A frame processor that stores all received frames in memory for inspection.
|
126 |
-
|
127 |
-
This processor maintains a history of all frames that pass through it, along with their
|
128 |
-
direction, allowing for later inspection and verification in tests.
|
129 |
-
|
130 |
-
Attributes:
|
131 |
-
history (list[FrameHistoryEntry]): List of all frames that have passed through the processor.
|
132 |
-
"""
|
133 |
-
|
134 |
-
def __init__(self, **kwargs) -> None:
|
135 |
-
"""Initialize the FrameStorage processor.
|
136 |
-
|
137 |
-
Args:
|
138 |
-
**kwargs: Additional keyword arguments passed to the parent FrameProcessor.
|
139 |
-
"""
|
140 |
-
super().__init__(**kwargs)
|
141 |
-
self.history: list[FrameHistoryEntry] = []
|
142 |
-
|
143 |
-
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
144 |
-
"""Process an incoming frame by storing it in history and forwarding it.
|
145 |
-
|
146 |
-
Args:
|
147 |
-
frame (Frame): The frame to process.
|
148 |
-
direction (FrameDirection): The direction the frame is traveling in the pipeline.
|
149 |
-
"""
|
150 |
-
await super().process_frame(frame, direction)
|
151 |
-
|
152 |
-
self.history.append(FrameHistoryEntry(frame, direction))
|
153 |
-
await self.push_frame(frame, direction)
|
154 |
-
|
155 |
-
def frames_of_type(self, t) -> list[Frame]:
|
156 |
-
"""Get all frames of a specific type from the history.
|
157 |
-
|
158 |
-
Args:
|
159 |
-
t: The type of frames to filter for.
|
160 |
-
|
161 |
-
Returns:
|
162 |
-
list[Frame]: List of all frames matching the specified type.
|
163 |
-
"""
|
164 |
-
return [e.frame for e in self.history if isinstance(e.frame, t)]
|
165 |
-
|
166 |
-
async def wait_for_frame(self, frame: Frame, timeout: timedelta = timedelta(seconds=5.0)) -> None:
|
167 |
-
"""Block until a matching frame is found in history.
|
168 |
-
|
169 |
-
Args:
|
170 |
-
frame (Frame): The frame to wait for.
|
171 |
-
timeout (timedelta, optional): Maximum time to wait. Defaults to 5 seconds.
|
172 |
-
|
173 |
-
Raises:
|
174 |
-
TimeoutError: If the frame is not found within the timeout period.
|
175 |
-
"""
|
176 |
-
candidates = {}
|
177 |
-
|
178 |
-
def is_same_frame(frame_a, frame_b) -> bool:
|
179 |
-
if type(frame_a) is not type(frame_b):
|
180 |
-
return False
|
181 |
-
|
182 |
-
if frame_a == frame_b:
|
183 |
-
return True
|
184 |
-
else:
|
185 |
-
candidates[frame_a.id] = frame_a.__repr__()
|
186 |
-
return False
|
187 |
-
|
188 |
-
found = False
|
189 |
-
start_time = datetime.now()
|
190 |
-
while not found:
|
191 |
-
found = any([is_same_frame(entry.frame, frame) for entry in self.history])
|
192 |
-
if not found:
|
193 |
-
if datetime.now() - start_time > timeout:
|
194 |
-
raise TimeoutError(
|
195 |
-
"Frame not found until timeout reached.\n"
|
196 |
-
f"EXPECTED:\n{frame.__repr__()}\n"
|
197 |
-
f"FOUND\n{'\n'.join(candidates.values())}"
|
198 |
-
)
|
199 |
-
await asyncio.sleep(0.01)
|
200 |
-
|
201 |
-
|
202 |
-
class SinusWaveProcessor(FrameProcessor):
|
203 |
-
"""A frame processor that generates a sine wave audio signal.
|
204 |
-
|
205 |
-
This processor generates a continuous sine wave at 440 Hz (A4 note) when started,
|
206 |
-
and outputs it as audio frames. It is useful for testing audio processing pipelines.
|
207 |
-
|
208 |
-
Attributes:
|
209 |
-
duration (timedelta): The total duration of the sine wave to generate.
|
210 |
-
chunk_duration (float): Duration of each audio chunk in seconds.
|
211 |
-
audio_frame_count (int): Total number of audio frames to generate.
|
212 |
-
audio_task (asyncio.Task | None): Task handling the audio generation.
|
213 |
-
"""
|
214 |
-
|
215 |
-
def __init__(
|
216 |
-
self,
|
217 |
-
*,
|
218 |
-
duration: timedelta,
|
219 |
-
**kwargs,
|
220 |
-
):
|
221 |
-
"""Initialize the SinusWaveProcessor.
|
222 |
-
|
223 |
-
Args:
|
224 |
-
duration (timedelta): The total duration of the sine wave to generate.
|
225 |
-
**kwargs: Additional keyword arguments passed to the parent FrameProcessor.
|
226 |
-
"""
|
227 |
-
super().__init__(**kwargs)
|
228 |
-
self.duration = duration
|
229 |
-
self.audio_task: asyncio.Task | None = None
|
230 |
-
|
231 |
-
self.chunk_duration = 0.02 # 20 milliseconds
|
232 |
-
self.audio_frame_count = round(self.duration.total_seconds() / self.chunk_duration)
|
233 |
-
|
234 |
-
async def _cancel_audio_task(self):
|
235 |
-
"""Cancel the running audio generation task if it exists."""
|
236 |
-
if self.audio_task and not self.audio_task.done():
|
237 |
-
await self.cancel_task(self.audio_task)
|
238 |
-
self.audio_task = None
|
239 |
-
|
240 |
-
async def stop(self):
|
241 |
-
"""Stop the audio generation by canceling the audio task."""
|
242 |
-
await self._cancel_audio_task()
|
243 |
-
|
244 |
-
async def cleanup(self):
|
245 |
-
"""Clean up resources by stopping audio generation and calling parent cleanup."""
|
246 |
-
await super().cleanup()
|
247 |
-
await self._cancel_audio_task()
|
248 |
-
|
249 |
-
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
250 |
-
"""Process incoming frames to start/stop audio generation.
|
251 |
-
|
252 |
-
Starts audio generation on StartFrame and stops on EndFrame.
|
253 |
-
|
254 |
-
Args:
|
255 |
-
frame (Frame): The frame to process.
|
256 |
-
direction (FrameDirection): The direction the frame is traveling in the pipeline.
|
257 |
-
"""
|
258 |
-
await super().process_frame(frame, direction)
|
259 |
-
if isinstance(frame, StartFrame):
|
260 |
-
self.audio_task = self.create_task(self.start())
|
261 |
-
if isinstance(frame, EndFrame):
|
262 |
-
await self.stop()
|
263 |
-
await super().push_frame(frame, direction)
|
264 |
-
|
265 |
-
async def start(self):
|
266 |
-
"""Start generating and outputting sine wave audio frames.
|
267 |
-
|
268 |
-
Generates a continuous 440 Hz sine wave, split into 20ms chunks,
|
269 |
-
and outputs them as InputAudioRawFrame instances.
|
270 |
-
"""
|
271 |
-
sample_rate = 16000 # Hz
|
272 |
-
frequency = 440 # Hz (A4 tone)
|
273 |
-
|
274 |
-
phase_offset = 0
|
275 |
-
for _ in range(self.audio_frame_count):
|
276 |
-
chunk_samples = int(sample_rate * self.chunk_duration)
|
277 |
-
|
278 |
-
# Generate the time axis
|
279 |
-
t = np.arange(chunk_samples) / sample_rate
|
280 |
-
|
281 |
-
# Create the sine wave with the given phase offset
|
282 |
-
sine_wave = 0.5 * np.sin(2 * np.pi * frequency * t + phase_offset)
|
283 |
-
|
284 |
-
# Calculate the new phase offset for the next chunk
|
285 |
-
phase_offset = (2 * np.pi * frequency * self.chunk_duration + phase_offset) % (2 * np.pi)
|
286 |
-
|
287 |
-
sine_wave_pcm = (sine_wave * 32767).astype(np.int16)
|
288 |
-
|
289 |
-
pipecat_frame = InputAudioRawFrame(audio=sine_wave_pcm.tobytes(), sample_rate=sample_rate, num_channels=1)
|
290 |
-
await super().push_frame(pipecat_frame)
|
291 |
-
|
292 |
-
await sleep(self.chunk_duration)
|
293 |
-
|
294 |
-
|
295 |
-
async def run_test(
|
296 |
-
processor: FrameProcessor,
|
297 |
-
*,
|
298 |
-
frames_to_send: Sequence[Frame],
|
299 |
-
expected_down_frames: Sequence[Frame],
|
300 |
-
expected_up_frames: Sequence[Frame] = [],
|
301 |
-
ignore_start: bool = True,
|
302 |
-
start_metadata: dict[str, Any] | None = None,
|
303 |
-
send_end_frame: bool = True,
|
304 |
-
) -> tuple[Sequence[Frame], Sequence[Frame]]:
|
305 |
-
"""Run a test on a frame processor with predefined input and expected output frames.
|
306 |
-
|
307 |
-
Args:
|
308 |
-
processor (FrameProcessor): The processor to test.
|
309 |
-
frames_to_send (Sequence[Frame]): Frames to send through the processor.
|
310 |
-
expected_down_frames (Sequence[Frame]): Expected frames in downstream direction.
|
311 |
-
expected_up_frames (Sequence[Frame], optional): Expected frames in upstream direction.
|
312 |
-
Defaults to [].
|
313 |
-
ignore_start (bool, optional): Whether to ignore start frames. Defaults to True.
|
314 |
-
start_metadata (dict[str, Any], optional): Metadata to include in start frame.
|
315 |
-
Defaults to None.
|
316 |
-
send_end_frame (bool, optional): Whether to send an end frame. Defaults to True.
|
317 |
-
|
318 |
-
Returns:
|
319 |
-
tuple[Sequence[Frame], Sequence[Frame]]: Tuple of (received downstream frames, received upstream frames).
|
320 |
-
|
321 |
-
Raises:
|
322 |
-
AssertionError: If received frames don't match expected frames.
|
323 |
-
"""
|
324 |
-
if start_metadata is None:
|
325 |
-
start_metadata = {}
|
326 |
-
received_down_frames, received_up_frames = await run_pipecat_test(
|
327 |
-
processor,
|
328 |
-
frames_to_send=frames_to_send,
|
329 |
-
expected_down_frames=[f.__class__ for f in expected_down_frames],
|
330 |
-
expected_up_frames=[f.__class__ for f in expected_up_frames],
|
331 |
-
ignore_start=ignore_start,
|
332 |
-
start_metadata=start_metadata,
|
333 |
-
send_end_frame=send_end_frame,
|
334 |
-
)
|
335 |
-
|
336 |
-
for real, expected in zip(received_up_frames, expected_up_frames, strict=True):
|
337 |
-
assert real == expected, f"Frame mismatch: \nreal: {repr(real)} \nexpected: {repr(expected)}"
|
338 |
-
|
339 |
-
for real, expected in zip(received_down_frames, expected_down_frames, strict=True):
|
340 |
-
assert real == expected, f"Frame mismatch: \nreal: {repr(real)} \nexpected: {repr(expected)}"
|
341 |
-
|
342 |
-
return received_down_frames, received_up_frames
|
343 |
-
|
344 |
-
|
345 |
-
async def run_interactive_test(
|
346 |
-
processor: FrameProcessor,
|
347 |
-
*,
|
348 |
-
test_coroutine,
|
349 |
-
start_metadata: dict[str, Any] | None = None,
|
350 |
-
ignore_start: bool = True,
|
351 |
-
send_end_frame: bool = True,
|
352 |
-
) -> tuple[Sequence[Frame], Sequence[Frame]]:
|
353 |
-
"""Run an interactive test on a frame processor with a custom test coroutine.
|
354 |
-
|
355 |
-
This function allows for more complex testing scenarios where frames need to be
|
356 |
-
sent and received dynamically during the test.
|
357 |
-
|
358 |
-
Args:
|
359 |
-
processor (FrameProcessor): The processor to test.
|
360 |
-
test_coroutine: Coroutine function that implements the test logic.
|
361 |
-
start_metadata (dict[str, Any], optional): Metadata to include in start frame.
|
362 |
-
Defaults to None.
|
363 |
-
ignore_start (bool, optional): Whether to ignore start frames. Defaults to True.
|
364 |
-
send_end_frame (bool, optional): Whether to send an end frame. Defaults to True.
|
365 |
-
|
366 |
-
Returns:
|
367 |
-
tuple[Sequence[Frame], Sequence[Frame]]: Tuple of (received downstream frames, received upstream frames).
|
368 |
-
"""
|
369 |
-
if start_metadata is None:
|
370 |
-
start_metadata = {}
|
371 |
-
received_up = asyncio.Queue()
|
372 |
-
received_down = asyncio.Queue()
|
373 |
-
source = QueuedFrameProcessor(
|
374 |
-
queue=received_up,
|
375 |
-
queue_direction=FrameDirection.UPSTREAM,
|
376 |
-
ignore_start=ignore_start,
|
377 |
-
)
|
378 |
-
sink = QueuedFrameProcessor(
|
379 |
-
queue=received_down,
|
380 |
-
queue_direction=FrameDirection.DOWNSTREAM,
|
381 |
-
ignore_start=ignore_start,
|
382 |
-
)
|
383 |
-
|
384 |
-
pipeline = Pipeline([source, processor, sink])
|
385 |
-
|
386 |
-
task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata))
|
387 |
-
|
388 |
-
async def run_test():
|
389 |
-
# Just give a little head start to the runner.
|
390 |
-
await asyncio.sleep(0.01)
|
391 |
-
await test_coroutine(task)
|
392 |
-
|
393 |
-
if send_end_frame:
|
394 |
-
await task.queue_frame(EndFrame())
|
395 |
-
|
396 |
-
# await asyncio.sleep(1.5)
|
397 |
-
# debug = ""
|
398 |
-
# for t in processor.get_task_manager().current_tasks():
|
399 |
-
# debug += f"\nMy Task {t.get_name()} is still running."
|
400 |
-
|
401 |
-
# assert False, debug
|
402 |
-
# print(debug)
|
403 |
-
|
404 |
-
runner = PipelineRunner()
|
405 |
-
await asyncio.gather(runner.run(task), run_test())
|
406 |
-
|
407 |
-
#
|
408 |
-
# Down frames
|
409 |
-
#
|
410 |
-
received_down_frames: list[Frame] = []
|
411 |
-
while not received_down.empty():
|
412 |
-
frame = await received_down.get()
|
413 |
-
if not isinstance(frame, EndFrame) or not send_end_frame:
|
414 |
-
received_down_frames.append(frame)
|
415 |
-
|
416 |
-
print("received DOWN frames =", received_down_frames)
|
417 |
-
|
418 |
-
#
|
419 |
-
# Up frames
|
420 |
-
#
|
421 |
-
received_up_frames: list[Frame] = []
|
422 |
-
while not received_up.empty():
|
423 |
-
frame = await received_up.get()
|
424 |
-
received_up_frames.append(frame)
|
425 |
-
|
426 |
-
print("received UP frames =", received_up_frames)
|
427 |
-
|
428 |
-
return (received_down_frames, received_up_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|