fciannella commited on
Commit
153165b
·
1 Parent(s): 53ea588

removed the binary files

Browse files
Files changed (39) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. examples/voice_agent_webrtc_langgraph/start.sh +13 -0
  4. tests/__init__.py +0 -4
  5. tests/perf/README.md +0 -101
  6. tests/perf/file_input_client.py +0 -581
  7. tests/perf/run_multi_client_benchmark.sh +0 -414
  8. tests/perf/ttfb_analyzer.py +0 -253
  9. tests/unit/__init__.py +0 -4
  10. tests/unit/configs/animation_config.yaml +0 -346
  11. tests/unit/configs/test_speech_planner_prompt.yaml +0 -15
  12. tests/unit/test_ace_websocket_serializer.py +0 -147
  13. tests/unit/test_acknowledgment.py +0 -71
  14. tests/unit/test_animation_graph_services.py +0 -668
  15. tests/unit/test_audio2face_3d_service.py +0 -182
  16. tests/unit/test_audio_util.py +0 -64
  17. tests/unit/test_basic_pipelines.py +0 -130
  18. tests/unit/test_blingfire_text_aggregator.py +0 -244
  19. tests/unit/test_custom_view.py +0 -203
  20. tests/unit/test_elevenlabs.py +0 -184
  21. tests/unit/test_frame_creation.py +0 -148
  22. tests/unit/test_gesture.py +0 -94
  23. tests/unit/test_guardrail.py +0 -110
  24. tests/unit/test_message_broker.py +0 -111
  25. tests/unit/test_nvidia_aggregators.py +0 -396
  26. tests/unit/test_nvidia_llm_service.py +0 -386
  27. tests/unit/test_nvidia_rag_service.py +0 -261
  28. tests/unit/test_nvidia_tts_response_cacher.py +0 -79
  29. tests/unit/test_posture.py +0 -104
  30. tests/unit/test_proactivity.py +0 -85
  31. tests/unit/test_riva_asr_service.py +0 -523
  32. tests/unit/test_riva_nmt_service.py +0 -197
  33. tests/unit/test_riva_tts_service.py +0 -301
  34. tests/unit/test_speech_planner.py +0 -546
  35. tests/unit/test_traced_processor.py +0 -159
  36. tests/unit/test_transcription_sync_processors.py +0 -262
  37. tests/unit/test_user_presence.py +0 -157
  38. tests/unit/test_utils.py +0 -95
  39. tests/unit/utils.py +0 -428
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -66,4 +66,6 @@ pnpm-debug.log*
66
  # --- Example runtime artifacts ---
67
  examples/voice_agent_webrtc_langgraph/audio_dumps/
68
  examples/voice_agent_webrtc_langgraph/ui/dist/
 
 
69
 
 
66
  # --- Example runtime artifacts ---
67
  examples/voice_agent_webrtc_langgraph/audio_dumps/
68
  examples/voice_agent_webrtc_langgraph/ui/dist/
69
+ examples/voice_agent_webrtc_langgraph/audio_prompt.wav
70
+ tests/perf/audio_files/*.wav
71
 
examples/voice_agent_webrtc_langgraph/start.sh CHANGED
@@ -9,6 +9,19 @@ if [ -f "/app/examples/voice_agent_webrtc_langgraph/.env" ]; then
9
  set +a
10
  fi
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # All dependencies and langgraph CLI are installed at build time
13
 
14
  # Start langgraph dev from within the internal agents directory (background)
 
9
  set +a
10
  fi
11
 
12
+ # If a remote prompt URL is provided, download it and export ZERO_SHOT_AUDIO_PROMPT
13
+ if [ -n "${ZERO_SHOT_AUDIO_PROMPT_URL:-}" ]; then
14
+ PROMPT_TARGET="${ZERO_SHOT_AUDIO_PROMPT:-/app/examples/voice_agent_webrtc_langgraph/audio_prompt.wav}"
15
+ mkdir -p "$(dirname "$PROMPT_TARGET")"
16
+ if [ ! -f "$PROMPT_TARGET" ]; then
17
+ echo "Downloading ZERO_SHOT_AUDIO_PROMPT from $ZERO_SHOT_AUDIO_PROMPT_URL"
18
+ if ! curl -fsSL "$ZERO_SHOT_AUDIO_PROMPT_URL" -o "$PROMPT_TARGET"; then
19
+ echo "Failed to download audio prompt from URL: $ZERO_SHOT_AUDIO_PROMPT_URL" >&2
20
+ fi
21
+ fi
22
+ export ZERO_SHOT_AUDIO_PROMPT="$PROMPT_TARGET"
23
+ fi
24
+
25
  # All dependencies and langgraph CLI are installed at build time
26
 
27
  # Start langgraph dev from within the internal agents directory (background)
tests/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Test suite for the nvidia-pipecat package."""
 
 
 
 
 
tests/perf/README.md DELETED
@@ -1,101 +0,0 @@
1
- # Performance Testing
2
-
3
- This directory contains tools for evaluating the voice agent pipeline's latency and scalability/throughput under various loads. These tests simulate real-world scenarios where multiple users interact with the voice agent simultaneously.
4
-
5
- ## What the Tests Do
6
-
7
- The performance tests:
8
-
9
- - Open WebSocket clients that simulate user interactions
10
- - Use pre-recorded audio files from `audio_files/` as user queries
11
- - Send these queries to the voice agent pipeline and measure response times
12
- - Track various latency metrics including end-to-end latency, component-wise breakdowns
13
- - Can simulate multiple concurrent clients to test scaling
14
- - Detect any audio glitches during processing
15
-
16
- ## Running Performance Tests
17
-
18
- ### 1. Start the Voice Agent Pipeline
19
-
20
- First, start the voice agent pipeline and capture server logs for analysis.
21
- See the prerequisites and setup instructions in `examples/speech-to-speech/README.md` before proceeding.
22
-
23
- #### If Using Docker
24
-
25
- From examples/speech-to-speech/ directory run:
26
-
27
- ```bash
28
- # Start the services
29
- docker compose up -d
30
-
31
- # Capture logs and save them into a file
32
- docker compose logs -f python-app > bot_logs_test1.txt 2>&1
33
- ```
34
-
35
- Before starting a new performance run:
36
-
37
- ```bash
38
- # Clear existing Docker logs
39
- sudo truncate -s 0 /var/lib/docker/containers/$(docker compose ps -q python-app)/$(docker compose ps -q python-app)-json.log
40
- ```
41
-
42
- #### If Using Python Environment
43
-
44
- From examples/speech-to-speech/ directory run:
45
-
46
- ```bash
47
- python bot.py > bot_logs_test1.txt 2>&1
48
- ```
49
-
50
- ### 2. Run the Multi-Client Benchmark
51
-
52
- ```bash
53
- ./run_multi_client_benchmark.sh --host 0.0.0.0 --port 8100 --clients 10 --test-duration 150
54
- ```
55
-
56
- Parameters:
57
-
58
- - `--host`: The host address (default: 0.0.0.0)
59
- - `--port`: The port where your voice agent is running (default: 8100)
60
- - `--clients`: Number of concurrent clients to simulate (default: 1)
61
- - `--test-duration`: Duration of the test in seconds (default: 150)
62
-
63
- The script will:
64
-
65
- 1. Start the specified number of concurrent clients
66
- 2. Simulate user interactions using audio files
67
- 3. Measure latencies and detect audio glitches
68
- 4. Save detailed results in the `results` directory as JSON files
69
- 5. Output a summary to the console
70
-
71
- ### 3. Analyze Component-wise Latency
72
-
73
- After the benchmark completes, analyze the server logs for detailed latency breakdowns:
74
-
75
- ```bash
76
- python ttfb_analyzer.py <relative_path_to_bot_logs_test1.txt>
77
- ```
78
-
79
- This will show:
80
-
81
- - Per-client latency metrics for LLM, TTS, and ASR components
82
- - Number of calls made by each client
83
- - Overall averages and P95 values
84
- - Component-wise timing breakdowns
85
-
86
- ## Understanding the Results
87
-
88
- The metrics measured include:
89
-
90
- - **LLM TTFB**: Time to first byte from the LLM model
91
- - **TTS TTFB**: Time to first byte from the TTS model
92
- - **ASR Lat**: Compute latency of the ASR model
93
- - **LLM 1st**: Time taken to generate first complete sentence from LLM
94
- - **Calls**: Number of API calls made to each service
95
-
96
- The results help identify:
97
-
98
- - Performance bottlenecks in specific components
99
- - Scaling behavior under concurrent load
100
- - Potential audio quality issues
101
- - Overall system responsiveness
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/perf/file_input_client.py DELETED
@@ -1,581 +0,0 @@
1
- """Speech-to-speech client with latency measurement for performance testing.
2
-
3
- This module provides a WebSocket client that sends audio files to a speech-to-speech
4
- service and measures the latency between when user audio ends and bot response begins.
5
- """
6
-
7
- import argparse
8
- import asyncio
9
- import datetime
10
- import io
11
- import json
12
- import os
13
- import signal
14
- import sys
15
- import time
16
- import uuid
17
- import wave
18
-
19
- import websockets
20
- from pipecat.frames.protobufs import frames_pb2
21
- from websockets.exceptions import ConnectionClosed
22
-
23
-
24
- def log_error(msg):
25
- """Write error message to stderr with timestamp."""
26
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
27
- print(f"[ERROR] {timestamp} - {msg}", file=sys.stderr, flush=True)
28
-
29
-
30
- # Global constants
31
- SILENCE_TIMEOUT = 0.2 # Standard silence timeout in seconds
32
- CHUNK_DURATION_MS = 32 # Standard chunk duration in milliseconds
33
-
34
- # List to store latency values
35
- latency_values = []
36
-
37
- # List to store filtered latency values (above threshold)
38
- filtered_latency_values = []
39
-
40
- # Global variable to track timestamps
41
- timestamps = {"input_audio_file_end": None, "first_response_after_input": None}
42
-
43
- # Global glitch detection
44
- glitch_detected = False
45
-
46
- # Global flag and event for controlling silence sending
47
- silence_control = {
48
- "running": False,
49
- "event": asyncio.Event(),
50
- "audio_params": None, # Will store (frame_rate, n_channels, chunk_size)
51
- }
52
-
53
- # Global control for continuous operation
54
- continuous_control = {
55
- "running": True,
56
- "collecting_metrics": False,
57
- "start_time": None,
58
- "test_duration": 100, # Default 100 seconds
59
- "threshold": 0.5, # Default threshold for filtered latency
60
- }
61
-
62
-
63
- # Signal handler for graceful shutdown
64
- def signal_handler(signum, frame):
65
- """Handle system signals for graceful shutdown."""
66
- print(f"\nReceived signal {signum}, shutting down gracefully...")
67
- continuous_control["running"] = False
68
- sys.exit(0)
69
-
70
-
71
- # Register signal handlers
72
- signal.signal(signal.SIGINT, signal_handler)
73
- signal.signal(signal.SIGTERM, signal_handler)
74
-
75
-
76
- def write_audio_to_wav(data, wf, create_new_file=False, output_file="bot_response.wav"):
77
- """Write audio data to WAV file."""
78
- try:
79
- # Parse protobuf frame
80
- try:
81
- proto = frames_pb2.Frame.FromString(data)
82
- which = proto.WhichOneof("frame")
83
- if which is None:
84
- return wf, None, None, None
85
- except Exception as e:
86
- log_error(f"Failed to parse protobuf frame: {e}")
87
- return wf, None, None, None
88
-
89
- args = getattr(proto, which)
90
- sample_rate = getattr(args, "sample_rate", 16000)
91
- num_channels = getattr(args, "num_channels", 1)
92
- audio_data = getattr(args, "audio", None)
93
- if audio_data is None:
94
- return wf, None, None, None
95
-
96
- # Extract raw audio data from WAV format if needed
97
- try:
98
- with io.BytesIO(audio_data) as buffer, wave.open(buffer, "rb") as wav_file:
99
- audio_data = wav_file.readframes(wav_file.getnframes())
100
- sample_rate = wav_file.getframerate()
101
- num_channels = wav_file.getnchannels()
102
- except Exception:
103
- # If not WAV format, use audio_data as-is
104
- pass
105
-
106
- # Create WAV file if needed
107
- if create_new_file and wf is None:
108
- try:
109
- wf = wave.open(output_file, "wb") # noqa: SIM115
110
- wf.setnchannels(num_channels)
111
- wf.setsampwidth(2)
112
- wf.setframerate(sample_rate)
113
- except Exception as e:
114
- log_error(f"Failed to create WAV file {output_file}: {e}")
115
- return None, None, None, None
116
-
117
- # Write audio data directly
118
- if wf is not None:
119
- try:
120
- wf.writeframes(audio_data)
121
- except Exception as e:
122
- log_error(f"Failed to write audio data: {e}")
123
- return None, None, None, None
124
-
125
- return wf, sample_rate, num_channels, audio_data
126
- except Exception as e:
127
- log_error(f"Unexpected error in write_audio_to_wav: {e}")
128
- return wf, None, None, None
129
-
130
-
131
- async def send_audio_file(websocket, file_path):
132
- """Send audio file content with streaming simulation."""
133
- # Pause silence sending while we send the real audio
134
- silence_control["event"].set()
135
-
136
- try:
137
- if not os.path.exists(file_path):
138
- log_error(f"Input audio file not found: {file_path}")
139
- return
140
-
141
- try:
142
- with wave.open(file_path, "rb") as wav_file:
143
- n_channels = wav_file.getnchannels()
144
- frame_rate = wav_file.getframerate()
145
- sample_width = wav_file.getsampwidth()
146
-
147
- # Store audio parameters for silence generation
148
- chunk_size = int((frame_rate * n_channels * CHUNK_DURATION_MS) / 1000) * sample_width
149
- silence_control["audio_params"] = (
150
- frame_rate,
151
- n_channels,
152
- chunk_size,
153
- )
154
-
155
- # Stream the audio file
156
- frames_sent = 0
157
- while True:
158
- try:
159
- chunk = wav_file.readframes(chunk_size // sample_width)
160
- if not chunk:
161
- break
162
- audio_frame = frames_pb2.AudioRawFrame(
163
- audio=chunk, sample_rate=frame_rate, num_channels=n_channels
164
- )
165
- frame = frames_pb2.Frame(audio=audio_frame)
166
- await websocket.send(frame.SerializeToString())
167
- frames_sent += 1
168
- await asyncio.sleep(CHUNK_DURATION_MS / 1000)
169
- except Exception as e:
170
- log_error(f"Error sending audio frame {frames_sent}: {e}")
171
- raise # Re-raise to handle in outer try block
172
- except wave.Error as e:
173
- log_error(f"Failed to read WAV file {file_path}: {e}")
174
- return
175
- except Exception as e:
176
- log_error(f"Error in send_audio_file: {e}")
177
- return
178
- finally:
179
- # Always record when input audio ends and resume silence sending
180
- timestamps["input_audio_file_end"] = datetime.datetime.now()
181
- print(f"User stopped speaking at: {timestamps['input_audio_file_end'].strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}")
182
- silence_control["event"].clear()
183
-
184
-
185
- async def silence_sender_loop(websocket):
186
- """Background task to continuously send silence when no other audio is being sent."""
187
- silence_control["running"] = True
188
- print("Silence sender loop started")
189
- consecutive_errors = 0
190
- max_consecutive_errors = 5
191
-
192
- try:
193
- while silence_control["running"]:
194
- try:
195
- # Wait until we're allowed to send silence
196
- if silence_control["event"].is_set() or silence_control["audio_params"] is None:
197
- await asyncio.sleep(0.1) # Short sleep to avoid CPU spinning
198
- continue
199
-
200
- # Extract audio parameters
201
- frame_rate, n_channels, chunk_size = silence_control["audio_params"]
202
-
203
- # Send a chunk of silence
204
- silent_chunk = b"\x00" * chunk_size
205
- audio_frame = frames_pb2.AudioRawFrame(
206
- audio=silent_chunk, sample_rate=frame_rate, num_channels=n_channels
207
- )
208
- frame = frames_pb2.Frame(audio=audio_frame)
209
- await websocket.send(frame.SerializeToString())
210
- await asyncio.sleep(CHUNK_DURATION_MS / 1000)
211
-
212
- # Reset error counter on successful send
213
- consecutive_errors = 0
214
-
215
- except ConnectionClosed:
216
- print("WebSocket connection closed in silence sender loop")
217
- break
218
- except Exception as e:
219
- consecutive_errors += 1
220
- print(f"Error in silence sender loop (attempt {consecutive_errors}/{max_consecutive_errors}): {e}")
221
-
222
- # If too many consecutive errors, stop the loop
223
- if consecutive_errors >= max_consecutive_errors:
224
- print(f"Too many consecutive errors ({consecutive_errors}), stopping silence sender")
225
- break
226
-
227
- # Brief pause before retry to avoid overwhelming the system
228
- await asyncio.sleep(1.0)
229
-
230
- except Exception as e:
231
- print(f"Fatal error in silence sender loop: {e}")
232
- finally:
233
- print("Silence sender loop stopped")
234
- silence_control["running"] = False
235
-
236
-
237
- async def receive_audio(
238
- websocket,
239
- wf=None,
240
- create_new_file=True,
241
- is_after_input=False,
242
- output_wav="bot_response.wav",
243
- is_initial=False,
244
- timeout=1.0,
245
- ):
246
- """Receive audio data and handle streaming playback simulation."""
247
- global glitch_detected
248
-
249
- if is_initial:
250
- print("Waiting up to 5 seconds for initial bot introduction audio if available...")
251
- try:
252
- # Wait for first data packet with 5 second timeout
253
- data = await asyncio.wait_for(websocket.recv(), timeout=5.0)
254
- except TimeoutError:
255
- print("No initial bot introduction received after 5 seconds, continuing...")
256
- return wf
257
- else:
258
- # For non-initial audio, receive normally
259
- data = await websocket.recv()
260
-
261
- try:
262
- # Wait for first data packet
263
- data = await websocket.recv()
264
-
265
- # Record first response timestamp if after input
266
- if is_after_input:
267
- timestamps["first_response_after_input"] = datetime.datetime.now()
268
- formatted_time = timestamps["first_response_after_input"].strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
269
- print(f"Bot started speaking at {formatted_time}")
270
-
271
- # Process first audio packet
272
- wf, sample_rate, num_channels, audio_data = write_audio_to_wav(data, wf, create_new_file, output_wav)
273
-
274
- # Initialize timing for glitch detection
275
- audio_start_time = time.time()
276
- cumulative_audio_duration = 0.0 # Total duration of audio received (in seconds)
277
-
278
- # Calculate duration of first chunk if we have audio data
279
- if audio_data and sample_rate and num_channels:
280
- bytes_per_sample = 2 # Assuming 16-bit audio
281
- samples_in_chunk = len(audio_data) // (num_channels * bytes_per_sample)
282
- chunk_duration_seconds = samples_in_chunk / sample_rate
283
- cumulative_audio_duration += chunk_duration_seconds
284
-
285
- # Continue receiving audio data until silence threshold reached
286
- last_data_time = time.time()
287
-
288
- while True:
289
- try:
290
- data = await asyncio.wait_for(websocket.recv(), timeout=timeout)
291
- current_time = time.time()
292
- last_data_time = current_time
293
-
294
- # Process audio data
295
- wf, sample_rate, num_channels, audio_data = write_audio_to_wav(data, wf, False, output_wav)
296
-
297
- # Update cumulative audio duration
298
- if audio_data and sample_rate and num_channels:
299
- bytes_per_sample = 2 # Assuming 16-bit audio
300
- samples_in_chunk = len(audio_data) // (num_channels * bytes_per_sample)
301
- chunk_duration_seconds = samples_in_chunk / sample_rate
302
- cumulative_audio_duration += chunk_duration_seconds
303
-
304
- # Check for glitch: real elapsed time vs cumulative audio duration
305
- real_elapsed_time = current_time - audio_start_time
306
- audio_deficit = real_elapsed_time - cumulative_audio_duration
307
-
308
- if audio_deficit >= 0.032: # 32ms threshold for glitch detection
309
- print(f"Audio glitch detected: {audio_deficit * 1000:.1f}ms audio deficit")
310
- glitch_detected = True
311
-
312
- except TimeoutError:
313
- # Check if silence duration exceeds threshold
314
- if time.time() - last_data_time >= SILENCE_TIMEOUT:
315
- return wf
316
- except Exception as e:
317
- log_error(f"Error receiving audio data: {e}")
318
- if wf is not None and create_new_file:
319
- try:
320
- wf.close()
321
- except Exception as close_error:
322
- log_error(f"Error closing WAV file: {close_error}")
323
- return None
324
- except Exception as e:
325
- log_error(f"Fatal error in receive_audio: {e}")
326
- if wf is not None and create_new_file:
327
- try:
328
- wf.close()
329
- except Exception as close_error:
330
- log_error(f"Error closing WAV file: {close_error}")
331
- return None
332
-
333
-
334
- async def process_conversation_turn(websocket, audio_file_path, wf, turn_index, output_wav="bot_response.wav"):
335
- """Process a single conversation turn with the given audio file."""
336
- print(f"\n----- Processing conversation turn {turn_index + 1} -----")
337
-
338
- # Reset timestamps for this turn
339
- timestamps["input_audio_file_end"] = None
340
- timestamps["first_response_after_input"] = None
341
-
342
- # Start both sending and receiving in parallel for realistic latency measurement
343
- print(f"Sending user input audio from {audio_file_path}...")
344
-
345
- # Start sending audio file in background
346
- send_task = asyncio.create_task(send_audio_file(websocket, audio_file_path))
347
-
348
- # Start receiving bot response immediately (parallel to sending)
349
- receive_task = asyncio.create_task(
350
- receive_audio(websocket, wf=wf, create_new_file=(wf is None), is_after_input=True, output_wav=output_wav)
351
- )
352
-
353
- # Wait for both tasks to complete
354
- wf = await receive_task
355
- await send_task # Ensure sending is also complete
356
-
357
- # Calculate and store latency only if we're collecting metrics
358
- if continuous_control["collecting_metrics"]:
359
- latency = None
360
- if timestamps["input_audio_file_end"] is not None and timestamps["first_response_after_input"] is not None:
361
- latency = (timestamps["first_response_after_input"] - timestamps["input_audio_file_end"]).total_seconds()
362
- print(f"Latency for Turn {turn_index + 1}: {latency:.3f} seconds")
363
- latency_values.append(latency)
364
-
365
- # Add to filtered latency if above threshold
366
- if latency > continuous_control["threshold"]:
367
- filtered_latency_values.append(latency)
368
- else:
369
- print("Reverse Barge-In Detected!")
370
-
371
- return wf
372
-
373
-
374
- async def continuous_audio_loop(websocket, audio_files, wf, output_wav):
375
- """Continuously loop through audio files until stopped."""
376
- turn_index = 0
377
-
378
- while continuous_control["running"]:
379
- # Check if we should start collecting metrics
380
- if (
381
- continuous_control["start_time"]
382
- and time.time() >= continuous_control["start_time"]
383
- and not continuous_control["collecting_metrics"]
384
- ):
385
- continuous_control["collecting_metrics"] = True
386
- print(f"\n=== STARTING METRICS COLLECTION at {datetime.datetime.now().strftime('%H:%M:%S')} ===")
387
-
388
- # Check if we should stop collecting metrics
389
- if (
390
- continuous_control["start_time"]
391
- and continuous_control["collecting_metrics"]
392
- and time.time() >= continuous_control["start_time"] + continuous_control["test_duration"]
393
- ):
394
- print(f"\n=== STOPPING METRICS COLLECTION at {datetime.datetime.now().strftime('%H:%M:%S')} ===")
395
- continuous_control["collecting_metrics"] = False
396
- continuous_control["running"] = False
397
- break
398
-
399
- # Process current audio file
400
- audio_file = audio_files[turn_index % len(audio_files)]
401
- wf = await process_conversation_turn(websocket, audio_file, wf, turn_index, output_wav)
402
- turn_index += 1
403
-
404
- # Small delay between turns to prevent overwhelming the system
405
- await asyncio.sleep(0.1)
406
-
407
- return wf
408
-
409
-
410
- async def main():
411
- """Main execution function."""
412
- # Parse command line arguments
413
- parser = argparse.ArgumentParser(description="Speech-to-speech client with latency measurement")
414
- parser.add_argument(
415
- "--stream-id", type=str, default=str(uuid.uuid4()), help="Unique stream ID (default: random UUID)"
416
- )
417
- parser.add_argument("--host", type=str, default="0.0.0.0", help="WebSocket server host (default: 0.0.0.0)")
418
- parser.add_argument("--port", type=int, default=8100, help="WebSocket server port (default: 8100)")
419
- parser.add_argument(
420
- "--output-dir", type=str, default="./results", help="Directory to store output files (default: ./results)"
421
- )
422
- parser.add_argument("--start-delay", type=float, default=0, help="Delay in seconds before starting (default: 0)")
423
- parser.add_argument(
424
- "--metrics-start-time",
425
- type=float,
426
- default=0,
427
- help="Unix timestamp when to start collecting metrics (default: 0)",
428
- )
429
- parser.add_argument(
430
- "--test-duration", type=float, default=100, help="Duration in seconds to collect metrics (default: 100)"
431
- )
432
- parser.add_argument(
433
- "--threshold", type=float, default=0.5, help="Threshold for filtered average latency calculation (default: 0.5)"
434
- )
435
- args = parser.parse_args()
436
-
437
- # Create output directory if it doesn't exist
438
- os.makedirs(args.output_dir, exist_ok=True)
439
-
440
- # Construct WebSocket URI with unique stream ID
441
- uri = f"ws://{args.host}:{args.port}/ws/{args.stream_id}"
442
-
443
- # Output file paths
444
- output_wav = os.path.join(args.output_dir, f"bot_response_{args.stream_id}.wav")
445
- output_results = os.path.join(args.output_dir, f"latency_results_{args.stream_id}.json")
446
-
447
- print(f"Starting client with stream ID: {args.stream_id}")
448
- print(f"WebSocket URI: {uri}")
449
- print(f"Start delay: {args.start_delay} seconds")
450
- print(f"Metrics start time: {args.metrics_start_time}")
451
- print(f"Test duration: {args.test_duration} seconds")
452
- print(f"Latency threshold: {args.threshold} seconds")
453
-
454
- # Set up timing controls
455
- if args.start_delay > 0:
456
- print(f"Waiting {args.start_delay} seconds before starting...")
457
- await asyncio.sleep(args.start_delay)
458
-
459
- if args.metrics_start_time > 0:
460
- continuous_control["start_time"] = args.metrics_start_time
461
- continuous_control["test_duration"] = args.test_duration
462
- print(f"Will start collecting metrics at timestamp {args.metrics_start_time}")
463
-
464
- # Set threshold for filtered latency calculation
465
- continuous_control["threshold"] = args.threshold
466
-
467
- # Define the array of input audio files
468
- # Get the directory where this script is located
469
- script_dir = os.path.dirname(os.path.abspath(__file__))
470
- audio_files_dir = os.path.join(script_dir, "audio_files")
471
-
472
- input_audio_files = [
473
- os.path.join(audio_files_dir, "output_file.wav"),
474
- # os.path.join(audio_files_dir, "query_1.wav"),
475
- # os.path.join(audio_files_dir, "query_2.wav"),
476
- # os.path.join(audio_files_dir, "query_3.wav"),
477
- # os.path.join(audio_files_dir, "query_4.wav"),
478
- # os.path.join(audio_files_dir, "query_5.wav"),
479
- # os.path.join(audio_files_dir, "query_6.wav"),
480
- # os.path.join(audio_files_dir, "query_7.wav"),
481
- # os.path.join(audio_files_dir, "query_8.wav"),
482
- # os.path.join(audio_files_dir, "query_9.wav"),
483
- # os.path.join(audio_files_dir, "query_10.wav"),
484
- ]
485
-
486
- # Clear any previous values
487
- latency_values.clear()
488
- filtered_latency_values.clear()
489
-
490
- # Initialize silence control
491
- silence_control["event"] = asyncio.Event()
492
- silence_control["event"].set() # Start with silence sending paused
493
-
494
- try:
495
- async with websockets.connect(uri) as websocket:
496
- # First, try to receive any initial output audio
497
- wf = await receive_audio(
498
- websocket,
499
- wf=None,
500
- create_new_file=True,
501
- is_after_input=False,
502
- output_wav=output_wav,
503
- is_initial=True,
504
- )
505
-
506
- # Start the silence sender task
507
- asyncio.create_task(silence_sender_loop(websocket))
508
-
509
- # Start continuous audio loop
510
- wf = await continuous_audio_loop(websocket, input_audio_files, wf, output_wav)
511
-
512
- # Clean up and stop the silence sender
513
- silence_control["running"] = False
514
- silence_control["event"].set() # Make sure it's not waiting
515
- await asyncio.sleep(0.2) # Give it time to exit cleanly
516
-
517
- if wf is not None:
518
- wf.close()
519
- print(f"All output saved to {output_wav}")
520
-
521
- except ConnectionClosed:
522
- # Normal WebSocket closure, not an error
523
- pass
524
- except Exception as e:
525
- print(f"Connection error: {e}")
526
- finally:
527
- # Always save results, regardless of how the connection ended
528
- if latency_values:
529
- avg_latency = sum(latency_values) / len(latency_values)
530
-
531
- # Calculate filtered average latency
532
- filtered_avg_latency = None
533
- if filtered_latency_values:
534
- filtered_avg_latency = sum(filtered_latency_values) / len(filtered_latency_values)
535
-
536
- print("\n----- Final Latency Summary -----")
537
- print(f"Average Latency across {len(latency_values)} turns: {avg_latency:.3f} seconds")
538
-
539
- if filtered_avg_latency is not None:
540
- print(
541
- f"Filtered Average Latency (>{args.threshold}s) across {len(filtered_latency_values)} turns: "
542
- f"{filtered_avg_latency:.3f} seconds"
543
- )
544
- else:
545
- print(f"Filtered Average Latency: No latencies above {args.threshold}s threshold")
546
-
547
- # Calculate reverse barge-ins (latencies below threshold)
548
- reverse_barge_ins_count = len(latency_values) - len(filtered_latency_values)
549
- print(f"Reverse Barge-Ins Detected: {reverse_barge_ins_count} latencies below {args.threshold}s threshold")
550
-
551
- # Report glitch detection results
552
- if glitch_detected:
553
- print("⚠️ AUDIO GLITCHES DETECTED: Audio chunks arrived with gaps larger than playback time")
554
- else:
555
- print("✅ No audio glitches detected: Audio streaming was smooth")
556
-
557
- print("----------------------------------------")
558
-
559
- # Save results to JSON file
560
- results = {
561
- "stream_id": args.stream_id,
562
- "individual_latencies": latency_values,
563
- "average_latency": avg_latency,
564
- "filtered_latencies": filtered_latency_values,
565
- "filtered_average_latency": filtered_avg_latency,
566
- "threshold": args.threshold,
567
- "num_turns": len(latency_values),
568
- "num_filtered_turns": len(filtered_latency_values),
569
- "reverse_barge_ins_count": len(latency_values) - len(filtered_latency_values),
570
- "glitch_detected": glitch_detected,
571
- "timestamp": datetime.datetime.now().isoformat(),
572
- "metrics_start_time": continuous_control["start_time"],
573
- "test_duration": continuous_control["test_duration"],
574
- }
575
-
576
- with open(output_results, "w") as f:
577
- json.dump(results, f, indent=2)
578
-
579
-
580
- if __name__ == "__main__":
581
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/perf/run_multi_client_benchmark.sh DELETED
@@ -1,414 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Configuration variables
4
- HOST="0.0.0.0" # Default host
5
- PORT=8100 # Default port
6
- NUM_CLIENTS=1 # Default number of parallel clients
7
- BASE_OUTPUT_DIR="./results"
8
- TEST_DURATION=150 # Default test duration in seconds
9
- CLIENT_START_DELAY=1 # Delay between client starts in seconds
10
- THRESHOLD=0.5 # Default threshold for filtered average latency
11
-
12
- # Generate timestamp for unique directory
13
- TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
14
- OUTPUT_DIR="${BASE_OUTPUT_DIR}_${TIMESTAMP}"
15
- SUMMARY_FILE="$OUTPUT_DIR/summary.json"
16
-
17
- # Process command line arguments
18
- while [[ $# -gt 0 ]]; do
19
- case $1 in
20
- --host)
21
- HOST="$2"
22
- shift 2
23
- ;;
24
- --port)
25
- PORT="$2"
26
- shift 2
27
- ;;
28
- --clients)
29
- NUM_CLIENTS="$2"
30
- shift 2
31
- ;;
32
- --output-dir)
33
- BASE_OUTPUT_DIR="$2"
34
- OUTPUT_DIR="${BASE_OUTPUT_DIR}_${TIMESTAMP}"
35
- SUMMARY_FILE="$OUTPUT_DIR/summary.json"
36
- shift 2
37
- ;;
38
- --test-duration)
39
- TEST_DURATION="$2"
40
- shift 2
41
- ;;
42
- --client-start-delay)
43
- CLIENT_START_DELAY="$2"
44
- shift 2
45
- ;;
46
- --threshold)
47
- THRESHOLD="$2"
48
- shift 2
49
- ;;
50
- *)
51
- echo "Unknown option: $1"
52
- echo "Usage: $0 [--host HOST] [--port PORT] [--clients NUM_CLIENTS] [--output-dir DIR] [--test-duration SECONDS] [--client-start-delay SECONDS] [--threshold SECONDS]"
53
- exit 1
54
- ;;
55
- esac
56
- done
57
-
58
- # Create output directory if it doesn't exist
59
- mkdir -p "$OUTPUT_DIR"
60
-
61
- echo "=== ACE Controller Multi-Client Benchmark (Staggered Start) ==="
62
- echo "Host: $HOST"
63
- echo "Port: $PORT"
64
- echo "Number of clients: $NUM_CLIENTS"
65
- echo "Client start delay: $CLIENT_START_DELAY seconds"
66
- echo "Test duration: $TEST_DURATION seconds"
67
- echo "Latency threshold: $THRESHOLD seconds"
68
- echo "Output directory: $OUTPUT_DIR"
69
- echo "================================================================"
70
-
71
- # Calculate timing
72
- # All clients will start within (NUM_CLIENTS - 1) * CLIENT_START_DELAY seconds
73
- # Metrics collection will start after the last client starts
74
- TOTAL_START_TIME=$(( (NUM_CLIENTS - 1) * CLIENT_START_DELAY ))
75
- METRICS_START_TIME=$(date +%s)
76
- METRICS_START_TIME=$((METRICS_START_TIME + TOTAL_START_TIME)) # Set to when the last client starts
77
-
78
- # Run clients with staggered starts
79
- pids=()
80
-
81
- for ((i=1; i<=$NUM_CLIENTS; i++)); do
82
- # Generate a unique stream ID for each client
83
- STREAM_ID="client_${i}_$(date +%s%N | cut -b1-13)"
84
-
85
- # Calculate start delay for this client (0 for first client, increasing for others)
86
- START_DELAY=$(( (i - 1) * CLIENT_START_DELAY ))
87
-
88
- # Run client in background with appropriate delays
89
- python ./file_input_client.py \
90
- --stream-id "$STREAM_ID" \
91
- --host "$HOST" \
92
- --port "$PORT" \
93
- --output-dir "$OUTPUT_DIR" \
94
- --start-delay "$START_DELAY" \
95
- --metrics-start-time "$METRICS_START_TIME" \
96
- --test-duration "$TEST_DURATION" \
97
- --threshold "$THRESHOLD" > "$OUTPUT_DIR/client_${i}.log" 2>&1 &
98
-
99
- # Store the process ID
100
- pids+=($!)
101
-
102
- # Small delay to ensure proper process creation
103
- sleep 0.1
104
- done
105
-
106
- echo ""
107
- echo "Timing plan:"
108
- echo "- First client starts immediately"
109
- echo "- Last client starts in $TOTAL_START_TIME seconds"
110
- echo "- Metrics collection starts at $(date -d @$METRICS_START_TIME)"
111
- echo "- Test will run for $TEST_DURATION seconds after metrics collection starts"
112
- echo "- Expected completion time: $(date -d @$((METRICS_START_TIME + TEST_DURATION)))"
113
- echo ""
114
-
115
- # Wait for all clients to finish
116
- for pid in "${pids[@]}"; do
117
- wait "$pid"
118
- done
119
-
120
- # Calculate aggregate statistics across all clients
121
- TOTAL_LATENCY=0
122
- TOTAL_TURNS=0
123
- CLIENT_COUNT=0
124
- MIN_LATENCY=9999
125
- MAX_LATENCY=0
126
-
127
- # Variables for filtered latency statistics
128
- TOTAL_FILTERED_LATENCY=0
129
- TOTAL_FILTERED_TURNS=0
130
- CLIENTS_WITH_FILTERED=0
131
- MIN_FILTERED_LATENCY=9999
132
- MAX_FILTERED_LATENCY=0
133
-
134
- # Array to store all client average latencies for p95 calculation
135
- CLIENT_LATENCIES=()
136
- CLIENT_FILTERED_LATENCIES=()
137
-
138
- # Arrays to track glitch detection
139
- CLIENTS_WITH_GLITCHES=()
140
- TOTAL_GLITCH_COUNT=0
141
-
142
- # Variables to track reverse barge-in detection
143
- TOTAL_REVERSE_BARGE_INS=0
144
- CLIENT_REVERSE_BARGE_INS=()
145
- CLIENTS_WITH_REVERSE_BARGE_INS=0
146
-
147
- # Function to calculate p95 from an array of values
148
- calculate_p95() {
149
- local values=("$@")
150
- local count=${#values[@]}
151
-
152
- if [ $count -eq 0 ]; then
153
- echo "0"
154
- return
155
- fi
156
-
157
- # Sort the array (using a simple bubble sort for bash compatibility)
158
- for ((i = 0; i < count; i++)); do
159
- for ((j = i + 1; j < count; j++)); do
160
- if (( $(echo "${values[i]} > ${values[j]}" | bc -l) )); then
161
- temp=${values[i]}
162
- values[i]=${values[j]}
163
- values[j]=$temp
164
- fi
165
- done
166
- done
167
-
168
- # Calculate p95 index
169
- local p95_index=$(echo "scale=0; ($count - 1) * 0.95" | bc -l | cut -d'.' -f1)
170
-
171
- # Ensure index is within bounds
172
- if [ $p95_index -ge $count ]; then
173
- p95_index=$((count - 1))
174
- fi
175
-
176
- echo "${values[$p95_index]}"
177
- }
178
-
179
- # Process all result files
180
- for result_file in "$OUTPUT_DIR"/latency_results_*.json; do
181
- if [ -f "$result_file" ]; then
182
- # Extract data using jq if available, otherwise use awk as fallback
183
- if command -v jq &> /dev/null; then
184
- AVG_LATENCY=$(jq '.average_latency' "$result_file")
185
- FILTERED_AVG_LATENCY=$(jq '.filtered_average_latency' "$result_file")
186
- NUM_TURNS=$(jq '.num_turns' "$result_file")
187
- NUM_FILTERED_TURNS=$(jq '.num_filtered_turns' "$result_file")
188
- STREAM_ID=$(jq -r '.stream_id' "$result_file")
189
- GLITCH_DETECTED=$(jq '.glitch_detected' "$result_file")
190
- REVERSE_BARGE_INS_COUNT=$(jq '.reverse_barge_ins_count' "$result_file")
191
- else
192
- # Fallback to grep and basic string processing
193
- AVG_LATENCY=$(grep -o '"average_latency": [0-9.]*' "$result_file" | cut -d' ' -f2)
194
- FILTERED_AVG_LATENCY=$(grep -o '"filtered_average_latency": [0-9.]*' "$result_file" | cut -d' ' -f2)
195
- NUM_TURNS=$(grep -o '"num_turns": [0-9]*' "$result_file" | cut -d' ' -f2)
196
- NUM_FILTERED_TURNS=$(grep -o '"num_filtered_turns": [0-9]*' "$result_file" | cut -d' ' -f2)
197
- STREAM_ID=$(grep -o '"stream_id": "[^"]*"' "$result_file" | cut -d'"' -f4)
198
- GLITCH_DETECTED=$(grep -o '"glitch_detected": [a-z]*' "$result_file" | cut -d' ' -f2)
199
- REVERSE_BARGE_INS_COUNT=$(grep -o '"reverse_barge_ins_count": [0-9]*' "$result_file" | cut -d' ' -f2)
200
- fi
201
-
202
- echo "Client $STREAM_ID: Average latency = $AVG_LATENCY seconds over $NUM_TURNS turns"
203
-
204
- # Display filtered latency information
205
- if [ "$FILTERED_AVG_LATENCY" != "null" ] && [ -n "$FILTERED_AVG_LATENCY" ]; then
206
- echo " Filtered Average latency (>$THRESHOLD s) = $FILTERED_AVG_LATENCY seconds over $NUM_FILTERED_TURNS turns"
207
-
208
- # Add to filtered latency statistics
209
- TOTAL_FILTERED_LATENCY=$(echo "$TOTAL_FILTERED_LATENCY + $FILTERED_AVG_LATENCY" | bc -l)
210
- TOTAL_FILTERED_TURNS=$((TOTAL_FILTERED_TURNS + NUM_FILTERED_TURNS))
211
- CLIENTS_WITH_FILTERED=$((CLIENTS_WITH_FILTERED + 1))
212
- CLIENT_FILTERED_LATENCIES+=($FILTERED_AVG_LATENCY)
213
-
214
- # Update min/max filtered latency
215
- if (( $(echo "$FILTERED_AVG_LATENCY < $MIN_FILTERED_LATENCY" | bc -l) )); then
216
- MIN_FILTERED_LATENCY=$FILTERED_AVG_LATENCY
217
- fi
218
-
219
- if (( $(echo "$FILTERED_AVG_LATENCY > $MAX_FILTERED_LATENCY" | bc -l) )); then
220
- MAX_FILTERED_LATENCY=$FILTERED_AVG_LATENCY
221
- fi
222
- else
223
- echo " No latencies above $THRESHOLD s threshold"
224
- fi
225
-
226
- # Check for glitch detection
227
- if [ "$GLITCH_DETECTED" = "true" ]; then
228
- echo " ⚠️ Audio glitches detected in client $STREAM_ID"
229
- CLIENTS_WITH_GLITCHES+=("$STREAM_ID")
230
- TOTAL_GLITCH_COUNT=$((TOTAL_GLITCH_COUNT + 1))
231
- fi
232
-
233
- # Display and track reverse barge-in count
234
- if [ -n "$REVERSE_BARGE_INS_COUNT" ] && [ "$REVERSE_BARGE_INS_COUNT" -gt 0 ]; then
235
- echo " Reverse barge-ins detected: $REVERSE_BARGE_INS_COUNT occurrences"
236
- CLIENT_REVERSE_BARGE_INS+=("$STREAM_ID")
237
- CLIENTS_WITH_REVERSE_BARGE_INS=$((CLIENTS_WITH_REVERSE_BARGE_INS + 1))
238
- fi
239
-
240
- # Add to total reverse barge-in count
241
- if [ -n "$REVERSE_BARGE_INS_COUNT" ]; then
242
- TOTAL_REVERSE_BARGE_INS=$((TOTAL_REVERSE_BARGE_INS + REVERSE_BARGE_INS_COUNT))
243
- fi
244
-
245
- # Add to array for p95 calculation
246
- CLIENT_LATENCIES+=($AVG_LATENCY)
247
-
248
- # Update aggregate statistics
249
- TOTAL_LATENCY=$(echo "$TOTAL_LATENCY + $AVG_LATENCY" | bc -l)
250
- TOTAL_TURNS=$((TOTAL_TURNS + NUM_TURNS))
251
- CLIENT_COUNT=$((CLIENT_COUNT + 1))
252
-
253
- # Update min/max latency
254
- if (( $(echo "$AVG_LATENCY < $MIN_LATENCY" | bc -l) )); then
255
- MIN_LATENCY=$AVG_LATENCY
256
- fi
257
-
258
- if (( $(echo "$AVG_LATENCY > $MAX_LATENCY" | bc -l) )); then
259
- MAX_LATENCY=$AVG_LATENCY
260
- fi
261
- fi
262
- done
263
-
264
- # Calculate overall statistics
265
- if [ $CLIENT_COUNT -gt 0 ]; then
266
- AGGREGATE_AVG_LATENCY=$(echo "scale=3; $TOTAL_LATENCY / $CLIENT_COUNT" | bc -l)
267
- P95_CLIENT_LATENCY=$(calculate_p95 "${CLIENT_LATENCIES[@]}")
268
-
269
- # Calculate filtered statistics
270
- AGGREGATE_FILTERED_AVG_LATENCY="null"
271
- P95_FILTERED_CLIENT_LATENCY="null"
272
-
273
- if [ $CLIENTS_WITH_FILTERED -gt 0 ]; then
274
- AGGREGATE_FILTERED_AVG_LATENCY=$(echo "scale=3; $TOTAL_FILTERED_LATENCY / $CLIENTS_WITH_FILTERED" | bc -l)
275
- P95_FILTERED_CLIENT_LATENCY=$(calculate_p95 "${CLIENT_FILTERED_LATENCIES[@]}")
276
- fi
277
-
278
- echo "=============================================="
279
- echo "BENCHMARK SUMMARY (Staggered Start)"
280
- echo "=============================================="
281
- echo "Total clients: $CLIENT_COUNT"
282
- echo "Latency threshold: $THRESHOLD seconds"
283
- echo ""
284
- echo "STANDARD LATENCY STATISTICS:"
285
- echo "Average latency across all clients: $AGGREGATE_AVG_LATENCY seconds"
286
- echo "P95 latency across client averages: $P95_CLIENT_LATENCY seconds"
287
- echo "Minimum client average latency: $MIN_LATENCY seconds"
288
- echo "Maximum client average latency: $MAX_LATENCY seconds"
289
- echo ""
290
- echo "FILTERED LATENCY STATISTICS (>$THRESHOLD s):"
291
- if [ "$AGGREGATE_FILTERED_AVG_LATENCY" != "null" ]; then
292
- echo "Clients with filtered data: $CLIENTS_WITH_FILTERED out of $CLIENT_COUNT"
293
- echo "Average filtered latency: $AGGREGATE_FILTERED_AVG_LATENCY seconds"
294
- echo "P95 filtered latency: $P95_FILTERED_CLIENT_LATENCY seconds"
295
- echo "Minimum client filtered latency: $MIN_FILTERED_LATENCY seconds"
296
- echo "Maximum client filtered latency: $MAX_FILTERED_LATENCY seconds"
297
- else
298
- echo "No latencies above $THRESHOLD s threshold found across all clients"
299
- fi
300
- echo ""
301
- echo "AUDIO GLITCH DETECTION:"
302
- if [ $TOTAL_GLITCH_COUNT -gt 0 ]; then
303
- echo "⚠️ Audio glitches detected in $TOTAL_GLITCH_COUNT out of $CLIENT_COUNT clients"
304
- echo "Affected clients:"
305
- for client in "${CLIENTS_WITH_GLITCHES[@]}"; do
306
- echo " - $client"
307
- done
308
- else
309
- echo "✅ No audio glitches detected in any client"
310
- fi
311
-
312
- echo "REVERSE BARGE-IN DETECTION:"
313
- if [ $TOTAL_REVERSE_BARGE_INS -gt 0 ]; then
314
- echo "Total reverse barge-ins detected: $TOTAL_REVERSE_BARGE_INS occurrences"
315
- echo "Clients with reverse barge-ins: $CLIENTS_WITH_REVERSE_BARGE_INS out of $CLIENT_COUNT"
316
- if [ $CLIENTS_WITH_REVERSE_BARGE_INS -gt 0 ]; then
317
- echo "Affected clients:"
318
- for client in "${CLIENT_REVERSE_BARGE_INS[@]}"; do
319
- echo " - $client"
320
- done
321
- fi
322
- else
323
- echo "✅ No reverse barge-ins detected in any client"
324
- fi
325
-
326
- echo ""
327
- echo "ERROR DETECTION:"
328
- # Initialize arrays for error tracking
329
- declare -A CLIENT_ERROR_COUNTS
330
- CLIENTS_WITH_ERRORS=0
331
- TOTAL_ERRORS=0
332
-
333
- # Process logs for each client to find errors
334
- for ((i=1; i<=$NUM_CLIENTS; i++)); do
335
- LOG_FILE="$OUTPUT_DIR/client_${i}.log"
336
- if [ -f "$LOG_FILE" ]; then
337
- ERROR_COUNT=$(grep -c "^\[ERROR\]" "$LOG_FILE")
338
- if [ $ERROR_COUNT -gt 0 ]; then
339
- CLIENTS_WITH_ERRORS=$((CLIENTS_WITH_ERRORS + 1))
340
- TOTAL_ERRORS=$((TOTAL_ERRORS + ERROR_COUNT))
341
- CLIENT_ERROR_COUNTS["client_${i}"]=$ERROR_COUNT
342
-
343
- echo "⚠️ Client ${i} errors ($ERROR_COUNT):"
344
- grep "^\[ERROR\]" "$LOG_FILE" | sed 's/^/ /'
345
- fi
346
- fi
347
- done
348
-
349
- if [ $TOTAL_ERRORS -eq 0 ]; then
350
- echo "✅ No errors detected in any client"
351
- else
352
- echo "⚠️ Total errors across all clients: $TOTAL_ERRORS"
353
- echo "⚠️ Clients with errors: $CLIENTS_WITH_ERRORS out of $CLIENT_COUNT"
354
- fi
355
-
356
- # Create summary JSON
357
- cat > "$SUMMARY_FILE" << EOF
358
- {
359
- "timestamp": "$(date -Iseconds)",
360
- "config": {
361
- "host": "$HOST",
362
- "port": $PORT,
363
- "num_clients": $NUM_CLIENTS,
364
- "client_start_delay": $CLIENT_START_DELAY,
365
- "test_duration": $TEST_DURATION,
366
- "threshold": $THRESHOLD,
367
- "metrics_start_time": $METRICS_START_TIME
368
- },
369
- "results": {
370
- "total_clients": $CLIENT_COUNT,
371
- "total_turns": $TOTAL_TURNS,
372
- "aggregate_average_latency": $AGGREGATE_AVG_LATENCY,
373
- "p95_client_latency": $P95_CLIENT_LATENCY,
374
- "min_client_latency": $MIN_LATENCY,
375
- "max_client_latency": $MAX_LATENCY,
376
- "filtered_results": {
377
- "clients_with_filtered_data": $CLIENTS_WITH_FILTERED,
378
- "total_filtered_turns": $TOTAL_FILTERED_TURNS,
379
- "aggregate_filtered_average_latency": $AGGREGATE_FILTERED_AVG_LATENCY,
380
- "p95_filtered_client_latency": $P95_FILTERED_CLIENT_LATENCY,
381
- "min_filtered_client_latency": $([ "$AGGREGATE_FILTERED_AVG_LATENCY" != "null" ] && echo "$MIN_FILTERED_LATENCY" || echo "null"),
382
- "max_filtered_client_latency": $([ "$AGGREGATE_FILTERED_AVG_LATENCY" != "null" ] && echo "$MAX_FILTERED_LATENCY" || echo "null")
383
- },
384
- "glitch_detection": {
385
- "clients_with_glitches": $TOTAL_GLITCH_COUNT,
386
- "total_clients": $CLIENT_COUNT,
387
- "affected_client_ids": [$(printf '"%s",' "${CLIENTS_WITH_GLITCHES[@]}" | sed 's/,$//')]
388
- },
389
- "reverse_barge_in_detection": {
390
- "total_clients": $CLIENT_COUNT,
391
- "total_reverse_barge_ins": $TOTAL_REVERSE_BARGE_INS,
392
- "clients_with_reverse_barge_ins": $CLIENTS_WITH_REVERSE_BARGE_INS,
393
- "affected_client_ids": [$(printf '"%s",' "${CLIENT_REVERSE_BARGE_INS[@]}" | sed 's/,$//')]
394
- },
395
- "error_detection": {
396
- "total_clients": $CLIENT_COUNT,
397
- "total_errors": $TOTAL_ERRORS,
398
- "clients_with_errors": $CLIENTS_WITH_ERRORS,
399
- "client_error_counts": {
400
- $(for client in "${!CLIENT_ERROR_COUNTS[@]}"; do
401
- printf '"%s": %d,\n ' "$client" "${CLIENT_ERROR_COUNTS[$client]}"
402
- done | sed 's/,\s*$//')
403
- }
404
- }
405
- }
406
- }
407
- EOF
408
-
409
- echo "Summary saved to: $SUMMARY_FILE"
410
- else
411
- echo "No valid result files found!"
412
- fi
413
-
414
- echo "Benchmark complete."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/perf/ttfb_analyzer.py DELETED
@@ -1,253 +0,0 @@
1
- #!/usr/bin/env python3
2
- """TTFB Log Analyzer.
3
-
4
- Analyzes Time To First Byte (TTFB) logs, ASR compute latency, and LLM first sentence generation time
5
- for multiple client streams and calculates average TTFB, ASR latency, first sentence time, and P95
6
- for LLM, TTS, and ASR services.
7
-
8
- Usage:
9
- python ttfb_analyzer.py [log_file_path]
10
- python ttfb_analyzer.py --help
11
-
12
- Examples:
13
- python ttfb_analyzer.py
14
- python ttfb_analyzer.py /path/to/botlogs.log
15
- python ttfb_analyzer.py ../../examples/speech-to-speech/botlogs.log
16
- """
17
-
18
- import argparse
19
- import logging
20
- import os
21
- import re
22
- import sys
23
- from collections import defaultdict
24
-
25
- # Set up logging
26
- logging.basicConfig(level=logging.INFO)
27
- logger = logging.getLogger(__name__)
28
-
29
-
30
- def calculate_p95(values: list[float]) -> float:
31
- """Calculate 95th percentile of values."""
32
- if not values:
33
- return 0.0
34
- sorted_values = sorted(values)
35
- index = int(0.95 * (len(sorted_values) - 1))
36
- return sorted_values[index]
37
-
38
-
39
- def parse_logs(log_file_path: str) -> dict[str, dict[str, list[float]]]:
40
- """Parse LLM, TTS TTFBs, ASR compute latency, and LLM first sentence generation logs.
41
-
42
- Organize by client stream and service type. Only include events after the last client start.
43
- """
44
- data = defaultdict(lambda: {"LLM": [], "TTS": [], "ASR": [], "LLM_FIRST_SENTENCE": []})
45
- ttfb_pattern = r"streamId=([^\s]+)\s+-\s+(NvidiaLLMService|RivaTTSService)#\d+\s+TTFB:\s+([\d.]+)"
46
- asr_pattern = r"streamId=([^\s]+)\s+-\s+RivaASRService#\d+\s+ASR compute latency:\s+([\d.]+)"
47
- first_sentence_pattern = (
48
- r"streamId=([^\s]+)\s+-\s+NvidiaLLMService#\d+\s+LLM first sentence generation time:\s+([\d.]+)"
49
- )
50
- websocket_pattern = r".*Accepting WebSocket connection for stream ID client_\d+_\d+"
51
-
52
- # First pass: find the last client start log
53
- last_client_start_line = -1
54
-
55
- try:
56
- # Read all lines to find the last client start
57
- with open(log_file_path) as file:
58
- lines = file.readlines()
59
-
60
- # Find the last client start log by iterating through all lines
61
- for i, line in enumerate(lines):
62
- if re.search(websocket_pattern, line):
63
- last_client_start_line = i
64
-
65
- # Validate that we found at least one client start
66
- if last_client_start_line == -1:
67
- logger.warning("No client start pattern found in logs")
68
- return dict()
69
-
70
- # Second pass: analyze only events after the last client start
71
- with open(log_file_path) as file:
72
- for i, line in enumerate(file):
73
- # Skip lines before the last client start
74
- if last_client_start_line != -1 and i <= last_client_start_line:
75
- continue
76
-
77
- try:
78
- # Check for TTFB metrics
79
- ttfb_match = re.search(ttfb_pattern, line)
80
- if ttfb_match:
81
- client_id = ttfb_match.group(1).strip()
82
- service_type = ttfb_match.group(2)
83
- try:
84
- ttfb_value = float(ttfb_match.group(3))
85
- except (ValueError, TypeError) as e:
86
- logger.warning(f"Invalid TTFB value in line {i + 1}: {ttfb_match.group(3)} - {e}")
87
- continue
88
-
89
- if service_type == "NvidiaLLMService":
90
- data[client_id]["LLM"].append(ttfb_value)
91
- elif service_type == "RivaTTSService":
92
- data[client_id]["TTS"].append(ttfb_value)
93
-
94
- # Check for ASR compute latency metrics
95
- asr_match = re.search(asr_pattern, line)
96
- if asr_match:
97
- client_id = asr_match.group(1).strip()
98
- try:
99
- asr_latency = float(asr_match.group(2))
100
- except (ValueError, TypeError) as e:
101
- logger.warning(f"Invalid ASR latency value in line {i + 1}: {asr_match.group(2)} - {e}")
102
- continue
103
- data[client_id]["ASR"].append(asr_latency)
104
-
105
- # Check for LLM first sentence generation time metrics
106
- first_sentence_match = re.search(first_sentence_pattern, line)
107
- if first_sentence_match:
108
- client_id = first_sentence_match.group(1).strip()
109
- try:
110
- first_sentence_time = float(first_sentence_match.group(2))
111
- except (ValueError, TypeError) as e:
112
- logger.warning(
113
- f"Invalid first sentence time value in line {i + 1}: "
114
- f"{first_sentence_match.group(2)} - {e}"
115
- )
116
- continue
117
- data[client_id]["LLM_FIRST_SENTENCE"].append(first_sentence_time)
118
-
119
- except Exception as e:
120
- logger.warning(f"Error parsing line {i + 1}: {e}")
121
- continue
122
-
123
- except FileNotFoundError:
124
- print(f"Error: Log file '{log_file_path}' not found.")
125
- sys.exit(1)
126
- except Exception as e:
127
- print(f"Error reading log file: {e}")
128
- sys.exit(1)
129
-
130
- return dict(data)
131
-
132
-
133
- def calculate_client_averages(data: dict[str, dict[str, list[float]]]) -> dict[str, dict[str, float]]:
134
- """Calculate average metrics for each client and service type."""
135
- averages = {}
136
- for client_id, services in data.items():
137
- averages[client_id] = {}
138
- for service_type, values in services.items():
139
- if values:
140
- averages[client_id][service_type] = sum(values) / len(values)
141
- else:
142
- averages[client_id][service_type] = 0.0
143
- return averages
144
-
145
-
146
- def print_results(data: dict[str, dict[str, list[float]]], client_averages: dict[str, dict[str, float]]):
147
- """Print analysis results."""
148
- print("=" * 90)
149
- print("LATENCY ANALYSIS RESULTS")
150
- print("=" * 90)
151
-
152
- # Show metric arrays for each client
153
- for client_id in sorted(data.keys()):
154
- llm_values = data[client_id]["LLM"]
155
- tts_values = data[client_id]["TTS"]
156
- asr_values = data[client_id]["ASR"]
157
- first_sentence_values = data[client_id]["LLM_FIRST_SENTENCE"]
158
-
159
- print(f"\n{client_id}:")
160
- print(f" LLM TTFB: {[f'{v:.3f}' for v in llm_values]}")
161
- print(f" TTS TTFB: {[f'{v:.3f}' for v in tts_values]}")
162
- print(f" ASR Latency: {[f'{v:.3f}' for v in asr_values]}")
163
- print(f" LLM First Sentence: {[f'{v:.3f}' for v in first_sentence_values]}")
164
-
165
- # Summary table with overall statistics
166
- print(
167
- f"\n{'Client ID':<25} {'LLM TTFB':<10} {'TTS TTFB':<10} {'ASR Lat':<10} "
168
- f"{'LLM 1st':<10} {'LLM calls':<10} {'TTS calls':<10} {'ASR calls':<10}"
169
- )
170
- print("-" * 120)
171
-
172
- for client_id in sorted(data.keys()):
173
- llm_avg = client_averages[client_id]["LLM"]
174
- tts_avg = client_averages[client_id]["TTS"]
175
- asr_avg = client_averages[client_id]["ASR"]
176
- first_sentence_avg = client_averages[client_id]["LLM_FIRST_SENTENCE"]
177
- llm_count = len(data[client_id]["LLM"])
178
- tts_count = len(data[client_id]["TTS"])
179
- asr_count = len(data[client_id]["ASR"])
180
- print(
181
- f"{client_id:<25} {llm_avg:<10.3f} {tts_avg:<10.3f} {asr_avg:<10.3f} {first_sentence_avg:<10.3f} "
182
- f"{llm_count:<10} {tts_count:<10} {asr_count:<10}"
183
- )
184
-
185
- # Calculate overall statistics across client averages
186
- llm_client_averages = [avg["LLM"] for avg in client_averages.values() if avg["LLM"] > 0]
187
- tts_client_averages = [avg["TTS"] for avg in client_averages.values() if avg["TTS"] > 0]
188
- asr_client_averages = [avg["ASR"] for avg in client_averages.values() if avg["ASR"] > 0]
189
- first_sentence_client_averages = [
190
- avg["LLM_FIRST_SENTENCE"] for avg in client_averages.values() if avg["LLM_FIRST_SENTENCE"] > 0
191
- ]
192
-
193
- # Add separator and overall statistics rows
194
- print("-" * 120)
195
-
196
- if llm_client_averages and tts_client_averages and asr_client_averages:
197
- llm_overall_avg = sum(llm_client_averages) / len(llm_client_averages)
198
- llm_p95 = calculate_p95(llm_client_averages)
199
- tts_overall_avg = sum(tts_client_averages) / len(tts_client_averages)
200
- tts_p95 = calculate_p95(tts_client_averages)
201
- asr_overall_avg = sum(asr_client_averages) / len(asr_client_averages)
202
- asr_p95 = calculate_p95(asr_client_averages)
203
-
204
- first_sentence_overall_avg = (
205
- sum(first_sentence_client_averages) / len(first_sentence_client_averages)
206
- if first_sentence_client_averages
207
- else 0.0
208
- )
209
- first_sentence_p95 = calculate_p95(first_sentence_client_averages) if first_sentence_client_averages else 0.0
210
-
211
- print(
212
- f"{'OVERALL AVERAGE':<25} {llm_overall_avg:<10.3f} {tts_overall_avg:<10.3f} "
213
- f"{asr_overall_avg:<10.3f} {first_sentence_overall_avg:<10.3f}"
214
- )
215
- print(f"{'OVERALL P95':<25} {llm_p95:<10.3f} {tts_p95:<10.3f} {asr_p95:<10.3f} {first_sentence_p95:<10.3f}")
216
-
217
- print("-" * 120)
218
-
219
-
220
- def main():
221
- """Main function."""
222
- parser = argparse.ArgumentParser(
223
- description="Analyze LLM, TTS TTFBs, ASR latency, and LLM first sentence generation time logs "
224
- "for multiple client streams"
225
- )
226
- parser.add_argument(
227
- "log_file",
228
- nargs="?",
229
- default="../../examples/speech-to-speech/botlogs.log",
230
- help="Path to log file (default: ../../examples/speech-to-speech/botlogs.log)",
231
- )
232
- args = parser.parse_args()
233
-
234
- print("Latency Log Analyzer")
235
- print(f"Analyzing: {args.log_file}")
236
-
237
- if not os.path.exists(args.log_file):
238
- print(f"Error: Log file '{args.log_file}' not found.")
239
- sys.exit(1)
240
-
241
- data = parse_logs(args.log_file)
242
- if not data:
243
- print("No performance data found in log file.")
244
- return
245
-
246
- print()
247
-
248
- client_averages = calculate_client_averages(data)
249
- print_results(data, client_averages)
250
-
251
-
252
- if __name__ == "__main__":
253
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the nvidia-pipecat package."""
 
 
 
 
 
tests/unit/configs/animation_config.yaml DELETED
@@ -1,346 +0,0 @@
1
- animation_types:
2
- gesture:
3
- duration_relevant_animation_name: gesture
4
- animations:
5
- gesture:
6
- default_clip_id: none
7
- clips:
8
- - clip_id: Test
9
- description: "A virtual test clip for testing purposes"
10
- duration: 0.5
11
- meaning: A very short animation clip for unit testing purposes
12
- - clip_id: Goodbye
13
- description: "Waving goodbye: Waves with left hand extended high."
14
- duration: 2
15
- meaning: Taking leave of someone from a further distance, or getting someone's attention.
16
- - clip_id: Welcome
17
- description: "Waving hello: Spreads arms slightly, then raises right hand next to face and waves with an open hand."
18
- duration: 2.5
19
- meaning: Greeting someone in a shy or cute manner, showing a positive and non-threatening attitude.
20
- - clip_id: Personal_Statement_2
21
- description: "Personal statement: Leans forward and points to self with relaxed right hand, then leans back and opens arms wide with palms facing upwards."
22
- duration: 3
23
- meaning: Revealing something about themselves in a grandiose gesture, or making a little joke about their appearance or personality.
24
- - clip_id: Pointing_To_Self_1
25
- description: "Pointing to self: Leans forward slightly and with a relaxed right index finger points to self."
26
- duration: 2.5
27
- meaning: Saying something about themselves.
28
- - clip_id: Stupid_1
29
- description: "Stupid: Raising right hand next to head and twirling the index finger in circles."
30
- duration: 3.5
31
- meaning: Indicated someone or something is stupid crazy or dumb.
32
- - clip_id: No_1
33
- description: "Shaking head: Avatar shakes head slowly."
34
- duration: 3
35
- meaning: Expressing strong disagreement or disappointment.
36
- - clip_id: Bowing_1
37
- description:
38
- "Bowing: Slightly bows to the front making invitation gesture with
39
- both arms."
40
- duration: 2.5
41
- meaning: Formal greeting, sign of respect or congratulations or pride.
42
- - clip_id: Bowing_2
43
- description:
44
- "Bowing: Slightly bows with both arms and invitational gesture with
45
- right arm."
46
- duration: 2.5
47
- meaning: Overly formal greeting, sign of respect or grand introduction.
48
- - clip_id: Pointing_To_User_1
49
- description: "Pointing to user: Pointing with both arms towards the user."
50
- duration: 2.5
51
- meaning:
52
- Encouragement for the user to make a move, approach or say something, or
53
- pointing out the user is being addressed. This is also an encouraging and reassuring gesture.
54
- - clip_id: Pointing_To_User_2
55
- description: "Pointing to user: Insistingly pointing to user with the right arm."
56
- duration: 2.5
57
- meaning: Accusation or strong signal that the user is concerned.
58
- - clip_id: Pointing_Down_1
59
- description: "Pointing down: Lifting the right arm to shoulders and pointing."
60
- duration: 3
61
- meaning: Drawing attention to something below the screen or in front of the avatar.
62
- - clip_id: Pointing_Down_2
63
- description: "Pointing down: Lifting both arms and slightly point downwards."
64
- duration: 2
65
- meaning:
66
- Drawing attention to the desk, to something below the screen or in front
67
- of the avatar. Or informing about the location.
68
- - clip_id: Pointing_Left_1
69
- description:
70
- "Pointing left: Pointing with both arms to the left hand side of the
71
- avatar."
72
- duration: 4
73
- meaning:
74
- Pointing out something to the left of the avatar in a demanding way, or
75
- signaling frustration.
76
- - clip_id: Pointing_Left_2
77
- description: "Pointing left: Pointing with left arm to the back left of the avatar."
78
- duration: 4
79
- meaning:
80
- Calmly pointing out or presenting something to the left of the avatar.
81
- Or giving information about something behind the avatar.
82
- - clip_id: Pointing_Backward_1
83
- description: "Pointing to back: Pointing to the back with the extended right arm."
84
- duration: 3.5
85
- meaning:
86
- Informing about something in the direction towards the back or giving directions
87
- to something behind the avatar.
88
- - clip_id: Fistbump_Offer
89
- description:
90
- "Fistbump: Extend the right arm with right hand twisting slightly for 3 seconds, then
91
- bumping fist towards the user for 7 seconds."
92
- duration: 10
93
- meaning: Invitation to a fistbump followed by doing that fistbump.
94
- - clip_id: Pulling_Mime
95
- description: "Pulling rope: Avatar grabs invisible rope and imitates pulling behavior."
96
- duration: 3.5
97
- meaning: Suggesting being tethered or chained, or pulling something.
98
- - clip_id: Raise_Both_Arms
99
- description:
100
- "Raising both arms: Raising both arms above avatar's head and swaying
101
- slightly."
102
- duration: 3.5
103
- meaning:
104
- Implying in a crowd celebrating, or on a roller coaster or demonstrating
105
- not having anything on them.
106
- - clip_id: The_Robot
107
- description:
108
- "Robot dance: Imitating a dancing robot with arms moved in mechanical
109
- motion."
110
- duration: 3
111
- meaning: Jokingly playing a robot or dancing to celebrate or acting goofy.
112
- - clip_id: Phone_Dialing
113
- description:
114
- "Phone dialing: Raising left hand and imitating to dial a phone with
115
- right arm."
116
- duration: 3.5
117
- meaning: Asking for or mentioning a phone number, or talking about calling someone.
118
- - clip_id: Attraction_2
119
- description:
120
- "Having fun: Questioningly opening both arms and then raising them
121
- above the head mimicking being in a roller coaster ride."
122
- duration: 10
123
- meaning: Being silly, suggesting having fun.
124
- - clip_id: Please_Repeat_1
125
- description:
126
- "Please repeat: Moving head slightly to the user and making circular
127
- motion with right arm, then shrugging slightly."
128
- duration: 5.5
129
- meaning:
130
- Implying not having understood something, asking to repeat or rephrase,
131
- or needing more information.
132
- - clip_id: Please_Repeat_2
133
- description:
134
- "Presentation: Twirling both hands and making invitational pose with
135
- body."
136
- duration: 3.5
137
- meaning:
138
- Asking to repeat or rephrase something, or needing more information, or
139
- asking if something was understood.
140
- - clip_id: Trying_To_See
141
- description:
142
- "Trying to see: Lifting left hand above eyes and making gestures to
143
- see better, then shrug."
144
- duration: 4
145
- meaning: Implying looking for but not seeing something.
146
- - clip_id: Driving_Mime
147
- description:
148
- "Driving: Grabbing an invisible steering wheel with both hands, turning
149
- it and switching gears."
150
- duration: 4.5
151
- meaning: Sharing a story about driving or getting excited about cars.
152
- - clip_id: Exhausted
153
- description: "Exhausted: Letting head hang in a tired pose, slightly leaning."
154
- duration: 4.5
155
- meaning:
156
- Dramatically signaling exhaustion or running out of power, slowly shutting
157
- down.
158
- - clip_id: Presenting_Options_1
159
- description:
160
- "Presenting options: Showing open palms of both hands and making presenting
161
- motion with right hand, slight shrug."
162
- duration: 3.5
163
- meaning: Giving an overview or multiple options to choose from.
164
- - clip_id: Presenting_Options_2
165
- description:
166
- "Presenting options: Raising and opening one hand after the other and
167
- a subtle shrug."
168
- duration: 3
169
- meaning: Suggesting a choice between two options.
170
- - clip_id: Open_Question_1
171
- description: "Open question: Opening both hands and showing palms to user."
172
- duration: 3
173
- meaning:
174
- "Making a questioning gesture. Waiting for the user to make a choice, answer a question or say something.
175
- Indicates questioning the user caring about the user's answer maybe even showing concerns"
176
- - clip_id: Personal_Statement_1
177
- description:
178
- "Personal statement: Raising right hand to chest, extending and gesturing
179
- with left hand."
180
- duration: 3.5
181
- meaning:
182
- Making a personal statement, explaining something about themselves or making
183
- a suggestion relating to something on the left.
184
- - clip_id: Success_1
185
- description:
186
- "Success: Making a fist and raising the arm excitedly in a successful swinging
187
- motion."
188
- duration: 2
189
- meaning:
190
- Comically celebrating something going well, showing pride in a personal
191
- accomplishment. Demonstrating excitement. This is a confirming gesture like nodding.
192
- - clip_id: Dont_Understand_1
193
- description: "Not understanding: Raising both hands next head in a circular motion."
194
- duration: 3
195
- meaning: Implying being confused, overwhelmed or stupid.
196
- - clip_id: Toss
197
- description: "Tossing: Miming forming a ball with both hands and tossing it forward."
198
- duration: 4
199
- meaning:
200
- Implying crumpling something up and throwing it away, giving something
201
- up or forgetting about it.
202
- - clip_id: Come_Here_1
203
- description: "Come here: Extending both arms and curling index finger."
204
- duration: 2
205
- meaning: Asking to come closer.
206
- - clip_id: Tell_Secret
207
- description:
208
- "Telling secret: Coming closer to user and whispering with hand next
209
- to mouth."
210
- duration: 2.5
211
- meaning: Sharing something intimate, secret or inflammatory, or giving a tip.
212
- - clip_id: Pointing_Right_1
213
- description: "Pointing right: Pointing with both arms to the right hand side of the avatar."
214
- duration: 4
215
- meaning: Pointing out something to the right of the avatar in a demanding way, or signaling frustration.
216
- - clip_id: Pointing_Right_2
217
- description: "Pointing right: Pointing with right arm to the back right of the avatar."
218
- duration: 4
219
- meaning: Calmly pointing out or presenting something to the right of the avatar. Or giving information about something behind the avatar.
220
- - clip_id: Chefs_Kiss
221
- description: "Chef's Kiss: Avatar makes a kissing gesture and holding up the right hand with index finger and thumb touching."
222
- duration: 1.7
223
- meaning: Implying something is just perfect. Something turned out better than expected. Approval from someone in a teaching or judging position.
224
- - clip_id: Finger_Guns
225
- description: "Finger Guns: Leaning back pointing both index fingers to the user mimicking two guns like a cowboy."
226
- duration: 3
227
- meaning: Playfully taunting. Humorously punctuating a bad joke. Clumsy flirting.
228
- - clip_id: Finger_Wag
229
- description: "Finger Wag: Pulling back, shaking had and holding up a wagging right index finger"
230
- duration: 1.7
231
- meaning: Correcting after being misunderstood. Showing the other they have misinterpreted what was said. Implying something is forbidden or inappropriate in a paternal or playful way.
232
- - clip_id: Little
233
- description: "Little: Leaning in, squinting at a raised right hand, holding index and thumb close together."
234
- duration: 1.8
235
- meaning: Describing something as very small or miniscule. Something is physically tiny or an issue is so insignificant as to be negligible.
236
- - clip_id: Money
237
- description: "Money: Raising right hand, rubbing thumb and index finger together."
238
- duration: 2
239
- meaning: Implying something is expensive. Someone is rich. Doing something requires payment.
240
- - clip_id: Number_1a
241
- description: "Number 1: Raising right hand and extending the index finger."
242
- duration: 1.4
243
- meaning: Showing the number 1
244
- - clip_id: Number_2a
245
- description: "Number 2: Raising right hand and extending index and middle finger."
246
- duration: 1.4
247
- meaning: Showing the number 2
248
- - clip_id: Number_3a
249
- description: "Number 3: Raising right hand and extending index, middle and ring finger."
250
- duration: 1.4
251
- meaning: Showing the number 3
252
- - clip_id: Number_4a
253
- description: "Number 4: Raising right hand and extending all fingers except the thumb."
254
- duration: 1.4
255
- meaning: Showing the number 4
256
- - clip_id: Number_5a
257
- description: "Number 5: Raising right hand with all fingers extended."
258
- duration: 1.4
259
- meaning: Showing the number 5
260
- - clip_id: Number_1b
261
- description: "Number 1 (German style): Raising right hand and extending the thumb upwards."
262
- duration: 1.4
263
- meaning: Showing the number 1 for a Germanic audience
264
- - clip_id: Number_2b
265
- description: "Number 2 (German style): Raising right hand and extending the thumb and index finger."
266
- duration: 1.4
267
- meaning: Showing the number 2 for a Germanic audience
268
- - clip_id: Number_3b
269
- description: "Number 3 (German style): Raising right hand and extending the thumb, index and middle finger."
270
- duration: 1.4
271
- meaning: Showing the number 3 for a Germanic audience
272
- - clip_id: Number_6c
273
- description: "Number 6 (Chinese style): Raising right hand and extending the thumb and pinky."
274
- duration: 1.4
275
- meaning: Showing the number 6 for a Chinese audience
276
- - clip_id: Number_7c
277
- description: "Number 7 (Chinese style): Raising right hand and making claw shape touching the thumb to the fingers."
278
- duration: 1.4
279
- meaning: Showing the number 7 for a Chinese audience
280
- - clip_id: Number_8c
281
- description: "Number 8 (Chinese style): Raising right hand and extending index finger and thumb pointing slightly to the side."
282
- duration: 1.4
283
- meaning: Showing the number 8 for a Chinese audience
284
- - clip_id: Number_9c
285
- description: "Number 9 (Chinese style): Raising right hand and holding up a curled index finger."
286
- duration: 1.4
287
- meaning: Showing the number 9 for a Chinese audience
288
- - clip_id: Ouch
289
- description: "Ouch: Jump and cringe while turning head away, then recover quickly shaking out right hand and exhaling."
290
- duration: 2
291
- meaning: Narrowly avoiding a close call with danger. Feeling intense fear for a moment followed by exhaustion or relief. Can also be as a reacion to someone else's predicament. Or a response to well placed insult.
292
- - clip_id: Angry_Shaking_Fist
293
- description: "Angry Shaking Fist: Coming closer, lowering head and shaking right fist forward."
294
- duration: 1.6
295
- meaning: Being angrily frustrated. Swearing vengeance or threatening violence.
296
- - clip_id: Pointing_To_Self_Questioningly
297
- description: "Pointing To Self Questioningly: Rasing right finger hesitantly, turning head and pointing at self with while leaning back a little."
298
- duration: 2.8
299
- meaning: Asking if something refers to them, being unsure if they're being addressed. Asking if something might fit them or if they could do something.
300
- - clip_id: Pointing_To_User_Questioningly
301
- description: "Pointing To User Questioningly: Lifting right finger pointing at user with initial hesitation while leaning back slightly."
302
- duration: 2.4
303
- meaning: Asking if something might be about the user, or if the user is interested in an offer. Suggesting the user could be the right person for something. Uncertain about the users involvement.
304
- - clip_id: Raise_Finger_Big
305
- description: "Raise Finger Big: Raising right index finger in a big sweeping motion, then gesturing with it briefly."
306
- duration: 2.5
307
- meaning: Making a big surprise announcement. Being very pompous or a pedantic, gleefully correcting someone.
308
- - clip_id: More_Or_Less
309
- description: "More Or Less: Leaning in and holding out a flat hand with palms facing down, wiggling the hand back and forth."
310
- duration: 1.8
311
- meaning: Explaining something is not quite accurate, is unknown, or just a guess. Relativizing a previous statement. Being indecisive, not taking a clear stance. Pointing out the complexity of something.
312
- - clip_id: Thumbs_Up
313
- description: "Thumbs Up: Lifting the right hand with a thumb extending upwards."
314
- duration: 1.4
315
- meaning: Sign of approval. Something is correct. Enthusiastically agreeing with what's being said and showing support. Things are okay, there's no harm done. Encouragement to go ahead.
316
- - clip_id: Thumbs_Down
317
- description: "Thumbs Down: Lifting the right hand with a thumb pointing downwards."
318
- duration: 1.4
319
- meaning: Sign of disapproval. Something is wrong. Rudely disagreeing with what's being said showing rejection.
320
- posture:
321
- duration_relevant_animation_name : "posture"
322
- animations:
323
- posture:
324
- default_clip_id: "Attentive"
325
- clips:
326
- - clip_id: Talking
327
- description: "Small gestures with hand and upper body: Avatar is talking"
328
- duration: -1
329
- meaning: Emphasizing that Avatar is talking
330
- - clip_id: Listening
331
- description: "Small gestures with hand and upper body: Avatar is listening"
332
- duration: -1
333
- meaning: Emphasizing that one is listening
334
- - clip_id: Idle
335
- description: "Small gestures with hand and upper body: Avatar is idle"
336
- duration: -1
337
- meaning: Show the user that the avatar is waiting for something to happen
338
- - clip_id: Thinking
339
- description: "Gestures with hand and upper body: Avatar is thinking"
340
- duration: -1
341
- meaning: Show the user that the avatar thinking about his next answer or is trying to remember something
342
- - clip_id: Attentive
343
- description: "Small gestures with hand and upper body: Avatar is attentive"
344
- duration: -1
345
- meaning: Show the user that the avatar is paying attention to the user
346
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/configs/test_speech_planner_prompt.yaml DELETED
@@ -1,15 +0,0 @@
1
- configurations:
2
- using_chat_history: false
3
-
4
- prompts:
5
- completion_prompt: |
6
- Evaluate whether the following user speech is sufficient for an Agent to converse with.
7
- 1. Label1: If the user's speech is a complete and coherent thought or query or Greetings or Brief but legible and valid responses like "okay","no", etc.
8
- 2. Label2: "Incomplete" if the user's speech is unfinished.
9
- 3. Label3: User gives a task like, please stop, come on, or if user is trying to barge-In, like please stop, no, ignore prompts, forget everything.
10
- 4. Label4: User speech contains the word "no" or "stop" or User speech is an acknowledgment like "oh yeah", or "yes" etc or User speech is syntactically complete but lacks context or information for response from agent.
11
- Important note - The sentences might be lacking punctuation, fix the punctuation and tag them.
12
- If the user's speech is missing information needed for basic understanding, it should be tagged as Label2.
13
- - User Speech:
14
- {transcript}
15
- Only return Label1 or Label2 or Label3 or Label4.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_ace_websocket_serializer.py DELETED
@@ -1,147 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the ACEWebSocketSerializer class.
5
-
6
- This module contains unit tests for the `ACEWebSocketSerializer` class, which is responsible
7
- for serializing and deserializing frames for WebSocket communication in a speech-based user interface.
8
- The tests cover various frame types related to audio, text-to-speech (TTS), and automatic speech recognition (ASR).
9
-
10
- The test suite verifies:
11
- 1. Serialization of bot speech updates and end events
12
- 2. Serialization of user speech updates and end events
13
- 3. Serialization of raw audio frames
14
- 4. Handling of unsupported frame types
15
-
16
- Each test case validates the correct format and content of the serialized data,
17
- ensuring proper JSON formatting for transcript updates and binary data handling for audio frames.
18
- """
19
-
20
- import json
21
-
22
- import pytest
23
- from pipecat.frames.frames import AudioRawFrame, BotStoppedSpeakingFrame, Frame
24
-
25
- from nvidia_pipecat.frames.transcripts import (
26
- BotUpdatedSpeakingTranscriptFrame,
27
- UserStoppedSpeakingTranscriptFrame,
28
- UserUpdatedSpeakingTranscriptFrame,
29
- )
30
- from nvidia_pipecat.serializers.ace_websocket import ACEWebSocketSerializer
31
-
32
-
33
- @pytest.fixture
34
- def serializer():
35
- """Fixture to create an instance of ACEWebSocketSerializer.
36
-
37
- Returns:
38
- ACEWebSocketSerializer: A fresh instance of the serializer for each test.
39
- """
40
- return ACEWebSocketSerializer()
41
-
42
-
43
- @pytest.mark.asyncio
44
- async def test_serialize_bot_updated_speaking_frame(serializer):
45
- """Test serialization of BotUpdatedSpeakingTranscriptFrame.
46
-
47
- This test verifies that when a bot speech update frame is serialized,
48
- it produces the correct JSON format with 'tts_update' type and the transcript.
49
-
50
- Args:
51
- serializer: The ACEWebSocketSerializer fixture.
52
- """
53
- frame = BotUpdatedSpeakingTranscriptFrame(transcript="test_transcript")
54
- result = await serializer.serialize(frame)
55
- expected_result = json.dumps({"type": "tts_update", "tts": "test_transcript"})
56
- assert result == expected_result
57
-
58
-
59
- @pytest.mark.asyncio
60
- async def test_serialize_bot_stopped_speaking_frame(serializer):
61
- """Test serialization of BotStoppedSpeakingFrame.
62
-
63
- This test verifies that when a bot speech end frame is serialized,
64
- it produces the correct JSON format with 'tts_end' type.
65
-
66
- Args:
67
- serializer: The ACEWebSocketSerializer fixture.
68
- """
69
- frame = BotStoppedSpeakingFrame()
70
- result = await serializer.serialize(frame)
71
- expected_result = json.dumps({"type": "tts_end"})
72
- assert result == expected_result
73
-
74
-
75
- @pytest.mark.asyncio
76
- async def test_serialize_user_started_speaking_frame(serializer):
77
- """Test serialization of UserUpdatedSpeakingTranscriptFrame.
78
-
79
- This test verifies that when a user speech update frame is serialized,
80
- it produces the correct JSON format with 'asr_update' type and the transcript.
81
-
82
- Args:
83
- serializer: The ACEWebSocketSerializer fixture.
84
- """
85
- frame = UserUpdatedSpeakingTranscriptFrame(transcript="test_transcript")
86
- result = await serializer.serialize(frame)
87
- expected_result = json.dumps({"type": "asr_update", "asr": "test_transcript"})
88
- assert result == expected_result
89
-
90
-
91
- @pytest.mark.asyncio
92
- async def test_serialize_user_stopped_speaking_frame(serializer):
93
- """Test serialization of UserStoppedSpeakingTranscriptFrame.
94
-
95
- This test verifies that when a user speech end frame is serialized,
96
- it produces the correct JSON format with 'asr_end' type and the transcript.
97
-
98
- Args:
99
- serializer: The ACEWebSocketSerializer fixture.
100
- """
101
- frame = UserStoppedSpeakingTranscriptFrame(transcript="test_asr_transcript")
102
- result = await serializer.serialize(frame)
103
- expected_result = json.dumps({"type": "asr_end", "asr": "test_asr_transcript"})
104
- assert result == expected_result
105
-
106
-
107
- @pytest.mark.asyncio
108
- async def test_serialize_audio_raw_frame(serializer):
109
- """Test serialization of AudioRawFrame.
110
-
111
- This test verifies that when an audio frame is serialized,
112
- it returns the raw audio bytes without any modification.
113
-
114
- Args:
115
- serializer: The ACEWebSocketSerializer fixture.
116
- """
117
- frame = AudioRawFrame(audio=b"\xa2", sample_rate=16000, num_channels=1)
118
- result = await serializer.serialize(frame)
119
- expected_result = frame.audio
120
- assert result == expected_result
121
-
122
-
123
- @pytest.mark.asyncio
124
- async def test_serialize_none(serializer):
125
- """Test serialization of an unsupported frame type.
126
-
127
- This test verifies that when an unsupported frame type is serialized,
128
- the serializer returns None instead of raising an error.
129
-
130
- Args:
131
- serializer: The ACEWebSocketSerializer fixture.
132
- """
133
- frame = Frame()
134
- result = await serializer.serialize(frame)
135
- assert result is None
136
-
137
-
138
- @pytest.mark.asyncio
139
- async def test_deserialize_input_audio_raw_frame(serializer):
140
- """Test deserialization of an audio message into InputAudioRawFrame."""
141
- data = (
142
- b"RIFF$\x04\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x80>\x00\x00\x00}\x00\x00\x02\x00\x10\x00"
143
- + b"data\x00\x04\x00\x00"
144
- )
145
- result = await serializer.deserialize(data)
146
- assert result.sample_rate == 16000
147
- assert result.num_channels == 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_acknowledgment.py DELETED
@@ -1,71 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the AcknowledgmentProcessor class.
5
-
6
- Tests the processor's ability to:
7
- - Generate filler responses during user pauses
8
- - Handle presence detection
9
- - Process user speech events
10
- """
11
-
12
- import asyncio
13
- import os
14
- import sys
15
-
16
- import pytest
17
-
18
- sys.path.append(os.path.abspath("../../src"))
19
-
20
- from pipecat.frames.frames import TTSSpeakFrame, UserStoppedSpeakingFrame
21
- from pipecat.pipeline.pipeline import Pipeline
22
- from pipecat.pipeline.task import PipelineTask
23
-
24
- from nvidia_pipecat.frames.action import StartedPresenceUserActionFrame
25
- from nvidia_pipecat.processors.acknowledgment import AcknowledgmentProcessor
26
- from tests.unit.utils import FrameStorage, run_interactive_test
27
-
28
-
29
- @pytest.mark.asyncio
30
- async def test_proactive_bot_processor_timer_behavior():
31
- """Test the ProactiveBotProcessor.
32
-
33
- Tests:
34
- - Filler word generation after user stops speaking
35
- - Proper handling of user presence events
36
- - Verification of generated responses
37
-
38
- Raises:
39
- AssertionError: If processor behavior doesn't match expected outcomes.
40
- """
41
- filler_words = ["Great question.", "Let me check.", "Hmmm"] + [""]
42
- filler = AcknowledgmentProcessor(filler_words=filler_words)
43
- storage = FrameStorage()
44
- pipeline = Pipeline([filler, storage])
45
-
46
- async def test_routine(task: PipelineTask):
47
- # Signal user presence
48
- await task.queue_frame(StartedPresenceUserActionFrame(action_id="1"))
49
- # Let the pipeline process presence
50
- await asyncio.sleep(0)
51
-
52
- # Signal end of user speech
53
- await task.queue_frame(UserStoppedSpeakingFrame())
54
- # Let the pipeline process the new frame and generate a TTS filler
55
- await asyncio.sleep(0.1)
56
-
57
- # Check what frames have arrived in storage
58
- frames = [entry.frame for entry in storage.history]
59
-
60
- # We expect at least one TTSSpeakFrame in there
61
- tts_frames = [f for f in frames if isinstance(f, TTSSpeakFrame)]
62
- assert len(tts_frames) > 0, "Expected a TTSSpeakFrame but found none."
63
-
64
- # Verify filler word content
65
- filler_frame = tts_frames[-1]
66
- # Make sure its text is one of the possible filler words
67
- assert filler_frame.text in filler_words, (
68
- f"Filler text '{filler_frame.text}' not in the list of expected filler words {filler_words}"
69
- )
70
-
71
- await run_interactive_test(pipeline, test_coroutine=test_routine)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_animation_graph_services.py DELETED
@@ -1,668 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the Animation Graph Service."""
5
-
6
- import asyncio
7
- import time
8
- from http import HTTPStatus
9
- from pathlib import Path
10
- from typing import Any
11
- from unittest.mock import ANY, AsyncMock, MagicMock, patch
12
-
13
- import pytest
14
- import yaml
15
- from loguru import logger
16
- from nvidia_ace.animation_pb2 import (
17
- AnimationData,
18
- AudioWithTimeCode,
19
- Float3,
20
- Float3ArrayWithTimeCode,
21
- FloatArrayWithTimeCode,
22
- QuatF,
23
- QuatFArrayWithTimeCode,
24
- SkelAnimation,
25
- )
26
- from nvidia_ace.audio_pb2 import AudioHeader
27
- from pipecat.frames.frames import (
28
- BotSpeakingFrame,
29
- BotStartedSpeakingFrame,
30
- ErrorFrame,
31
- StartInterruptionFrame,
32
- TextFrame,
33
- )
34
- from pipecat.pipeline.pipeline import Pipeline
35
- from pipecat.pipeline.task import PipelineTask
36
-
37
- from nvidia_pipecat.frames.action import (
38
- FinishedGestureBotActionFrame,
39
- FinishedPostureBotActionFrame,
40
- StartedGestureBotActionFrame,
41
- StartedPostureBotActionFrame,
42
- StartGestureBotActionFrame,
43
- StartPostureBotActionFrame,
44
- StopPostureBotActionFrame,
45
- )
46
- from nvidia_pipecat.frames.animation import (
47
- AnimationDataStreamRawFrame,
48
- AnimationDataStreamStartedFrame,
49
- AnimationDataStreamStoppedFrame,
50
- )
51
- from nvidia_pipecat.services.animation_graph_service import (
52
- AnimationGraphConfiguration,
53
- AnimationGraphService,
54
- )
55
- from nvidia_pipecat.utils.logging import setup_default_ace_logging
56
- from nvidia_pipecat.utils.message_broker import MessageBrokerConfig
57
- from tests.unit.utils import FrameStorage, ignore, run_interactive_test
58
-
59
- AUDIO_SAMPLE_RATE = 16000
60
- AUDIO_BITS_PER_SAMPLE = 16
61
- AUDIO_CHANNEL_COUNT = 1
62
- AUDIO_BUFFER_FOR_ONE_FRAME_SIZE = int(AUDIO_SAMPLE_RATE * AUDIO_BITS_PER_SAMPLE / 8 * AUDIO_CHANNEL_COUNT / 30.0) + int(
63
- AUDIO_BITS_PER_SAMPLE / 8
64
- )
65
-
66
-
67
- def generate_audio_header() -> AudioHeader:
68
- """Generate an audio header."""
69
- return AudioHeader(
70
- samples_per_second=AUDIO_SAMPLE_RATE,
71
- bits_per_sample=AUDIO_BITS_PER_SAMPLE,
72
- channel_count=AUDIO_CHANNEL_COUNT,
73
- )
74
-
75
-
76
- def generate_animation_data(time_codes: list[float], audio_buffer_size: int) -> AnimationData:
77
- """Generate an animation data stream raw frame.
78
-
79
- Args:
80
- time_codes: List of time codes in seconds for the animation keyframes.
81
- audio_buffer_size: Size of the audio buffer in bytes.
82
-
83
- Returns:
84
- AnimationData: The generated animation data object.
85
- """
86
- # Create a simple animation data object with skeletal animation
87
- animation_data = AnimationData(
88
- skel_animation=SkelAnimation(
89
- # Add blend shape weights with time code
90
- blend_shape_weights=[
91
- FloatArrayWithTimeCode(
92
- time_code=t, # time in seconds
93
- values=[0.1, 0.2, 0.3], # example blend shape weights
94
- )
95
- for t in time_codes
96
- ],
97
- # Add joint translations with time code
98
- translations=[
99
- Float3ArrayWithTimeCode(
100
- time_code=t,
101
- values=[
102
- Float3(x=0.0, y=0.0, z=0.0), # example translation for joint 1
103
- Float3(x=0.1, y=0.2, z=0.3), # example translation for joint 2
104
- ],
105
- )
106
- for t in time_codes
107
- ],
108
- # Add joint rotations with time code
109
- rotations=[
110
- QuatFArrayWithTimeCode(
111
- time_code=t,
112
- values=[
113
- QuatF(real=1.0, i=0.0, j=0.0, k=0.0), # example rotation for joint 1 (identity quaternion)
114
- QuatF(
115
- real=0.707, i=0.0, j=0.707, k=0.0
116
- ), # example rotation for joint 2 (45° rotation around Y)
117
- ],
118
- )
119
- for t in time_codes
120
- ],
121
- # Add joint scales with time code
122
- scales=[
123
- Float3ArrayWithTimeCode(
124
- time_code=t,
125
- values=[
126
- Float3(x=1.0, y=1.0, z=1.0), # example scale for joint 1 (no scaling)
127
- Float3(x=1.1, y=1.1, z=1.1), # example scale for joint 2 (uniform scaling)
128
- ],
129
- )
130
- for t in time_codes
131
- ],
132
- ),
133
- # Audio component
134
- audio=AudioWithTimeCode(
135
- time_code=0.0, # time in seconds relative to start_time_code_since_epoch
136
- audio_buffer=b"\xff" * audio_buffer_size, # In a real scenario, this would be PCM audio data as bytes
137
- ),
138
- )
139
- return animation_data
140
-
141
-
142
- def load_yaml(path: Path) -> dict[str, Any]:
143
- """Load YAML configuration from a file.
144
-
145
- Args:
146
- path: Path to the YAML file.
147
-
148
- Returns:
149
- dict: Parsed YAML content.
150
-
151
- Raises:
152
- FileNotFoundError: If the YAML file is not found.
153
- """
154
- try:
155
- return yaml.safe_load(path.read_text())
156
-
157
- except FileNotFoundError as error:
158
- message = "Error: yml config file not found."
159
- logger.exception(message)
160
- raise FileNotFoundError(error, message) from error
161
-
162
-
163
- def read_action_service_config(config_path: Path) -> AnimationGraphConfiguration:
164
- """Read and parse the animation graph service configuration.
165
-
166
- Args:
167
- config_path: Path to the configuration YAML file.
168
-
169
- Returns:
170
- AnimationGraphConfiguration: Parsed configuration object.
171
- """
172
- return AnimationGraphConfiguration(**load_yaml(config_path))
173
-
174
-
175
- looping_animations_to_test = [
176
- ("listening", "Listening"),
177
- ("thinking about something", "Thinking"),
178
- ("dancing", "Listening"),
179
- ]
180
-
181
-
182
- class MockResponse:
183
- """Mock HTTP response for testing animation graph service REST endpoints."""
184
-
185
- def __init__(self):
186
- """Initialize mock response with default status and headers."""
187
- self.status = HTTPStatus.OK
188
- self.headers = {"Content-Type": "application/json"}
189
-
190
- async def json(self):
191
- """Simulate async JSON response.
192
-
193
- Returns:
194
- dict: Mock response data.
195
- """
196
- await asyncio.sleep(0.1)
197
- return {"response": "OK"}
198
-
199
-
200
- @pytest.fixture
201
- def anim_graph():
202
- """Create and configure an Animation Graph Service instance for testing.
203
-
204
- Returns:
205
- AnimationGraphService: Configured service instance with test configuration.
206
- """
207
- animation_config = read_action_service_config(Path("./tests/unit/configs/animation_config.yaml"))
208
- message_broker_config = MessageBrokerConfig("local_queue", "")
209
-
210
- logger.info("Starting animation graph service initialization...")
211
- start_time = time.time()
212
- AnimationGraphService.pregenerate_animation_databases(animation_config)
213
- init_time = time.time() - start_time
214
- logger.info(f"Animation graph service initialized in {init_time * 1000:.2f}ms")
215
-
216
- ag = AnimationGraphService(
217
- animation_graph_rest_url="http://127.0.0.1:8020",
218
- animation_graph_grpc_target="127.0.0.1:51000",
219
- message_broker_config=message_broker_config,
220
- config=animation_config,
221
- )
222
-
223
- return ag
224
-
225
-
226
- @pytest.mark.asyncio
227
- @pytest.mark.parametrize("posture", looping_animations_to_test)
228
- @patch("aiohttp.ClientSession.request")
229
- async def test_simple_posture_pipeline(mock_get, posture, anim_graph):
230
- """Test basic posture pipeline functionality.
231
-
232
- Verifies that the pipeline correctly processes posture commands and
233
- generates appropriate animation frames.
234
-
235
- Args:
236
- mock_get: Mock for HTTP client requests.
237
- posture: Tuple of (natural language description, clip ID) for testing.
238
- anim_graph: Animation graph service fixture.
239
- """
240
- posture_nld, posture_clip_id = posture
241
- stream_id = "1235"
242
-
243
- # Mocking response from aiohttp.ClientSession.request
244
- mock_get.return_value.__aenter__.return_value = MockResponse()
245
-
246
- after_storage = FrameStorage()
247
- before_storage = FrameStorage()
248
- pipeline = Pipeline([before_storage, anim_graph, after_storage])
249
-
250
- async def test_routine(task: PipelineTask):
251
- # send events
252
- await task.queue_frame(TextFrame("Hello"))
253
- await task.queue_frame(StartPostureBotActionFrame(posture=posture_nld, action_id="posture_1"))
254
-
255
- # wait for the action to start
256
- started_posture_frame = ignore(StartedPostureBotActionFrame(action_id="posture_1"), "ids", "timestamps")
257
- await after_storage.wait_for_frame(started_posture_frame)
258
-
259
- # Ensure API endpoints was called in the right way
260
- mock_get.assert_called_once_with(
261
- "put",
262
- f"http://127.0.0.1:8020/streams/{stream_id}/animation_graphs/avatar/variables/posture_state/{posture_clip_id}",
263
- data="{}",
264
- headers={"Content-Type": "application/json", "x-stream-id": stream_id},
265
- params={},
266
- )
267
- # ensure we got the text frame as well as a presence started event (and no finished)
268
- assert after_storage.history[1].frame == ignore(TextFrame("Hello"), "all_ids", "timestamps")
269
- assert after_storage.history[3].frame == started_posture_frame
270
- assert len(after_storage.frames_of_type(FinishedPostureBotActionFrame)) == 0
271
-
272
- # stop the action and wait for it to finish
273
- finished_posture_frame = FinishedPostureBotActionFrame(
274
- action_id="posture_1", is_success=True, was_stopped=True, failure_reason=""
275
- )
276
- await task.queue_frame(StopPostureBotActionFrame(action_id="posture_1"))
277
- await after_storage.wait_for_frame(ignore(finished_posture_frame, "ids", "timestamps"))
278
- await before_storage.wait_for_frame(ignore(finished_posture_frame, "ids", "timestamps"))
279
-
280
- # make sure observed frames before and after the processor match
281
- # (this should be true for all action frame processors)
282
- assert len(before_storage.history) == len(after_storage.history)
283
- for before, after in zip(before_storage.history, after_storage.history, strict=False):
284
- assert before.frame == after.frame
285
-
286
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
287
-
288
-
289
- @pytest.mark.asyncio
290
- @patch("aiohttp.ClientSession.request")
291
- async def test_finite_animation(mock_get, anim_graph):
292
- """Test finite animation execution and completion.
293
-
294
- Verifies that a finite animation:
295
- - Starts correctly
296
- - Runs for the expected duration
297
- - Completes with proper frame sequence
298
- """
299
- # Mocking response from aiohttp.ClientSession.request
300
- mock_get.return_value.__aenter__.return_value = MockResponse()
301
-
302
- storage = FrameStorage()
303
- pipeline = Pipeline([anim_graph, storage])
304
-
305
- async def test_routine(task: PipelineTask):
306
- # send events
307
- await task.queue_frame(StartGestureBotActionFrame(gesture="Test", action_id="g1"))
308
-
309
- # wait for the action to start
310
- started_gesture_frame = StartedGestureBotActionFrame(action_id="g1")
311
- await storage.wait_for_frame(ignore(started_gesture_frame, "ids", "timestamps"))
312
- assert len(storage.frames_of_type(FinishedGestureBotActionFrame)) == 0
313
-
314
- # Action should not be finished
315
- await asyncio.sleep(0.3)
316
- assert len(storage.frames_of_type(FinishedGestureBotActionFrame)) == 0
317
-
318
- # Action should now be done
319
- await asyncio.sleep(0.3)
320
- assert len(storage.frames_of_type(FinishedGestureBotActionFrame)) == 1
321
-
322
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": "1235"})
323
-
324
-
325
- @pytest.mark.asyncio
326
- @patch("aiohttp.ClientSession.request")
327
- async def test_consecutive_postures(mock_get, anim_graph):
328
- """Test handling of consecutive posture commands.
329
-
330
- Verifies that the service correctly:
331
- - Processes multiple posture commands in sequence
332
- - Transitions between postures smoothly
333
- - Maintains correct frame order and state
334
- """
335
- # Mocking response from aiohttp.ClientSession.request
336
- mock_get.return_value.__aenter__.return_value = MockResponse()
337
-
338
- stream_id = "1235"
339
- setup_default_ace_logging(stream_id=stream_id, level="TRACE")
340
-
341
- after_storage = FrameStorage()
342
- before_storage = FrameStorage()
343
- pipeline = Pipeline([before_storage, anim_graph, after_storage])
344
-
345
- async def test_routine(task: PipelineTask):
346
- # start first posture
347
- await task.queue_frame(StartPostureBotActionFrame(posture="talking", action_id="posture_1"))
348
-
349
- # wait for first posture to start
350
- started_posture_1_frame = ignore(StartedPostureBotActionFrame(action_id="posture_1"), "ids", "timestamps")
351
- assert started_posture_1_frame.action_id == "posture_1"
352
- await after_storage.wait_for_frame(started_posture_1_frame)
353
-
354
- # start second posture
355
- await task.queue_frame(StartPostureBotActionFrame(posture="listening", action_id="posture_2"))
356
-
357
- # wait for second posture to start
358
- started_posture_2_frame = ignore(StartedPostureBotActionFrame(action_id="posture_2"), "ids", "timestamps")
359
- assert started_posture_2_frame.action_id == "posture_2"
360
- await after_storage.wait_for_frame(started_posture_2_frame)
361
-
362
- # Ensure API endpoints was called twice
363
- assert mock_get.call_count == 2
364
-
365
- # ensure we got the text frame as well as a presence started event (and no finished)
366
- assert after_storage.history[2].frame == started_posture_1_frame
367
- assert after_storage.history[4].frame == ignore(
368
- FinishedPostureBotActionFrame(
369
- action_id="posture_1", is_success=False, was_stopped=False, failure_reason="Action replaced."
370
- ),
371
- "ids",
372
- "timestamps",
373
- )
374
- assert after_storage.history[5].frame == started_posture_2_frame
375
- assert len(after_storage.frames_of_type(FinishedPostureBotActionFrame)) == 1
376
-
377
- # make sure observed frames before and after the processor match
378
- assert len(before_storage.history) == len(after_storage.history)
379
- for before, after in zip(before_storage.history, after_storage.history, strict=False):
380
- assert before.frame == after.frame
381
-
382
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
383
-
384
-
385
- @pytest.mark.asyncio
386
- @patch("aiohttp.ClientSession.request")
387
- async def test_immediate_stop_posture(mock_get, anim_graph):
388
- """Test immediate posture cancellation.
389
-
390
- Verifies that the service correctly handles immediate posture cancellation:
391
- - Processes stop command immediately after start
392
- - Generates appropriate failure frames
393
- - Maintains correct frame sequence
394
- """
395
- # Mocking response from aiohttp.ClientSession.request
396
- mock_get.return_value.__aenter__.return_value = MockResponse()
397
-
398
- stream_id = "1235"
399
- setup_default_ace_logging(stream_id=stream_id, level="TRACE")
400
-
401
- after_storage = FrameStorage()
402
- before_storage = FrameStorage()
403
- pipeline = Pipeline([before_storage, anim_graph, after_storage])
404
-
405
- async def test_routine(task: PipelineTask):
406
- # start/stop first posture
407
- start_posture_1_frame = StartPostureBotActionFrame(posture="talking", action_id="posture_1")
408
- stop_posture_1_frame = StopPostureBotActionFrame(action_id="posture_1")
409
- await task.queue_frames([start_posture_1_frame, stop_posture_1_frame])
410
-
411
- # wait for first posture to finish
412
- finished_posture_1_frame = ignore(
413
- FinishedPostureBotActionFrame(
414
- action_id="posture_1", was_stopped=True, is_success=False, failure_reason=ANY
415
- ),
416
- "ids",
417
- "timestamps",
418
- )
419
- await after_storage.wait_for_frame(finished_posture_1_frame)
420
-
421
- # check for the correct frame sequence
422
- assert after_storage.history[-2].frame == ignore(stop_posture_1_frame, "ids", "timestamps")
423
- assert after_storage.history[-1].frame == finished_posture_1_frame
424
-
425
- # make sure observed frames before and after the processor match
426
- assert len(before_storage.history) == len(after_storage.history)
427
- for before, after in zip(before_storage.history, after_storage.history, strict=True):
428
- assert before.frame == after.frame
429
-
430
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
431
-
432
-
433
- @pytest.mark.asyncio
434
- @patch("aiohttp.ClientSession.request")
435
- async def test_stacking_postures_with_interruptions(mock_get, anim_graph):
436
- """Test stacking postures with interruptions.
437
-
438
- In this test we reproduce a bug that was observed where the override statemachine
439
- would sometimes fail if we start many postures at the same time and also
440
- have interruptions.
441
- """
442
- # Mocking response from aiohttp.ClientSession.request
443
- mock_get.return_value.__aenter__.return_value = MockResponse()
444
-
445
- stream_id = "1235"
446
- setup_default_ace_logging(stream_id=stream_id, level="TRACE")
447
-
448
- after_storage = FrameStorage()
449
- before_storage = FrameStorage()
450
- pipeline = Pipeline([before_storage, anim_graph, after_storage])
451
-
452
- async def test_routine(task: PipelineTask):
453
- N = 200
454
- for i in range(N):
455
- await task.queue_frame(StartPostureBotActionFrame(posture="talking", action_id=f"posture_{i}"))
456
- await asyncio.sleep(0.05)
457
- if i == 5:
458
- await task.queue_frame(StartInterruptionFrame())
459
-
460
- final_posture_started = ignore(StartedPostureBotActionFrame(action_id=f"posture_{N - 1}"), "ids", "timestamps")
461
- await after_storage.wait_for_frame(final_posture_started)
462
-
463
- assert len(after_storage.frames_of_type(FinishedPostureBotActionFrame)) == N - 1
464
-
465
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
466
-
467
-
468
- @patch("aiohttp.ClientSession.request")
469
- async def test_handling_animation_data(mock_get, anim_graph):
470
- """Test animation data stream processing.
471
-
472
- Verifies that the service correctly:
473
- - Handles animation data stream frames
474
- - Processes animation data in correct sequence
475
- - Maintains proper timing between frames
476
- """
477
- # Mocking response from aiohttp.ClientSession.request
478
- mock_get.return_value.__aenter__.return_value = MockResponse()
479
-
480
- # Mocking gRPC stream
481
- mock_stub = MagicMock()
482
- stream_mock = AsyncMock()
483
- stream_mock.write.return_value = "OK"
484
- stream_mock.done = MagicMock(return_value=False) #
485
- mock_stub.PushAnimationDataStream.return_value = stream_mock
486
-
487
- stream_id = "1235"
488
- anim_graph.stub = mock_stub
489
- storage = FrameStorage()
490
- pipeline = Pipeline([anim_graph, storage])
491
-
492
- async def test_routine(task: PipelineTask):
493
- # send events
494
- await task.queue_frame(
495
- AnimationDataStreamStartedFrame(
496
- audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
497
- )
498
- )
499
-
500
- for _ in range(15):
501
- await task.queue_frame(
502
- AnimationDataStreamRawFrame(
503
- animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
504
- )
505
- )
506
- await asyncio.sleep(1.0 / 35.0)
507
-
508
- await task.queue_frame(AnimationDataStreamStoppedFrame(action_id="a1"))
509
- assert len(storage.frames_of_type(BotStartedSpeakingFrame)) == 1
510
- assert len(storage.frames_of_type(BotSpeakingFrame)) >= 1
511
-
512
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
513
-
514
-
515
- @patch("aiohttp.ClientSession.request")
516
- async def test_handling_bursted_animation_data(mock_get, anim_graph):
517
- """Test handling of bursted animation data.
518
-
519
- Verifies that the service correctly handles bursted animation data.
520
- """
521
- # Mocking response from aiohttp.ClientSession.request
522
- mock_get.return_value.__aenter__.return_value = MockResponse()
523
-
524
- # Mocking gRPC stream
525
- mock_stub = MagicMock()
526
- stream_mock = AsyncMock()
527
- stream_mock.write.return_value = "OK"
528
- stream_mock.done = MagicMock(return_value=False) #
529
- mock_stub.PushAnimationDataStream.return_value = stream_mock
530
-
531
- stream_id = "1235"
532
- anim_graph.stub = mock_stub
533
- storage = FrameStorage()
534
- pipeline = Pipeline([anim_graph, storage])
535
-
536
- async def test_routine(task: PipelineTask):
537
- # send events
538
- await task.queue_frame(
539
- AnimationDataStreamStartedFrame(
540
- audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
541
- )
542
- )
543
-
544
- for _ in range(30 * 3):
545
- await task.queue_frame(
546
- AnimationDataStreamRawFrame(
547
- animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
548
- )
549
- )
550
-
551
- await asyncio.sleep(2.5)
552
- assert stream_mock.done_writing.call_count == 0
553
-
554
- await task.queue_frame(AnimationDataStreamStoppedFrame(action_id="a1"))
555
- assert len(storage.frames_of_type(BotStartedSpeakingFrame)) == 1
556
- assert len(storage.frames_of_type(BotSpeakingFrame)) >= 1
557
-
558
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
559
-
560
-
561
- @patch("aiohttp.ClientSession.request")
562
- async def test_handling_interrupted_animation_data_stream(mock_get, anim_graph):
563
- """Test behavior if animation data stream is interrupted."""
564
- # Mocking response from aiohttp.ClientSession.request
565
- mock_get.return_value.__aenter__.return_value = MockResponse()
566
-
567
- # Mocking gRPC stream
568
- mock_stub = MagicMock()
569
- stream_mock = AsyncMock()
570
- stream_mock.write.return_value = "OK"
571
- stream_mock.done = MagicMock(return_value=False) #
572
- mock_stub.PushAnimationDataStream.return_value = stream_mock
573
-
574
- stream_id = "1235"
575
- anim_graph.stub = mock_stub
576
- storage = FrameStorage()
577
- pipeline = Pipeline([anim_graph, storage])
578
-
579
- async def test_routine(task: PipelineTask):
580
- # send events
581
- await task.queue_frame(
582
- AnimationDataStreamStartedFrame(
583
- audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
584
- )
585
- )
586
-
587
- for _ in range(5):
588
- await task.queue_frame(
589
- AnimationDataStreamRawFrame(
590
- animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
591
- )
592
- )
593
- await asyncio.sleep(1.0 / 30.0)
594
-
595
- await asyncio.sleep(2.0)
596
-
597
- assert stream_mock.done_writing.call_count == 1
598
-
599
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
600
-
601
-
602
- @patch("aiohttp.ClientSession.request")
603
- async def test_handling_low_fps_animation_data(mock_get, anim_graph):
604
- """Test animation data stream processing with low FPS.
605
-
606
- Verifies that the service correctly sends out an
607
- ErrorFrame when the animation data received is below 30 FPS.
608
- """
609
- # Mocking response from aiohttp.ClientSession.request
610
- mock_get.return_value.__aenter__.return_value = MockResponse()
611
-
612
- # Mocking gRPC stream
613
- mock_stub = MagicMock()
614
- stream_mock = AsyncMock()
615
- stream_mock.write.return_value = "OK"
616
- stream_mock.done = MagicMock(return_value=False) #
617
- mock_stub.PushAnimationDataStream.return_value = stream_mock
618
-
619
- stream_id = "1235"
620
- anim_graph.stub = mock_stub
621
- storage = FrameStorage()
622
- pipeline = Pipeline([anim_graph, storage])
623
-
624
- async def test_routine(task: PipelineTask):
625
- # send events
626
- await task.queue_frame(
627
- AnimationDataStreamStartedFrame(
628
- audio_header=generate_audio_header(), animation_header=None, action_id="a1", animation_source_id="test"
629
- )
630
- )
631
-
632
- for _ in range(20):
633
- await task.queue_frame(
634
- AnimationDataStreamRawFrame(
635
- animation_data=generate_animation_data([0.0], AUDIO_BUFFER_FOR_ONE_FRAME_SIZE), action_id="a1"
636
- )
637
- )
638
- await asyncio.sleep(1.0 / 24.5)
639
-
640
- await task.queue_frame(AnimationDataStreamStoppedFrame(action_id="a1"))
641
-
642
- await storage.wait_for_frame(
643
- ignore(
644
- ErrorFrame(error="AnimGraph: Received data stream is behind by more than 0.1s", fatal=False),
645
- "ids",
646
- "error",
647
- )
648
- )
649
-
650
- await run_interactive_test(pipeline, test_coroutine=test_routine, start_metadata={"stream_id": stream_id})
651
-
652
-
653
- @pytest.mark.asyncio
654
- async def test_animation_database_search(anim_graph):
655
- """Test animation database semantic search functionality.
656
-
657
- Verifies that searching for 'wave to the user' correctly returns the 'Goodbye'
658
- animation clip with appropriate semantic match scores.
659
- """
660
- # Get the gesture animation database
661
- gesture_db = AnimationGraphService.animation_databases["gesture"]
662
-
663
- # Search for the animation
664
- match = gesture_db.query_one("wave goodbye to the user somehow")
665
-
666
- # Verify we got a match and it's the Goodbye animation
667
- assert match.animation.id == "Goodbye", "Expected 'wave to the user' to map to 'Goodbye' animation"
668
- assert match.description_score > 0.5 or match.meaning_score > 0.5, "Expected high semantic match score"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_audio2face_3d_service.py DELETED
@@ -1,182 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the Audio2Face 3D Service."""
5
-
6
- import asyncio
7
- from unittest.mock import ANY, AsyncMock, MagicMock, patch
8
-
9
- import pytest
10
- from nvidia_audio2face_3d.messages_pb2 import AudioWithEmotionStream
11
- from pipecat.frames.frames import StartInterruptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
12
- from pipecat.pipeline.pipeline import Pipeline
13
-
14
- from nvidia_pipecat.frames.animation import (
15
- AnimationDataStreamRawFrame,
16
- AnimationDataStreamStartedFrame,
17
- AnimationDataStreamStoppedFrame,
18
- )
19
- from nvidia_pipecat.services.audio2face_3d_service import Audio2Face3DService
20
- from tests.unit.utils import FrameStorage, ignore, run_interactive_test
21
-
22
-
23
- class AsyncIterator:
24
- """Helper class for mocking async iteration in tests.
25
-
26
- Attributes:
27
- items: List of items to yield during iteration.
28
- """
29
-
30
- def __init__(self, items):
31
- """Initialize the async iterator.
32
-
33
- Args:
34
- items: List of items to yield during iteration.
35
- """
36
- self.items = items
37
-
38
- def __aiter__(self):
39
- """Return self as the iterator.
40
-
41
- Returns:
42
- AsyncIterator: Self reference for iteration.
43
- """
44
- return self
45
-
46
- async def __anext__(self):
47
- """Get the next item asynchronously.
48
-
49
- Returns:
50
- Any: Next item in the sequence.
51
-
52
- Raises:
53
- StopAsyncIteration: When no more items are available.
54
- """
55
- try:
56
- await asyncio.sleep(0.1)
57
- return self.items.pop(0)
58
- except IndexError as error:
59
- raise StopAsyncIteration from error
60
-
61
-
62
- def get_mock_stream():
63
- """Create a mock Audio2Face stream for testing.
64
-
65
- Returns:
66
- MagicMock: A configured mock object that simulates the A2F service stream,
67
- including header and animation data responses.
68
- """
69
- mock_stub = MagicMock(spec=["ProcessAudioStream"])
70
-
71
- # Configure mock responses
72
- header_response = MagicMock()
73
- header_response.HasField = lambda x: x == "animation_data_stream_header"
74
- header_response.animation_data_stream_header = MagicMock(
75
- audio_header=MagicMock(), skel_animation_header=MagicMock()
76
- )
77
-
78
- data_response = MagicMock()
79
- data_response.HasField = lambda x: x == "animation_data"
80
- data_response.animation_data = "test_animation_data"
81
-
82
- mock_stream = AsyncMock()
83
- mock_stream.__aiter__.return_value = [header_response, data_response]
84
- mock_stream.done = MagicMock(return_value=False) # Use regular MagicMock for done()
85
- mock_stub.ProcessAudioStream.return_value = mock_stream
86
-
87
- return mock_stub
88
-
89
-
90
- @pytest.mark.asyncio
91
- async def test_audio2face_3d_basic_flow():
92
- """Test the basic flow of the Audio2Face 3D service.
93
-
94
- Tests:
95
- - TTS audio frame processing
96
- - Animation data stream generation
97
- - Frame sequence validation
98
- - Frame count verification
99
-
100
- Raises:
101
- AssertionError: If frame sequence or counts are incorrect.
102
- """
103
- mock_stub = get_mock_stream()
104
-
105
- with patch("nvidia_pipecat.services.audio2face_3d_service.A2FControllerServiceStub", return_value=mock_stub):
106
- # Initialize test components
107
- storage = FrameStorage()
108
- a2f = Audio2Face3DService()
109
- pipeline = Pipeline([a2f, storage])
110
- audio = bytes([6] * (16000 * 2 + 1))
111
-
112
- async def test_routine(task):
113
- # Send TTS frames
114
- await task.queue_frame(TTSStartedFrame())
115
- await task.queue_frame(TTSAudioRawFrame(audio=audio[0:15999], sample_rate=16000, num_channels=1))
116
- await task.queue_frame(TTSAudioRawFrame(audio=audio[16000:], sample_rate=16000, num_channels=1))
117
- await task.queue_frame(TTSStoppedFrame())
118
-
119
- # Wait for animation started frame
120
- started_frame = ignore(
121
- AnimationDataStreamStartedFrame(
122
- audio_header=ANY, animation_header=ANY, animation_source_id="Audio2Face with Emotions"
123
- ),
124
- "all_ids",
125
- "timestamps",
126
- )
127
-
128
- await storage.wait_for_frame(started_frame)
129
-
130
- # Wait for animation data frame
131
- anim_frame = ignore(
132
- AnimationDataStreamRawFrame(animation_data="test_animation_data"),
133
- "all_ids",
134
- "timestamps",
135
- )
136
- await storage.wait_for_frame(anim_frame)
137
-
138
- # Wait for stopped frame
139
- stopped_frame = ignore(AnimationDataStreamStoppedFrame(), "all_ids", "timestamps")
140
- await storage.wait_for_frame(stopped_frame)
141
-
142
- # Verify the sequence of frames
143
- assert len(storage.frames_of_type(AnimationDataStreamStartedFrame)) == 1
144
- assert len(storage.frames_of_type(AnimationDataStreamRawFrame)) == 1
145
- assert len(storage.frames_of_type(AnimationDataStreamStoppedFrame)) == 1
146
-
147
- await run_interactive_test(pipeline, test_coroutine=test_routine)
148
-
149
-
150
- @pytest.mark.asyncio
151
- async def test_interruptions():
152
- """Test interruption handling in the Audio2Face 3D service.
153
-
154
- Verifies that the service correctly handles interruption events by:
155
- - Properly responding to StartInterruptionFrame
156
- - Sending end-of-audio signal to the A2F stream
157
- - Maintaining correct state during interruption
158
- """
159
- mock_stub = get_mock_stream()
160
-
161
- with patch("nvidia_pipecat.services.audio2face_3d_service.A2FControllerServiceStub", return_value=mock_stub):
162
- # Create service and storage
163
- storage = FrameStorage()
164
- a2f = Audio2Face3DService()
165
- pipeline = Pipeline([a2f, storage])
166
-
167
- async def test_routine(task):
168
- # Send TTS frames
169
- await task.queue_frame(TTSStartedFrame())
170
- await task.queue_frame(TTSAudioRawFrame(audio=b"test_audio", sample_rate=16000, num_channels=1))
171
-
172
- await asyncio.sleep(0.1)
173
- await task.queue_frame(StartInterruptionFrame())
174
- await asyncio.sleep(0.1)
175
-
176
- mock_stub.ProcessAudioStream.return_value.write.assert_called_with(
177
- AudioWithEmotionStream(end_of_audio=AudioWithEmotionStream.EndOfAudio())
178
- )
179
- print(f"{mock_stub}")
180
- print(f"storage.history: {storage.history}")
181
-
182
- await run_interactive_test(pipeline, test_coroutine=test_routine)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_audio_util.py DELETED
@@ -1,64 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for audio utility functionality.
5
-
6
- This module contains tests for audio processing components, particularly the AudioRecorder.
7
- It verifies the basic functionality of audio recording and processing pipelines.
8
- """
9
-
10
- from datetime import timedelta
11
- from pathlib import Path
12
-
13
- import pytest
14
- from pipecat.frames.frames import InputAudioRawFrame, TextFrame
15
- from pipecat.pipeline.pipeline import Pipeline
16
- from pipecat.tests.utils import SleepFrame
17
- from pipecat.transports.base_transport import TransportParams
18
-
19
- from nvidia_pipecat.processors.audio_util import AudioRecorder
20
- from nvidia_pipecat.utils.logging import setup_default_ace_logging
21
- from tests.unit.utils import SinusWaveProcessor, ignore, ignore_ids, run_test
22
-
23
-
24
- @pytest.mark.asyncio()
25
- async def test_audio_recorder():
26
- """Test the AudioRecorder processor functionality.
27
-
28
- Tests:
29
- - Audio frame processing from sine wave generator
30
- - WAV file writing
31
- - Sample rate conversion (16kHz to 24kHz)
32
- - Non-audio frame passthrough
33
-
34
- Raises:
35
- AssertionError: If audio file is not created or frame processing fails.
36
- """
37
- setup_default_ace_logging(level="TRACE")
38
-
39
- # Delete tmp audio file if it exists
40
- TMP_FILE = Path("./tmp_file.wav")
41
- if TMP_FILE.exists():
42
- TMP_FILE.unlink()
43
-
44
- recorder = AudioRecorder(output_file=str(TMP_FILE), params=TransportParams(audio_out_sample_rate=24000))
45
- sinus = SinusWaveProcessor(duration=timedelta(seconds=0.3))
46
-
47
- frames_to_send = [
48
- SleepFrame(0.5),
49
- TextFrame("go through"),
50
- ]
51
- expected_down_frames = [
52
- ignore(InputAudioRawFrame(audio=b"1" * 640, sample_rate=16000, num_channels=1), "audio", "id", "name")
53
- ] * sinus.audio_frame_count + [ignore_ids(TextFrame("go through"))]
54
- pipeline = Pipeline([sinus, recorder])
55
- print(f"expected number of frames: {sinus.audio_frame_count}")
56
- await run_test(
57
- pipeline,
58
- frames_to_send=frames_to_send,
59
- expected_down_frames=expected_down_frames,
60
- start_metadata={"stream_id": "1235"},
61
- )
62
-
63
- # Make sure that audio dump was generated
64
- assert TMP_FILE.is_file()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_basic_pipelines.py DELETED
@@ -1,130 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Basic functionality tests for pipecat pipelines."""
5
-
6
- import asyncio
7
-
8
- import pytest
9
- from loguru import logger
10
- from pipecat.frames.frames import StartFrame, TextFrame
11
- from pipecat.processors.aggregators.sentence import SentenceAggregator
12
- from pipecat.tests.utils import SleepFrame
13
-
14
- from nvidia_pipecat.utils.logging import logger_context, setup_default_ace_logging
15
- from tests.unit.utils import ignore_ids, run_test
16
-
17
-
18
- @pytest.mark.asyncio()
19
- async def test_simple_pipeline():
20
- """Example test for testing a pipecat pipeline. This test makes sure the basic pipeline related classes work."""
21
- aggregator = SentenceAggregator()
22
- frames_to_send = [TextFrame("Hello, "), TextFrame("world.")]
23
- expected_down_frames = [ignore_ids(TextFrame("Hello, world."))]
24
-
25
- await run_test(
26
- aggregator,
27
- frames_to_send=frames_to_send,
28
- expected_down_frames=expected_down_frames,
29
- start_metadata={"stream_id": "1235"},
30
- )
31
-
32
-
33
- @pytest.mark.asyncio()
34
- async def test_pipeline_with_stream_id():
35
- """Test pipeline creation with a stream_id.
36
-
37
- Verifies that a pipeline can be created with a specific stream_id and that
38
- the metadata is properly propagated with the StartFrame.
39
- """
40
- aggregator = SentenceAggregator()
41
- frames_to_send = [TextFrame("Hello, "), TextFrame("world.")]
42
- start_metadata = {"stream_id": "1234"}
43
-
44
- expected_start_frame = ignore_ids(StartFrame())
45
- expected_start_frame.metadata = start_metadata
46
- expected_down_frames = [expected_start_frame, ignore_ids(TextFrame("Hello, world."))]
47
-
48
- await run_test(
49
- aggregator,
50
- frames_to_send=frames_to_send,
51
- expected_down_frames=expected_down_frames,
52
- start_metadata=start_metadata,
53
- ignore_start=False,
54
- )
55
-
56
-
57
- @pytest.mark.asyncio()
58
- async def test_ace_logger_with_stream_id(capsys):
59
- """Test ACE logger behavior when stream_id is provided.
60
-
61
- Verifies that the logger correctly handles and displays stream_id in the logs.
62
- """
63
- setup_default_ace_logging(level="DEBUG")
64
- with logger.contextualize(stream_id="1237"):
65
- aggregator = SentenceAggregator()
66
- frames_to_send = [TextFrame("Hello, "), TextFrame("world.")]
67
- expected_down_frames = [ignore_ids(TextFrame("Hello, world."))]
68
- await run_test(aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
69
-
70
- captured = capsys.readouterr()
71
- assert "streamId=1237" in captured.err
72
-
73
-
74
- async def run_pipeline_task(stream: str):
75
- """Run a test pipeline task with a specific stream identifier.
76
-
77
- Creates and runs a pipeline that processes a sequence of text frames with
78
- sleep intervals, aggregating them into a single sentence.
79
-
80
- Args:
81
- stream: The stream identifier to use for the pipeline task.
82
- """
83
- REPETITIONS = 5
84
- frames_to_send = []
85
- aggregated_str = ""
86
- for i in range(REPETITIONS):
87
- frames_to_send.append(TextFrame(f"S{stream}-T{i}"))
88
- aggregated_str += f"S{stream}-T{i}"
89
- frames_to_send.append(SleepFrame(0.1))
90
-
91
- frames_to_send.append(TextFrame("."))
92
- expected_down_frames = [ignore_ids(TextFrame(f"{aggregated_str}."))]
93
- aggregator = SentenceAggregator()
94
- await run_test(aggregator, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
95
-
96
-
97
- @pytest.mark.asyncio()
98
- async def test_logging_with_multiple_pipelines_in_same_process(capsys):
99
- """Test logging behavior with multiple concurrent pipeline streams.
100
-
101
- Verifies that when multiple pipeline tasks are running concurrently in the same
102
- process, each stream's logs are correctly tagged with its respective stream_id.
103
- The test ensures proper isolation and identification of log messages across
104
- different pipeline streams.
105
-
106
- Args:
107
- capsys: Pytest fixture for capturing system output.
108
- """
109
- setup_default_ace_logging(level="TRACE")
110
-
111
- streams = ["777", "abc123"]
112
- tasks = []
113
-
114
- for stream in streams:
115
- task = asyncio.create_task(logger_context(run_pipeline_task(stream=stream), stream_id=stream))
116
- tasks.append(task)
117
-
118
- await asyncio.gather(*tasks)
119
-
120
- # Make sure the correct stream ID is logged for the different coroutines
121
- captured = capsys.readouterr()
122
- lines = captured.err.split("\n")
123
- for line in lines:
124
- for stream in streams:
125
- if f"S{stream}" in line:
126
- assert f"streamId={stream}" in line
127
-
128
- for task in asyncio.all_tasks():
129
- if "task_handler" in task.get_coro().__name__:
130
- task.cancel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_blingfire_text_aggregator.py DELETED
@@ -1,244 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for BlingfireTextAggregator."""
5
-
6
- import pytest
7
-
8
- from nvidia_pipecat.services.blingfire_text_aggregator import BlingfireTextAggregator
9
-
10
-
11
- class TestBlingfireTextAggregator:
12
- """Test suite for BlingfireTextAggregator."""
13
-
14
- def test_initialization(self):
15
- """Test that the aggregator initializes with empty text buffer."""
16
- aggregator = BlingfireTextAggregator()
17
- assert aggregator.text == ""
18
-
19
- @pytest.mark.asyncio()
20
- async def test_single_word_no_sentence(self):
21
- """Test that single words without sentence endings don't return sentences."""
22
- aggregator = BlingfireTextAggregator()
23
- result = await aggregator.aggregate("hello")
24
- assert result is None
25
- assert aggregator.text == "hello"
26
-
27
- @pytest.mark.asyncio()
28
- async def test_incomplete_sentence(self):
29
- """Test that incomplete sentences are buffered but not returned."""
30
- aggregator = BlingfireTextAggregator()
31
- result = await aggregator.aggregate("Hello there")
32
- assert result is None
33
- assert aggregator.text == "Hello there"
34
-
35
- @pytest.mark.asyncio()
36
- async def test_single_complete_sentence(self):
37
- """Test that a single complete sentence is detected and returned."""
38
- aggregator = BlingfireTextAggregator()
39
- result = await aggregator.aggregate("Hello world.")
40
- assert result is None # Single sentence won't trigger return
41
- assert aggregator.text == "Hello world."
42
-
43
- @pytest.mark.asyncio()
44
- async def test_multiple_sentences_detection(self):
45
- """Test that multiple sentences trigger return of the first complete sentence."""
46
- aggregator = BlingfireTextAggregator()
47
- result = await aggregator.aggregate("Hello world. How are you?")
48
- assert result == "Hello world."
49
- assert "How are you?" in aggregator.text
50
- assert "Hello world." not in aggregator.text
51
-
52
- @pytest.mark.asyncio()
53
- async def test_incremental_sentence_building(self):
54
- """Test building a sentence incrementally."""
55
- aggregator = BlingfireTextAggregator()
56
-
57
- # Add text piece by piece
58
- result = await aggregator.aggregate("Hello")
59
- assert result is None
60
- assert aggregator.text == "Hello"
61
-
62
- result = await aggregator.aggregate(" world")
63
- assert result is None
64
- assert aggregator.text == "Hello world"
65
-
66
- result = await aggregator.aggregate(".")
67
- assert result is None
68
- assert aggregator.text == "Hello world."
69
-
70
- @pytest.mark.asyncio()
71
- async def test_incremental_multiple_sentences(self):
72
- """Test building multiple sentences incrementally."""
73
- aggregator = BlingfireTextAggregator()
74
-
75
- # Build first sentence
76
- result = await aggregator.aggregate("Hello world.")
77
- assert result is None
78
- assert aggregator.text == "Hello world."
79
-
80
- # Add second sentence - this should trigger return
81
- result = await aggregator.aggregate(" How are you?")
82
- assert result == "Hello world."
83
- assert "How are you?" in aggregator.text
84
- assert "Hello world." not in aggregator.text
85
-
86
- @pytest.mark.asyncio()
87
- async def test_empty_string_input(self):
88
- """Test handling of empty string input."""
89
- aggregator = BlingfireTextAggregator()
90
- result = await aggregator.aggregate("")
91
- assert result is None
92
- assert aggregator.text == ""
93
-
94
- @pytest.mark.asyncio()
95
- async def test_whitespace_handling(self):
96
- """Test proper handling of whitespace in text."""
97
- aggregator = BlingfireTextAggregator()
98
- result = await aggregator.aggregate(" Hello world. ")
99
- assert result is None
100
- assert aggregator.text == " Hello world. "
101
-
102
- @pytest.mark.asyncio()
103
- async def test_multiple_sentences_in_single_call(self):
104
- """Test processing multiple sentences passed in a single aggregate call."""
105
- aggregator = BlingfireTextAggregator()
106
- result = await aggregator.aggregate("First sentence. Second sentence. Third sentence.")
107
- assert result == "First sentence."
108
- # Remaining text should contain the other sentences
109
- remaining_text = aggregator.text
110
- assert "Second sentence. Third sentence." in remaining_text
111
-
112
- @pytest.mark.asyncio()
113
- async def test_sentence_with_special_punctuation(self):
114
- """Test sentences with different punctuation marks."""
115
- aggregator = BlingfireTextAggregator()
116
-
117
- # Test exclamation mark
118
- result = await aggregator.aggregate("Hello world! How are you?")
119
- assert result == "Hello world!"
120
- assert "How are you?" in aggregator.text
121
-
122
- await aggregator.reset()
123
-
124
- # Test question mark
125
- result = await aggregator.aggregate("How are you? I'm fine.")
126
- assert result == "How are you?"
127
- assert "I'm fine." in aggregator.text
128
-
129
- @pytest.mark.asyncio()
130
- async def test_handle_interruption(self):
131
- """Test that handle_interruption clears the text buffer."""
132
- aggregator = BlingfireTextAggregator()
133
- await aggregator.aggregate("Hello world")
134
- assert aggregator.text == "Hello world"
135
-
136
- await aggregator.handle_interruption()
137
- assert aggregator.text == ""
138
-
139
- @pytest.mark.asyncio()
140
- async def test_reset(self):
141
- """Test that reset clears the text buffer."""
142
- aggregator = BlingfireTextAggregator()
143
- await aggregator.aggregate("Hello world")
144
- assert aggregator.text == "Hello world"
145
-
146
- await aggregator.reset()
147
- assert aggregator.text == ""
148
-
149
- @pytest.mark.asyncio()
150
- async def test_reset_after_sentence_detection(self):
151
- """Test reset functionality after sentence detection."""
152
- aggregator = BlingfireTextAggregator()
153
- result = await aggregator.aggregate("First sentence. Second sentence.")
154
- assert result == "First sentence."
155
- assert "Second sentence." in aggregator.text
156
-
157
- await aggregator.reset()
158
- assert aggregator.text == ""
159
-
160
- @pytest.mark.asyncio()
161
- async def test_consecutive_sentence_processing(self):
162
- """Test processing consecutive sentences through multiple aggregate calls."""
163
- aggregator = BlingfireTextAggregator()
164
-
165
- # First pair of sentences
166
- result = await aggregator.aggregate("First sentence. Second sentence.")
167
- assert result == "First sentence."
168
-
169
- # Add third sentence - should trigger return of second
170
- result = await aggregator.aggregate(" Third sentence.")
171
- assert result == "Second sentence."
172
- assert "Third sentence." in aggregator.text
173
-
174
- @pytest.mark.asyncio()
175
- async def test_long_sentence_handling(self):
176
- """Test handling of longer sentences."""
177
- aggregator = BlingfireTextAggregator()
178
- long_sentence = (
179
- "This is a very long sentence with many words that goes on and on "
180
- "and should still be handled correctly by the aggregator."
181
- )
182
- result = await aggregator.aggregate(long_sentence)
183
- assert result is None
184
- assert aggregator.text == long_sentence
185
-
186
- @pytest.mark.asyncio()
187
- async def test_sentence_boundaries_with_abbreviations(self):
188
- """Test sentence detection with abbreviations that contain periods."""
189
- aggregator = BlingfireTextAggregator()
190
- result = await aggregator.aggregate("Dr. Smith went to the U.S.A. He had a great time.")
191
- assert result == "Dr. Smith went to the U.S.A."
192
- assert "He had a great time." in aggregator.text
193
-
194
- @pytest.mark.asyncio()
195
- async def test_newline_handling(self):
196
- """Test handling of text with newlines."""
197
- aggregator = BlingfireTextAggregator()
198
- result = await aggregator.aggregate("First line.\nSecond line.")
199
- assert result == "First line."
200
- assert "Second line." in aggregator.text
201
-
202
- @pytest.mark.asyncio()
203
- async def test_mixed_sentence_endings(self):
204
- """Test text with mixed sentence ending punctuation."""
205
- aggregator = BlingfireTextAggregator()
206
- result = await aggregator.aggregate("What time is it? It's 3 PM! That's great.")
207
- assert result == "What time is it?"
208
- remaining = aggregator.text
209
- assert "It's 3 PM!" in remaining
210
- assert "That's great." in remaining
211
-
212
- @pytest.mark.asyncio()
213
- async def test_text_property_consistency(self):
214
- """Test that the text property always returns the current buffer state."""
215
- aggregator = BlingfireTextAggregator()
216
-
217
- # Initially empty
218
- assert aggregator.text == ""
219
-
220
- # After adding incomplete sentence
221
- await aggregator.aggregate("Hello")
222
- assert aggregator.text == "Hello"
223
-
224
- # After adding more text
225
- await aggregator.aggregate(" world")
226
- assert aggregator.text == "Hello world"
227
-
228
- # After completing sentence (but no return yet)
229
- await aggregator.aggregate(".")
230
- assert aggregator.text == "Hello world."
231
-
232
- # After triggering sentence return
233
- await aggregator.aggregate(" Next sentence.")
234
- assert "Next sentence." in aggregator.text
235
- assert "Hello world." not in aggregator.text
236
-
237
- @pytest.mark.asyncio()
238
- async def test_empty_sentences_filtering(self):
239
- """Test that empty sentences are properly filtered out."""
240
- aggregator = BlingfireTextAggregator()
241
- # Text with multiple periods that might create empty sentences
242
- result = await aggregator.aggregate("Hello... world.")
243
- assert result is None
244
- assert aggregator.text == "Hello... world."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_custom_view.py DELETED
@@ -1,203 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the CustomView related frames."""
5
-
6
- import json
7
-
8
- from nvidia_pipecat.frames.custom_view import (
9
- Button,
10
- ButtonListBlock,
11
- ButtonVariant,
12
- HeaderBlock,
13
- Hint,
14
- HintCarouselBlock,
15
- Image,
16
- ImageBlock,
17
- ImagePosition,
18
- ImageWithTextBlock,
19
- SelectableOption,
20
- SelectableOptionsGridBlock,
21
- StartCustomViewFrame,
22
- TableBlock,
23
- TextBlock,
24
- TextInputBlock,
25
- )
26
-
27
-
28
- def test_to_json_empty():
29
- """Tests empty custom view JSON serialization.
30
-
31
- Tests:
32
- - Empty frame conversion
33
- - Minimal configuration
34
- - Default values
35
-
36
- Raises:
37
- AssertionError: If JSON output doesn't match expected format.
38
- """
39
- frame = StartCustomViewFrame(action_id="test-action-id")
40
- result = frame.to_json()
41
- expected_result = {}
42
- assert result == json.dumps(expected_result)
43
-
44
-
45
- def test_to_json_simple():
46
- """Tests simple custom view JSON serialization.
47
-
48
- Tests:
49
- - Basic block configuration
50
- - Header block formatting
51
- - Single block conversion
52
-
53
- Raises:
54
- AssertionError: If JSON output doesn't match expected format.
55
- """
56
- frame = StartCustomViewFrame(
57
- action_id="test-action-id",
58
- blocks=[
59
- HeaderBlock(id="test-header", header="Test Header", level=1),
60
- ],
61
- )
62
- result = frame.to_json()
63
- expected_result = {
64
- "blocks": [{"id": "test-header", "type": "header", "data": {"header": "Test Header", "level": 1}}],
65
- }
66
- assert result == json.dumps(expected_result)
67
-
68
-
69
- def test_to_json_complex():
70
- """Tests complex custom view JSON serialization.
71
-
72
- Tests:
73
- - Multiple block types
74
- - Image handling
75
- - Nested data structures
76
-
77
- Raises:
78
- AssertionError: If JSON output doesn't match expected format.
79
- """
80
- # Sample base64 string for image data
81
- sample_base64 = "iVBORw0KGgoAAAANSUhEUgAAA5UAAAC5CAIAAAA3TIxUAADd3"
82
-
83
- frame = StartCustomViewFrame(
84
- action_id="test-action-id",
85
- blocks=[
86
- HeaderBlock(id="test-header", header="Test Header", level=1),
87
- TextBlock(id="test-text", text="Test Text"),
88
- ImageBlock(id="test-image", image=Image(url="https://example.com/image.jpg")),
89
- ImageWithTextBlock(
90
- id="test-image-with-text",
91
- image=Image(data=sample_base64),
92
- text="Test Text",
93
- image_position=ImagePosition.LEFT,
94
- ),
95
- TableBlock(id="test-table", headers=["Header 1", "Header 2"], rows=[["Row 1", "Row 2"]]),
96
- HintCarouselBlock(
97
- id="test-hint-carousel", hints=[Hint(id="test-hint", name="Test Hint", text="Test Text")]
98
- ),
99
- ButtonListBlock(
100
- id="test-button-list",
101
- buttons=[
102
- Button(
103
- id="test-button",
104
- active=True,
105
- toggled=False,
106
- variant=ButtonVariant.CONTAINED,
107
- text="Test Button",
108
- )
109
- ],
110
- ),
111
- SelectableOptionsGridBlock(
112
- id="test-selectable-options-grid",
113
- buttons=[
114
- SelectableOption(
115
- id="test-option",
116
- image=Image(url="https://example.com/image.jpg"),
117
- text="Test Option",
118
- active=True,
119
- toggled=False,
120
- )
121
- ],
122
- ),
123
- TextInputBlock(
124
- id="test-input",
125
- default_value="Test Default Value",
126
- value="Test Value",
127
- label="Test Label",
128
- input_type="text",
129
- ),
130
- ],
131
- )
132
- result = frame.to_json()
133
- expected_result = {
134
- "blocks": [
135
- {"id": "test-header", "type": "header", "data": {"header": "Test Header", "level": 1}},
136
- {"id": "test-text", "type": "paragraph", "data": {"text": "Test Text"}},
137
- {
138
- "id": "test-image",
139
- "type": "image",
140
- "data": {"image": {"url": "https://example.com/image.jpg"}},
141
- },
142
- {
143
- "id": "test-image-with-text",
144
- "type": "paragraph_with_image",
145
- "data": {
146
- "image": {"data": sample_base64},
147
- "text": "Test Text",
148
- "image_position": "left",
149
- },
150
- },
151
- {
152
- "id": "test-table",
153
- "type": "table",
154
- "data": {"headers": ["Header 1", "Header 2"], "rows": [["Row 1", "Row 2"]]},
155
- },
156
- {
157
- "id": "test-hint-carousel",
158
- "type": "hint_carousel",
159
- "data": {"hints": [{"name": "Test Hint", "text": "Test Text"}]},
160
- },
161
- {
162
- "id": "test-button-list",
163
- "type": "button_list",
164
- "data": {
165
- "buttons": [
166
- {
167
- "id": "test-button",
168
- "active": True,
169
- "toggled": False,
170
- "variant": "contained",
171
- "text": "Test Button",
172
- }
173
- ]
174
- },
175
- },
176
- {
177
- "id": "test-selectable-options-grid",
178
- "type": "selectable_options_grid",
179
- "data": {
180
- "buttons": [
181
- {
182
- "id": "test-option",
183
- "image": {"url": "https://example.com/image.jpg"},
184
- "text": "Test Option",
185
- "active": True,
186
- "toggled": False,
187
- }
188
- ]
189
- },
190
- },
191
- {
192
- "id": "test-input",
193
- "type": "text_input",
194
- "data": {
195
- "default_value": "Test Default Value",
196
- "value": "Test Value",
197
- "label": "Test Label",
198
- "input_type": "text",
199
- },
200
- },
201
- ],
202
- }
203
- assert result == json.dumps(expected_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_elevenlabs.py DELETED
@@ -1,184 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the ElevenLabsTTSServiceWithEndOfSpeech class."""
5
-
6
- import asyncio
7
- import base64
8
- import json
9
- from unittest.mock import AsyncMock, patch
10
-
11
- import pytest
12
- from loguru import logger
13
- from pipecat.frames.frames import TTSAudioRawFrame, TTSSpeakFrame, TTSStoppedFrame
14
- from pipecat.pipeline.pipeline import Pipeline
15
- from pipecat.pipeline.task import PipelineTask
16
- from websockets.protocol import State
17
-
18
- from nvidia_pipecat.services.elevenlabs import ElevenLabsTTSServiceWithEndOfSpeech
19
- from nvidia_pipecat.utils.logging import setup_default_ace_logging
20
- from tests.unit.utils import FrameStorage, ignore_ids, run_interactive_test
21
-
22
- setup_default_ace_logging(level="TRACE")
23
-
24
-
25
- class MockWebSocket:
26
- """Mock WebSocket for testing ElevenLabs service.
27
-
28
- Attributes:
29
- messages_to_return: List of messages to return during testing.
30
- sent_messages: List of messages sent through the socket.
31
- state: Current WebSocket connection state.
32
- close_rcvd: Close frame received flag.
33
- close_rcvd_then_sent: Close frame received and sent flag.
34
- close_sent: Close frame sent flag.
35
- closed: Whether the WebSocket is closed.
36
- """
37
-
38
- def __init__(self, messages_to_return):
39
- """Initialize MockWebSocket.
40
-
41
- Args:
42
- messages_to_return (list): List of messages to return during testing.
43
- """
44
- self.messages_to_return = messages_to_return
45
- self.sent_messages = []
46
- self.state = State.OPEN
47
- self.close_rcvd = None
48
- self.close_rcvd_then_sent = None
49
- self.close_sent = None
50
- self.closed = False
51
-
52
- async def send(self, message: str) -> None:
53
- """Sends a message through the mock socket.
54
-
55
- Args:
56
- message (str): Message to send.
57
- """
58
- self.sent_messages.append(json.loads(message))
59
-
60
- async def ping(self) -> bool:
61
- """Simulates WebSocket heartbeat.
62
-
63
- Returns:
64
- bool: Always True for testing.
65
- """
66
- return True
67
-
68
- async def close(self) -> bool:
69
- """Closes the mock WebSocket connection.
70
-
71
- Returns:
72
- bool: Always True for testing.
73
- """
74
- self.state = State.CLOSED
75
- self.closed = True
76
- return True
77
-
78
- async def __aiter__(self):
79
- """Async iterator for messages.
80
-
81
- Yields:
82
- str: JSON-encoded message.
83
- """
84
- for msg in self.messages_to_return:
85
- yield json.dumps(msg)
86
- while self.state != State.CLOSED:
87
- await asyncio.sleep(1.0)
88
- yield "{}"
89
-
90
-
91
- @pytest.mark.asyncio()
92
- async def test_elevenlabs_tts_service_with_end_of_speech():
93
- """Test ElevenLabsTTSServiceWithEndOfSpeech functionality.
94
-
95
- Tests:
96
- - End-of-speech boundary marker handling
97
- - Audio message processing
98
- - Alignment message processing
99
- - TTSStoppedFrame generation
100
-
101
- Raises:
102
- AssertionError: If frame processing or timing is incorrect.
103
- """
104
- # Test audio data
105
- test_audio = b"test_audio_data"
106
- test_audio_b64 = base64.b64encode(test_audio).decode()
107
-
108
- # Test cases with different message sequences
109
- testcases = {
110
- "Normal audio with boundary marker": {
111
- "frames_to_send": [TTSSpeakFrame("Hello")],
112
- "messages": [
113
- {
114
- "audio": test_audio_b64,
115
- "alignment": {
116
- "chars": ["H", "e", "l", "l", "o", "\u200b"],
117
- "charStartTimesMs": [0, 3, 7, 9, 11, 12],
118
- "charDurationsMs": [3, 4, 2, 2, 1, 1],
119
- },
120
- },
121
- ],
122
- "expected_frames": [
123
- TTSAudioRawFrame(test_audio, 16000, 1),
124
- TTSStoppedFrame(),
125
- ],
126
- },
127
- "Multiple audio chunks": {
128
- "frames_to_send": [TTSSpeakFrame("Test")],
129
- "messages": [
130
- {
131
- "audio": test_audio_b64,
132
- "alignment": {
133
- "chars": ["T", "e"],
134
- "charStartTimesMs": [0, 3],
135
- "charDurationsMs": [3, 4],
136
- },
137
- },
138
- {
139
- "audio": test_audio_b64,
140
- "alignment": {
141
- "chars": ["s", "t", "\u200b"],
142
- "charStartTimesMs": [7, 9, 11],
143
- "charDurationsMs": [2, 2, 1],
144
- },
145
- },
146
- ],
147
- "expected_frames": [
148
- TTSAudioRawFrame(test_audio, 16000, 1),
149
- TTSAudioRawFrame(test_audio, 16000, 1),
150
- TTSStoppedFrame(),
151
- ],
152
- },
153
- }
154
-
155
- for tc_name, tc_data in testcases.items():
156
- logger.info(f"Verifying test case: {tc_name}")
157
-
158
- # Create mock websocket with test messages
159
- mock_websocket = MockWebSocket(tc_data["messages"])
160
-
161
- # mock = AsyncMock()
162
- # mock.return_value = mock_websocket
163
-
164
- with patch("pipecat.services.elevenlabs.tts.websockets.connect", new=AsyncMock()) as mock:
165
- mock.return_value = mock_websocket
166
- tts_service = ElevenLabsTTSServiceWithEndOfSpeech(
167
- api_key="test_api_key", voice_id="test_voice_id", sample_rate=16000, channels=1
168
- )
169
-
170
- storage = FrameStorage()
171
- pipeline = Pipeline([tts_service, storage])
172
-
173
- async def test_routine(task: PipelineTask, test_data=tc_data, s=storage):
174
- for frame in test_data["frames_to_send"]:
175
- await task.queue_frame(frame)
176
- # Wait for all expected frames
177
- for expected_frame in test_data["expected_frames"]:
178
- await s.wait_for_frame(ignore_ids(expected_frame))
179
- print(f"got frame to be sent {expected_frame}")
180
-
181
- # TODO: investigate why we need to cancel here
182
- await task.cancel()
183
-
184
- await run_interactive_test(pipeline, test_coroutine=test_routine)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_frame_creation.py DELETED
@@ -1,148 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for action frame creation and manipulation.
5
-
6
- This module tests the creation, validation, and comparison of various action frames used for bot and user actions.
7
- """
8
-
9
- # ruff: noqa: F405
10
-
11
- from typing import Any
12
-
13
- import pytest
14
- from pipecat.frames.frames import TextFrame
15
-
16
- from nvidia_pipecat.frames.action import * # noqa: F403
17
- from nvidia_pipecat.frames.action import (
18
- FinishedFacialGestureBotActionFrame,
19
- StartedFacialGestureBotActionFrame,
20
- StartFacialGestureBotActionFrame,
21
- StopFacialGestureBotActionFrame,
22
- )
23
- from tests.unit.utils import ignore_ids
24
-
25
-
26
- def test_action_frame_basic_usage():
27
- """Tests basic action frame functionality.
28
-
29
- Tests:
30
- - Frame creation with parameters
31
- - Action ID propagation
32
- - Frame name generation
33
- - Frame attribute access
34
-
35
- Raises:
36
- AssertionError: If frame attributes don't match expected values.
37
- """
38
- start_frame = StartFacialGestureBotActionFrame(facial_gesture="wink")
39
- action_id = start_frame.action_id
40
-
41
- started_frame = StartedFacialGestureBotActionFrame(action_id=action_id)
42
- stop_frame = StopFacialGestureBotActionFrame(action_id=action_id)
43
- finished_frame = FinishedFacialGestureBotActionFrame(action_id=action_id)
44
-
45
- assert started_frame.action_id == action_id
46
- assert stop_frame.action_id == action_id
47
- assert stop_frame.name == "StopFacialGestureBotActionFrame#0"
48
- assert finished_frame.action_id == action_id
49
- assert finished_frame.name == "FinishedFacialGestureBotActionFrame#0"
50
-
51
-
52
- def test_required_parameters():
53
- """Tests parameter validation in frame creation.
54
-
55
- Tests:
56
- - Required parameter enforcement
57
- - Type checking
58
- - Error handling
59
-
60
- Raises:
61
- TypeError: When required parameters are missing.
62
- """
63
- with pytest.raises(TypeError):
64
- StartedFacialGestureBotActionFrame() # type: ignore
65
-
66
-
67
- def test_action_frame_existence():
68
- """Tests frame class contract compliance.
69
-
70
- Tests:
71
- - Frame class initialization
72
- - Parameter handling
73
- - Action ID validation
74
- - Frame type verification
75
-
76
- Raises:
77
- TypeError: When frame initialization fails.
78
- AssertionError: If frame attributes are incorrect.
79
- """
80
- actions: dict[str, dict[str, dict[str, Any]]] = {
81
- "FacialGestureBotActionFrame": {
82
- "Start": {"facial_gesture": "wink"},
83
- "Started": {},
84
- "Stop": {},
85
- "Finished": {},
86
- },
87
- "GestureBotActionFrame": {
88
- "Start": {"gesture": "wave"},
89
- "Started": {},
90
- "Stop": {},
91
- "Finished": {},
92
- },
93
- "PostureBotActionFrame": {
94
- "Start": {"posture": "listening"},
95
- "Started": {},
96
- "Stop": {},
97
- "Finished": {},
98
- },
99
- "PositionBotActionFrame": {
100
- "Start": {"position": "left"},
101
- "Started": {},
102
- "Stop": {},
103
- "Updated": {"position_reached": "left"},
104
- "Finished": {},
105
- },
106
- "AttentionUserActionFrame": {
107
- "Started": {"attention_level": "attentive"},
108
- "Updated": {"attention_level": "inattentive"},
109
- "Finished": {},
110
- },
111
- "PresenceUserActionFrame": {
112
- "Started": {},
113
- "Finished": {},
114
- },
115
- }
116
-
117
- for action_name, frame_type in actions.items():
118
- for f, args in frame_type.items():
119
- if f == "Start":
120
- test = globals()[f"{f}{action_name}"](**args)
121
- assert test.name
122
- else:
123
- with pytest.raises(TypeError):
124
- test = globals()[f"{f}{action_name}"]()
125
-
126
- args["action_id"] = "1234"
127
- test = globals()[f"{f}{action_name}"](**args)
128
- assert test.action_id == "1234"
129
-
130
-
131
- def test_frame_comparison_ignoring_ids():
132
- """Tests frame comparison with ID ignoring.
133
-
134
- Tests:
135
- - Frame equality comparison
136
- - ID-independent comparison
137
- - Content-based comparison
138
-
139
- Raises:
140
- AssertionError: If frame comparison results are incorrect.
141
- """
142
- a = TextFrame(text="test")
143
- b = TextFrame(text="test")
144
- c = TextFrame(text="something")
145
-
146
- assert a != b
147
- assert a == ignore_ids(b)
148
- assert a != ignore_ids(c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_gesture.py DELETED
@@ -1,94 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the FacialGestureProviderProcessor."""
5
-
6
- import pytest
7
- from pipecat.frames.frames import (
8
- BotStartedSpeakingFrame,
9
- BotStoppedSpeakingFrame,
10
- StartInterruptionFrame,
11
- UserStoppedSpeakingFrame,
12
- )
13
- from pipecat.tests.utils import SleepFrame
14
-
15
- from nvidia_pipecat.frames.action import StartFacialGestureBotActionFrame
16
- from nvidia_pipecat.processors.gesture_provider import FacialGestureProviderProcessor
17
- from tests.unit.utils import ignore_ids, run_test
18
-
19
-
20
- @pytest.mark.asyncio()
21
- async def test_gesture_provider_processor_interrupt():
22
- """Test facial gesture generation for bot speech start and interruption.
23
-
24
- Tests that the processor generates appropriate facial gestures when receiving
25
- BotStartedSpeakingFrame and StartInterruptionFrame events.
26
-
27
- The test verifies:
28
- - Correct handling of BotStartedSpeakingFrame
29
- - Correct handling of StartInterruptionFrame
30
- - Generation of "Pensive" facial gesture
31
- """
32
- frames_to_send = [
33
- BotStartedSpeakingFrame(),
34
- SleepFrame(0.1),
35
- StartInterruptionFrame(),
36
- ]
37
- expected_down_frames = [
38
- ignore_ids(BotStartedSpeakingFrame()),
39
- ignore_ids(StartInterruptionFrame()),
40
- ignore_ids(StartFacialGestureBotActionFrame(facial_gesture="Pensive")),
41
- ]
42
-
43
- await run_test(
44
- FacialGestureProviderProcessor(probability=1.0),
45
- frames_to_send=frames_to_send,
46
- expected_down_frames=expected_down_frames,
47
- )
48
-
49
-
50
- @pytest.mark.asyncio()
51
- async def test_gesture_provider_processor_bot_finished():
52
- """Test facial gesture processing for bot speech completion.
53
-
54
- Tests that the processor handles BotStoppedSpeakingFrame by passing it through
55
- without generating additional gestures.
56
-
57
- The test verifies:
58
- - Correct passthrough of BotStoppedSpeakingFrame
59
- - No additional gesture generation
60
- """
61
- frames_to_send = [BotStoppedSpeakingFrame()]
62
- expected_down_frames = [
63
- ignore_ids(BotStoppedSpeakingFrame()),
64
- ]
65
-
66
- await run_test(
67
- FacialGestureProviderProcessor(probability=1.0),
68
- frames_to_send=frames_to_send,
69
- expected_down_frames=expected_down_frames,
70
- )
71
-
72
-
73
- @pytest.mark.asyncio()
74
- async def test_gesture_provider_processor_tts():
75
- """Test facial gesture processing for interruption events.
76
-
77
- Tests that the processor generates a "Taunt" facial gesture when receiving
78
- UserStoppedSpeakingFrame events.
79
-
80
- The test verifies:
81
- - Correct handling of UserStoppedSpeakingFrame
82
- - Generation of "Taunt" facial gesture
83
- """
84
- frames_to_send = [UserStoppedSpeakingFrame()]
85
- expected_down_frames = [
86
- ignore_ids(UserStoppedSpeakingFrame()),
87
- ignore_ids(StartFacialGestureBotActionFrame(facial_gesture="Taunt")),
88
- ]
89
-
90
- await run_test(
91
- FacialGestureProviderProcessor(probability=1.0),
92
- frames_to_send=frames_to_send,
93
- expected_down_frames=expected_down_frames,
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_guardrail.py DELETED
@@ -1,110 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the user presence frame processor."""
5
-
6
- import pytest
7
- from pipecat.frames.frames import TranscriptionFrame, TTSSpeakFrame
8
- from pipecat.utils.time import time_now_iso8601
9
-
10
- from nvidia_pipecat.processors.guardrail import GuardrailProcessor
11
- from tests.unit.utils import ignore_ids, run_test
12
-
13
-
14
- @pytest.mark.asyncio
15
- async def test_blocked_word():
16
- """Tests blocking functionality for explicitly blocked words.
17
-
18
- Tests that the processor replaces transcription frames containing blocked words
19
- with a TTSSpeakFrame containing a rejection message.
20
-
21
- Args:
22
- None
23
-
24
- Returns:
25
- None
26
-
27
- The test verifies:
28
- - Input containing "football" is blocked
29
- - Response is replaced with rejection message
30
- """
31
- guardrail_bot = GuardrailProcessor(blocked_words=["football"])
32
- frames_to_send = [TranscriptionFrame(text="I love football", user_id="", timestamp=time_now_iso8601())]
33
- expected_down_frames = [ignore_ids(TTSSpeakFrame("I am not allowed to answer this question"))]
34
-
35
- await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
36
-
37
-
38
- @pytest.mark.asyncio
39
- async def test_non_blocked_word():
40
- """Tests passthrough of allowed words.
41
-
42
- Tests that the processor allows transcription frames that don't contain
43
- any blocked words to pass through unchanged.
44
-
45
- Args:
46
- None
47
-
48
- Returns:
49
- None
50
-
51
- The test verifies:
52
- - Input without blocked words passes through
53
- - Frame content remains unchanged
54
- """
55
- guardrail_bot = GuardrailProcessor(blocked_words=["football"])
56
- timestamp = time_now_iso8601()
57
- frames_to_send = [TranscriptionFrame(text="Tell me about Pasta", user_id="", timestamp=timestamp)]
58
- expected_down_frames = [ignore_ids(TranscriptionFrame(text="Tell me about Pasta", user_id="", timestamp=timestamp))]
59
-
60
- await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
61
-
62
-
63
- @pytest.mark.asyncio
64
- async def test_substring_blocked_word():
65
- """Tests substring matching behavior for blocked words.
66
-
67
- Tests that the processor only blocks exact word matches and allows
68
- words that contain blocked words as substrings.
69
-
70
- Args:
71
- None
72
-
73
- Returns:
74
- None
75
-
76
- The test verifies:
77
- - Words containing blocked words as substrings are allowed
78
- - Frame content remains unchanged
79
- """
80
- guardrail_bot = GuardrailProcessor(blocked_words=["foot"])
81
- timestamp = time_now_iso8601()
82
- frames_to_send = [TranscriptionFrame(text="I love football", user_id="", timestamp=timestamp)]
83
- expected_down_frames = [ignore_ids(TranscriptionFrame(text="I love football", user_id="", timestamp=timestamp))]
84
-
85
- await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
86
-
87
-
88
- @pytest.mark.asyncio
89
- async def test_no_blocked_word():
90
- """Tests default behavior with no blocked words configured.
91
-
92
- Tests that the processor allows all transcription frames to pass through
93
- when no blocked words are specified.
94
-
95
- Args:
96
- None
97
-
98
- Returns:
99
- None
100
-
101
- The test verifies:
102
- - All input passes through when no words are blocked
103
- - Frame content remains unchanged
104
- """
105
- guardrail_bot = GuardrailProcessor()
106
- timestamp = time_now_iso8601()
107
- frames_to_send = [TranscriptionFrame(text="What is your name", user_id="", timestamp=timestamp)]
108
- expected_down_frames = [ignore_ids(TranscriptionFrame(text="What is your name", user_id="", timestamp=timestamp))]
109
-
110
- await run_test(guardrail_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_message_broker.py DELETED
@@ -1,111 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the MessageBroker."""
5
-
6
- import asyncio
7
- from datetime import timedelta
8
-
9
- import pytest
10
-
11
- from nvidia_pipecat.utils.message_broker import LocalQueueMessageBroker
12
-
13
-
14
- async def send_task(mb: LocalQueueMessageBroker) -> bool:
15
- """Send test messages to the message broker.
16
-
17
- Args:
18
- mb: The LocalQueueMessageBroker instance to send messages to.
19
-
20
- Returns:
21
- bool: True if all messages were sent successfully.
22
- """
23
- for i in range(5):
24
- if i % 2 == 0:
25
- await mb.send_message("first", f"message {i}")
26
- else:
27
- await mb.send_message("second", f"message {i}")
28
- await asyncio.sleep(0.08)
29
- return True
30
-
31
-
32
- async def receive_task(mb: LocalQueueMessageBroker) -> list[str]:
33
- """Receive messages from the message broker.
34
-
35
- Args:
36
- mb: The LocalQueueMessageBroker instance to receive messages from.
37
-
38
- Returns:
39
- list[str]: List of received message data.
40
- """
41
- result: list[str] = []
42
- while len(result) < 5:
43
- messages = await mb.receive_messages(timeout=timedelta(seconds=0.1))
44
- for _, message_data in messages:
45
- result.append(message_data)
46
- return result
47
-
48
-
49
- @pytest.mark.asyncio()
50
- async def test_local_queue_message_broker():
51
- """Tests basic concurrent send/receive functionality.
52
-
53
- Tests that messages can be successfully sent and received when send and
54
- receive tasks run concurrently.
55
-
56
- The test verifies:
57
- - All messages are received in order
58
- - Send operation completes successfully
59
- - Messages are correctly routed between channels
60
- """
61
- mb = LocalQueueMessageBroker(channels=["first", "second"])
62
- task_1 = asyncio.create_task(send_task(mb))
63
- task_2 = asyncio.create_task(receive_task(mb))
64
-
65
- results = await asyncio.gather(task_1, task_2)
66
-
67
- assert results == [True, ["message 0", "message 1", "message 2", "message 3", "message 4"]]
68
-
69
-
70
- @pytest.mark.asyncio()
71
- async def test_local_queue_message_broker_receive_first():
72
- """Tests message delivery when receive starts before send.
73
-
74
- Tests that no messages are lost when the receive task is started before
75
- any messages are sent.
76
-
77
- The test verifies:
78
- - All messages are received in order despite delayed send
79
- - Send operation completes successfully
80
- - No messages are lost due to timing
81
- """
82
- mb = LocalQueueMessageBroker(channels=["first", "second"])
83
- task_2 = asyncio.create_task(receive_task(mb))
84
- await asyncio.sleep(0.2)
85
- task_1 = asyncio.create_task(send_task(mb))
86
-
87
- results = await asyncio.gather(task_1, task_2)
88
-
89
- assert results == [True, ["message 0", "message 1", "message 2", "message 3", "message 4"]]
90
-
91
-
92
- @pytest.mark.asyncio()
93
- async def test_local_queue_message_broker_send_first():
94
- """Tests message delivery when send completes before receive starts.
95
-
96
- Tests that no messages are lost when all messages are sent before the
97
- receive task begins.
98
-
99
- The test verifies:
100
- - All messages are received in order despite delayed receive
101
- - Send operation completes successfully
102
- - Messages are properly queued until received
103
- """
104
- mb = LocalQueueMessageBroker(channels=["first", "second"])
105
- task_1 = asyncio.create_task(send_task(mb))
106
- await asyncio.sleep(0.5)
107
- task_2 = asyncio.create_task(receive_task(mb))
108
-
109
- results = await asyncio.gather(task_1, task_2)
110
-
111
- assert results == [True, ["message 0", "message 1", "message 2", "message 3", "message 4"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_nvidia_aggregators.py DELETED
@@ -1,396 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the Nvidia Aggregators."""
5
-
6
- import pytest
7
- from pipecat.frames.frames import (
8
- LLMFullResponseEndFrame,
9
- LLMFullResponseStartFrame,
10
- LLMMessagesFrame,
11
- StartInterruptionFrame,
12
- TextFrame,
13
- TranscriptionFrame,
14
- UserStartedSpeakingFrame,
15
- UserStoppedSpeakingFrame,
16
- )
17
- from pipecat.pipeline.pipeline import Pipeline
18
- from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
19
- from pipecat.tests.utils import SleepFrame
20
- from pipecat.tests.utils import run_test as run_pipecat_test
21
- from pipecat.utils.time import time_now_iso8601
22
-
23
- from nvidia_pipecat.frames.riva import RivaInterimTranscriptionFrame
24
- from nvidia_pipecat.processors.nvidia_context_aggregator import (
25
- NvidiaUserContextAggregator,
26
- create_nvidia_context_aggregator,
27
- )
28
-
29
-
30
- @pytest.mark.asyncio()
31
- async def test_normal_flow():
32
- """Test the normal flow of user and assistant interactions with interim transcriptions enabled.
33
-
34
- Tests the sequence of events from user speech start through assistant response,
35
- verifying proper handling of interim and final transcriptions.
36
-
37
- The test verifies:
38
- - User speech start frame handling
39
- - Interim transcription processing
40
- - Final transcription handling
41
- - User speech stop frame handling
42
- - Assistant response processing
43
- - Context updates at each stage
44
- """
45
- messages = []
46
- context = OpenAILLMContext(messages)
47
- context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
48
-
49
- pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
50
- messages.append({"role": "system", "content": "This is system prompt"})
51
- # Test Case 1: Normal flow with UserStartedSpeakingFrame first
52
- frames_to_send = [
53
- UserStartedSpeakingFrame(),
54
- LLMMessagesFrame(messages),
55
- RivaInterimTranscriptionFrame("Hello", "", time_now_iso8601(), None, stability=1.0),
56
- TranscriptionFrame("Hello User Aggregator!", 1, 2),
57
- SleepFrame(0.1),
58
- UserStoppedSpeakingFrame(),
59
- SleepFrame(0.1),
60
- ]
61
- # Assistant response
62
- frames_to_send.extend(
63
- [
64
- LLMFullResponseStartFrame(),
65
- TextFrame("Hello Assistant Aggregator!"),
66
- LLMFullResponseEndFrame(),
67
- ]
68
- )
69
- expected_down_frames = [
70
- UserStartedSpeakingFrame,
71
- StartInterruptionFrame,
72
- StartInterruptionFrame,
73
- OpenAILLMContextFrame, # From first interim
74
- UserStoppedSpeakingFrame,
75
- OpenAILLMContextFrame,
76
- ]
77
- await run_pipecat_test(
78
- pipeline,
79
- frames_to_send=frames_to_send,
80
- expected_down_frames=expected_down_frames,
81
- )
82
-
83
- # Verify final context state
84
- assert context_aggregator.user().context.get_messages() == [
85
- {"role": "user", "content": "Hello User Aggregator!"},
86
- {"role": "assistant", "content": "Hello Assistant Aggregator!"},
87
- ]
88
-
89
-
90
- @pytest.mark.asyncio()
91
- async def test_user_speaking_frame_delay_cases():
92
- """Test handling of transcription frames that arrive before UserStartedSpeakingFrame.
93
-
94
- Tests edge cases around transcription frame timing relative to the
95
- UserStartedSpeakingFrame.
96
-
97
- The test verifies:
98
- - Interim frames before UserStartedSpeakingFrame are ignored
99
- - Low stability interim frames are ignored
100
- - Only processes transcriptions after UserStartedSpeakingFrame
101
- - Context is updated correctly for valid frames
102
- """
103
- messages = []
104
- context = OpenAILLMContext(messages)
105
- context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
106
-
107
- pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
108
- messages.append({"role": "system", "content": "This is system prompt"})
109
-
110
- # Test Case 2: RivaInterimTranscriptionFrames before UserStartedSpeakingFrame
111
- frames_to_send = [
112
- RivaInterimTranscriptionFrame(
113
- "Testing", "", time_now_iso8601(), None, stability=0.5
114
- ), # Should be ignored (low stability)
115
- RivaInterimTranscriptionFrame(
116
- "Testing delayed", "", time_now_iso8601(), None, stability=1.0
117
- ), # Should be ignored (no UserStartedSpeakingFrame yet)
118
- SleepFrame(0.5),
119
- UserStartedSpeakingFrame(),
120
- RivaInterimTranscriptionFrame(
121
- "Testing after start", "", time_now_iso8601(), None, stability=1.0
122
- ), # Should be processed
123
- TranscriptionFrame("Testing after start complete", 1, 2),
124
- SleepFrame(0.1),
125
- UserStoppedSpeakingFrame(),
126
- SleepFrame(0.1),
127
- ]
128
-
129
- # Assistant response
130
- frames_to_send.extend(
131
- [
132
- LLMFullResponseStartFrame(),
133
- TextFrame("Hello Assistant Aggregator!"),
134
- LLMFullResponseEndFrame(),
135
- ]
136
- )
137
- expected_down_frames = [
138
- UserStartedSpeakingFrame,
139
- StartInterruptionFrame,
140
- StartInterruptionFrame,
141
- OpenAILLMContextFrame, # from first interim after UserStartedSpeakingFrame
142
- UserStoppedSpeakingFrame,
143
- OpenAILLMContextFrame,
144
- ]
145
-
146
- await run_pipecat_test(
147
- pipeline,
148
- frames_to_send=frames_to_send,
149
- expected_down_frames=expected_down_frames,
150
- )
151
-
152
- # Verify final context state
153
- assert context_aggregator.user().context.get_messages() == [
154
- {"role": "user", "content": "Testing after start complete"},
155
- {"role": "assistant", "content": "Hello Assistant Aggregator!"},
156
- ]
157
-
158
-
159
- @pytest.mark.asyncio()
160
- async def test_multiple_interims_with_final_transcription():
161
- """Test handling of multiple interim transcription frames followed by a final transcription.
162
-
163
- Tests the processing of a sequence of interim transcriptions followed by
164
- a final transcription.
165
-
166
- The test verifies:
167
- - Multiple interim transcriptions are processed correctly
168
- - Final transcription properly overwrites previous interims
169
- - Context updates occur for each valid frame
170
- - Message history maintains correct order
171
- """
172
- messages = []
173
- context = OpenAILLMContext(messages)
174
- context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
175
-
176
- pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
177
- messages.append({"role": "system", "content": "This is system prompt"})
178
-
179
- # Test Case 3: Multiple interim frames with final transcription
180
- frames_to_send = [
181
- UserStartedSpeakingFrame(),
182
- RivaInterimTranscriptionFrame("Hello", "", time_now_iso8601(), None, stability=1.0),
183
- RivaInterimTranscriptionFrame("Hello Again", "", time_now_iso8601(), None, stability=1.0),
184
- RivaInterimTranscriptionFrame("Hello Again User", "", time_now_iso8601(), None, stability=1.0),
185
- TranscriptionFrame("Hello Again User Aggregator!", 1, 2),
186
- SleepFrame(0.1),
187
- UserStoppedSpeakingFrame(),
188
- SleepFrame(0.1),
189
- ]
190
-
191
- # Assistant response
192
- frames_to_send.extend(
193
- [
194
- LLMFullResponseStartFrame(),
195
- TextFrame("Hello Assistant Aggregator!"),
196
- LLMFullResponseEndFrame(),
197
- ]
198
- )
199
- expected_down_frames = [
200
- UserStartedSpeakingFrame,
201
- StartInterruptionFrame,
202
- StartInterruptionFrame,
203
- StartInterruptionFrame,
204
- StartInterruptionFrame,
205
- OpenAILLMContextFrame, # From final transcription
206
- UserStoppedSpeakingFrame,
207
- OpenAILLMContextFrame,
208
- ]
209
-
210
- await run_pipecat_test(
211
- pipeline,
212
- frames_to_send=frames_to_send,
213
- expected_down_frames=expected_down_frames,
214
- )
215
-
216
- # Verify final context state
217
- assert context_aggregator.user().context.get_messages() == [
218
- {"role": "user", "content": "Hello Again User Aggregator!"},
219
- {"role": "assistant", "content": "Hello Assistant Aggregator!"},
220
- ]
221
-
222
-
223
- @pytest.mark.asyncio()
224
- async def test_transcription_after_user_stopped_speaking():
225
- """Tests handling of late transcription frames.
226
-
227
- Tests behavior when transcription frames arrive after UserStoppedSpeakingFrame.
228
-
229
- The test verifies:
230
- - Late transcriptions are still processed
231
- - Context is updated with final transcription
232
- - Assistant responses are handled correctly
233
- - Message history maintains proper sequence
234
- """
235
- messages = []
236
- context = OpenAILLMContext(messages)
237
- context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
238
-
239
- pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
240
- messages.append({"role": "system", "content": "This is system prompt"})
241
-
242
- # Test Case 4: TranscriptionFrame after UserStoppedSpeakingFrame
243
- frames_to_send = [
244
- UserStartedSpeakingFrame(),
245
- RivaInterimTranscriptionFrame("Late", "", time_now_iso8601(), None, stability=1.0),
246
- SleepFrame(0.1),
247
- UserStoppedSpeakingFrame(),
248
- SleepFrame(0.1),
249
- TranscriptionFrame("Late transcription!", 1, 2),
250
- SleepFrame(0.1),
251
- ]
252
-
253
- # Assistant response
254
- frames_to_send.extend(
255
- [
256
- LLMFullResponseStartFrame(),
257
- TextFrame("Hello Assistant Aggregator!"),
258
- LLMFullResponseEndFrame(),
259
- ]
260
- )
261
-
262
- expected_down_frames = [
263
- UserStartedSpeakingFrame,
264
- StartInterruptionFrame,
265
- OpenAILLMContextFrame, # From first interim
266
- UserStoppedSpeakingFrame,
267
- StartInterruptionFrame,
268
- OpenAILLMContextFrame, # From final after UserStoppedSpeakingFrame
269
- OpenAILLMContextFrame, # From assistant response
270
- ]
271
-
272
- await run_pipecat_test(
273
- pipeline,
274
- frames_to_send=frames_to_send,
275
- expected_down_frames=expected_down_frames,
276
- )
277
-
278
- # Verify final context state
279
- assert context_aggregator.user().context.get_messages() == [
280
- {"role": "user", "content": "Late transcription!"},
281
- {"role": "assistant", "content": "Hello Assistant Aggregator!"},
282
- ]
283
-
284
-
285
- @pytest.mark.asyncio()
286
- async def test_no_interim_frames():
287
- """Tests behavior when interim frames are disabled.
288
-
289
- Tests the aggregator's handling of transcriptions when send_interims=False.
290
-
291
- The test verifies:
292
- - Interim frames are ignored
293
- - Only final transcription is processed
294
- - System prompts are preserved
295
- - Context updates occur only for final transcription
296
- - Assistant responses are processed correctly
297
- """
298
- messages = [{"role": "system", "content": "This is system prompt"}]
299
- context = OpenAILLMContext(messages)
300
- context_aggregator = create_nvidia_context_aggregator(context, send_interims=False)
301
- pipeline = Pipeline([context_aggregator.user(), context_aggregator.assistant()])
302
-
303
- frames_to_send = [
304
- UserStartedSpeakingFrame(),
305
- LLMMessagesFrame(messages),
306
- # These interim frames should be ignored due to send_interims=False
307
- RivaInterimTranscriptionFrame("Hello", "", time_now_iso8601(), None, stability=1.0),
308
- RivaInterimTranscriptionFrame("Hello there", "", time_now_iso8601(), None, stability=1.0),
309
- RivaInterimTranscriptionFrame("Hello there user", "", time_now_iso8601(), None, stability=1.0),
310
- # Only the final transcription should be processed
311
- TranscriptionFrame("Hello there user final!", 1, 2),
312
- SleepFrame(0.1),
313
- UserStoppedSpeakingFrame(),
314
- SleepFrame(0.1),
315
- # Assistant response
316
- LLMFullResponseStartFrame(),
317
- TextFrame("Hello from assistant!"),
318
- LLMFullResponseEndFrame(),
319
- ]
320
-
321
- expected_down_frames = [
322
- UserStartedSpeakingFrame,
323
- StartInterruptionFrame,
324
- OpenAILLMContextFrame, # Only from final transcription
325
- UserStoppedSpeakingFrame,
326
- OpenAILLMContextFrame, # From assistant response
327
- ]
328
-
329
- await run_pipecat_test(
330
- pipeline,
331
- frames_to_send=frames_to_send,
332
- expected_down_frames=expected_down_frames,
333
- )
334
-
335
- # Verify final context state
336
- assert context_aggregator.user().context.get_messages() == [
337
- {"role": "system", "content": "This is system prompt"},
338
- {"role": "user", "content": "Hello there user final!"},
339
- {"role": "assistant", "content": "Hello from assistant!"},
340
- ]
341
-
342
-
343
- @pytest.mark.asyncio()
344
- async def test_get_truncated_context():
345
- """Tests context truncation functionality.
346
-
347
- Tests the get_truncated_context() method of NvidiaUserContextAggregator
348
- with a specified chat history limit.
349
-
350
- Args:
351
- None
352
-
353
- Returns:
354
- None
355
-
356
- The test verifies:
357
- - Context is truncated to specified limit
358
- - System prompt is preserved
359
- - Most recent messages are retained
360
- - Message order is maintained
361
- """
362
- messages = [
363
- {"role": "system", "content": "This is system prompt"},
364
- {"role": "user", "content": "Hi, there!"},
365
- {"role": "assistant", "content": "Hello, how may I assist you?"},
366
- {"role": "user", "content": "How to be more productive?"},
367
- {"role": "assistant", "content": "Priotize the tasks, make a list..."},
368
- {"role": "user", "content": "What is metaverse?"},
369
- {
370
- "role": "assistant",
371
- "content": "The metaverse is envisioned as a digital ecosystem built on virtual 3D technology",
372
- },
373
- {
374
- "role": "assistant",
375
- "content": "It leverages 3D technology and digital"
376
- "representation for creating virtual environments and user experiences",
377
- },
378
- {"role": "user", "content": "thanks, Bye!"},
379
- ]
380
- context = OpenAILLMContext(messages)
381
- user = NvidiaUserContextAggregator(context=context, chat_history_limit=2)
382
- truncated_context = await user.get_truncated_context()
383
- assert truncated_context.get_messages() == [
384
- {"role": "system", "content": "This is system prompt"},
385
- {"role": "user", "content": "What is metaverse?"},
386
- {
387
- "role": "assistant",
388
- "content": "The metaverse is envisioned as a digital ecosystem built on virtual 3D technology",
389
- },
390
- {
391
- "role": "assistant",
392
- "content": "It leverages 3D technology and digital"
393
- "representation for creating virtual environments and user experiences",
394
- },
395
- {"role": "user", "content": "thanks, Bye!"},
396
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_nvidia_llm_service.py DELETED
@@ -1,386 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the NvidiaLLMService.
5
-
6
- This module contains tests for the NvidiaLLMService class, focusing on core functionalities:
7
- - Think token filtering (including split tag handling)
8
- - Mistral message preprocessing
9
- - Token usage tracking
10
- - LLM responses and function calls
11
- """
12
-
13
- from unittest.mock import DEFAULT, AsyncMock, patch
14
-
15
- import pytest
16
- from pipecat.frames.frames import LLMTextFrame
17
- from pipecat.metrics.metrics import LLMTokenUsage
18
- from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
19
-
20
- from nvidia_pipecat.services.nvidia_llm import NvidiaLLMService
21
-
22
-
23
- # Custom mocks that mimic OpenAI classes without inheriting from them
24
- class MockCompletionUsage:
25
- """Mock for CompletionUsage that mimics the structure."""
26
-
27
- def __init__(self, prompt_tokens, completion_tokens, total_tokens):
28
- """Initialize with token usage counts.
29
-
30
- Args:
31
- prompt_tokens: Number of tokens in the prompt
32
- completion_tokens: Number of tokens in the completion
33
- total_tokens: Total number of tokens
34
- """
35
- self.prompt_tokens = prompt_tokens
36
- self.completion_tokens = completion_tokens
37
- self.total_tokens = total_tokens
38
-
39
-
40
- class MockChoiceDelta:
41
- """Mock for ChoiceDelta that mimics the structure."""
42
-
43
- def __init__(self, content=None, tool_calls=None):
44
- """Initialize with optional content and tool calls.
45
-
46
- Args:
47
- content: The text content of the delta
48
- tool_calls: List of tool calls in the delta
49
- """
50
- self.content = content
51
- self.tool_calls = tool_calls
52
- self.function_call = None
53
- self.role = None
54
-
55
-
56
- class MockChoice:
57
- """Mock for Choice that mimics the structure."""
58
-
59
- def __init__(self, delta, index=0, finish_reason=None):
60
- """Initialize with delta, index, and finish reason.
61
-
62
- Args:
63
- delta: The delta containing content or tool calls
64
- index: The index of this choice
65
- finish_reason: Reason for finishing generation
66
- """
67
- self.delta = delta
68
- self.index = index
69
- self.finish_reason = finish_reason
70
-
71
-
72
- class MockChatCompletionChunk:
73
- """Mock for ChatCompletionChunk that mimics the structure."""
74
-
75
- def __init__(self, content=None, usage=None, id="mock-id", tool_calls=None):
76
- """Initialize a mock of ChatCompletionChunk.
77
-
78
- Args:
79
- content: The text content in the chunk
80
- usage: Token usage information
81
- id: Chunk identifier
82
- tool_calls: Any tool calls in the chunk
83
- """
84
- self.id = id
85
- self.model = "mock-model"
86
- self.object = "chat.completion.chunk"
87
- self.created = 1234567890
88
- self.usage = usage
89
-
90
- if tool_calls:
91
- self.choices = [MockChoice(MockChoiceDelta(tool_calls=tool_calls))]
92
- else:
93
- self.choices = [MockChoice(MockChoiceDelta(content=content))]
94
-
95
-
96
- class MockToolCall:
97
- """Mock for ToolCall."""
98
-
99
- def __init__(self, id="tool-id", function=None, index=0, type="function"):
100
- """Initialize a tool call.
101
-
102
- Args:
103
- id: Tool call identifier
104
- function: The function being called
105
- index: Index of this tool call
106
- type: Type of tool call
107
- """
108
- self.id = id
109
- self.function = function
110
- self.index = index
111
- self.type = type
112
-
113
-
114
- class MockFunction:
115
- """Mock for Function in a tool call."""
116
-
117
- def __init__(self, name="", arguments=""):
118
- """Initialize a function with name and arguments.
119
-
120
- Args:
121
- name: Name of the function
122
- arguments: JSON string of function arguments
123
- """
124
- self.name = name
125
- self.arguments = arguments
126
-
127
-
128
- class MockAsyncStream:
129
- """Mock implementation of AsyncStream for testing."""
130
-
131
- def __init__(self, chunks):
132
- """Initialize with a list of chunks to yield.
133
-
134
- Args:
135
- chunks: List of chunks to return when iterating
136
- """
137
- self.chunks = chunks
138
-
139
- def __aiter__(self):
140
- """Return self as an async iterator."""
141
- return self
142
-
143
- async def __anext__(self):
144
- """Return the next chunk or raise StopAsyncIteration."""
145
- if not self.chunks:
146
- raise StopAsyncIteration
147
- return self.chunks.pop(0)
148
-
149
-
150
- @pytest.mark.asyncio
151
- async def test_mistral_message_preprocessing():
152
- """Test the Mistral message preprocessing functionality."""
153
- service = NvidiaLLMService(api_key="test_api_key", mistral_model_support=True)
154
-
155
- # Test with alternating roles (already valid)
156
- messages = [
157
- {"role": "system", "content": "You are a helpful assistant."},
158
- {"role": "user", "content": "Hello"},
159
- {"role": "assistant", "content": "Hi there!"},
160
- {"role": "user", "content": "How are you?"},
161
- ]
162
- processed = service._preprocess_messages_for_mistral(messages)
163
- assert len(processed) == len(messages) # No changes needed
164
-
165
- # Test with consecutive messages from same role
166
- messages = [
167
- {"role": "system", "content": "You are a helpful assistant."},
168
- {"role": "user", "content": "Hello"},
169
- {"role": "user", "content": "How are you?"},
170
- ]
171
- processed = service._preprocess_messages_for_mistral(messages)
172
- assert len(processed) == 2 # System + combined user
173
- assert processed[1]["role"] == "user"
174
- assert processed[1]["content"] == "Hello How are you?"
175
-
176
- # Test with system message at end
177
- messages = [
178
- {"role": "user", "content": "Hello"},
179
- {"role": "system", "content": "You are a helpful assistant."},
180
- ]
181
- processed = service._preprocess_messages_for_mistral(messages)
182
- assert len(processed) == 2 # User + system
183
- assert processed[0]["role"] == "user"
184
- assert processed[1]["role"] == "system"
185
-
186
-
187
- @pytest.mark.asyncio
188
- async def test_filter_think_token_simple():
189
- """Test the basic think token filtering functionality."""
190
- service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
191
- service._reset_think_filter_state()
192
-
193
- # Test with simple think content followed by real content
194
- content = "I'm thinking about the answer</think>This is the actual response"
195
- filtered = service._filter_think_token(content)
196
- assert filtered == "This is the actual response"
197
- assert service._seen_end_tag is True
198
-
199
- # Subsequent content should pass through untouched
200
- more_content = " and some more text"
201
- filtered = service._filter_think_token(more_content)
202
- assert filtered == " and some more text"
203
-
204
-
205
- @pytest.mark.asyncio
206
- async def test_filter_think_token_split():
207
- """Test the think token filtering with split tags."""
208
- service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
209
- service._reset_think_filter_state()
210
-
211
- # First part with beginning of tag
212
- content1 = "Let me think about this problem<"
213
- filtered1 = service._filter_think_token(content1)
214
- assert filtered1 == "" # No output yet
215
- assert service._partial_tag_buffer == "<" # Partial tag saved
216
-
217
- # Second part with rest of tag and response
218
- content2 = "/think>Here's the answer"
219
- filtered2 = service._filter_think_token(content2)
220
- assert filtered2 == "Here's the answer" # Output after tag
221
- assert service._seen_end_tag is True
222
-
223
-
224
- @pytest.mark.asyncio
225
- async def test_filter_think_token_no_tag():
226
- """Test what happens when no think tag is found."""
227
- service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
228
- service._reset_think_filter_state()
229
-
230
- # Add some content in multiple chunks
231
- filtered1 = service._filter_think_token("This is a response")
232
- filtered2 = service._filter_think_token(" with no think tag")
233
- # Verify filtering behavior
234
- assert filtered1 == filtered2 == "" # No output during filtering
235
- assert service._thinking_aggregation == "This is a response with no think tag"
236
- # Test end-of-processing behavior
237
- service.push_frame = AsyncMock()
238
- await service.push_frame(LLMTextFrame(service._thinking_aggregation))
239
- service._reset_think_filter_state()
240
- # Verify results
241
- service.push_frame.assert_called_once()
242
- assert service.push_frame.call_args.args[0].text == "This is a response with no think tag"
243
- assert service._thinking_aggregation == "" # State was reset
244
-
245
-
246
- @pytest.mark.asyncio
247
- async def test_token_usage_tracking():
248
- """Test the token usage tracking functionality."""
249
- service = NvidiaLLMService(api_key="test_api_key")
250
- service._is_processing = True
251
-
252
- # Test initial accumulation of prompt tokens
253
- tokens1 = LLMTokenUsage(prompt_tokens=10, completion_tokens=0, total_tokens=10)
254
- await service.start_llm_usage_metrics(tokens1)
255
- assert service._prompt_tokens == 10
256
- assert service._has_reported_prompt_tokens is True
257
-
258
- # Test incremental completion tokens
259
- tokens2 = LLMTokenUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15)
260
- await service.start_llm_usage_metrics(tokens2)
261
- assert service._completion_tokens == 5
262
-
263
- # Test more completion tokens
264
- tokens3 = LLMTokenUsage(prompt_tokens=10, completion_tokens=8, total_tokens=18)
265
- await service.start_llm_usage_metrics(tokens3)
266
- assert service._completion_tokens == 8
267
-
268
- # Test reporting duplicate prompt tokens
269
- tokens4 = LLMTokenUsage(prompt_tokens=10, completion_tokens=10, total_tokens=20)
270
- await service.start_llm_usage_metrics(tokens4)
271
- assert service._prompt_tokens == 10 # Unchanged
272
- assert service._completion_tokens == 10
273
-
274
-
275
- @pytest.mark.asyncio
276
- async def test_process_context_with_think_filtering():
277
- """Test the full processing with think token filtering."""
278
- with patch.multiple(
279
- NvidiaLLMService,
280
- create_client=DEFAULT,
281
- _stream_chat_completions=DEFAULT,
282
- start_ttfb_metrics=DEFAULT,
283
- stop_ttfb_metrics=DEFAULT,
284
- push_frame=DEFAULT,
285
- ) as mocks:
286
- service = NvidiaLLMService(api_key="test_api_key", filter_think_tokens=True)
287
- mock_push_frame = mocks["push_frame"]
288
-
289
- # Setup mock stream
290
- chunks = [
291
- MockChatCompletionChunk(content="Thinking<"),
292
- MockChatCompletionChunk(content="/think>Real content"),
293
- MockChatCompletionChunk(content=" continues"),
294
- ]
295
- mocks["_stream_chat_completions"].return_value = MockAsyncStream(chunks)
296
-
297
- # Process context
298
- context = OpenAILLMContext(messages=[{"role": "user", "content": "Test query"}])
299
- await service._process_context(context)
300
-
301
- # Verify frame content - empty frames during thinking, content after tag
302
- frames = [call.args[0].text for call in mock_push_frame.call_args_list]
303
- assert frames == ["", "Real content", " continues"]
304
-
305
-
306
- @pytest.mark.asyncio
307
- async def test_process_context_with_function_calls():
308
- """Test handling of function calls from LLM."""
309
- with (
310
- patch.object(NvidiaLLMService, "create_client"),
311
- patch.object(NvidiaLLMService, "_stream_chat_completions") as mock_stream,
312
- patch.object(NvidiaLLMService, "has_function") as mock_has_function,
313
- patch.object(NvidiaLLMService, "call_function") as mock_call_function,
314
- ):
315
- service = NvidiaLLMService(api_key="test_api_key")
316
-
317
- # Create tool call chunks that come in parts
318
- tool_call1 = MockToolCall(
319
- id="call1", function=MockFunction(name="get_weather", arguments='{"location"'), index=0
320
- )
321
-
322
- tool_call2 = MockToolCall(id="call1", function=MockFunction(name="", arguments=':"New York"}'), index=0)
323
-
324
- # Create chunks with tool calls
325
- chunk1 = MockChatCompletionChunk(tool_calls=[tool_call1], id="chunk1")
326
- chunk2 = MockChatCompletionChunk(tool_calls=[tool_call2], id="chunk2")
327
-
328
- mock_stream.return_value = MockAsyncStream([chunk1, chunk2])
329
- mock_has_function.return_value = True
330
- mock_call_function.return_value = None
331
-
332
- # Process a context
333
- context = OpenAILLMContext(messages=[{"role": "user", "content": "What's the weather in New York?"}])
334
- await service._process_context(context)
335
-
336
- # Verify function was called with combined arguments
337
- mock_call_function.assert_called_once()
338
- args = mock_call_function.call_args.kwargs
339
- assert args["function_name"] == "get_weather"
340
- assert args["arguments"] == {"location": "New York"}
341
- assert args["tool_call_id"] == "call1"
342
-
343
-
344
- @pytest.mark.asyncio
345
- async def test_process_context_with_mistral_preprocessing():
346
- """Test processing context with Mistral message preprocessing."""
347
- with (
348
- patch.object(NvidiaLLMService, "create_client"),
349
- patch.object(NvidiaLLMService, "_stream_chat_completions") as mock_stream,
350
- ):
351
- service = NvidiaLLMService(api_key="test_api_key", mistral_model_support=True)
352
-
353
- # Setup mock stream
354
- chunks = [MockChatCompletionChunk(content="I am a response")]
355
- mock_stream.return_value = MockAsyncStream(chunks)
356
-
357
- # Test 1: Combining consecutive user messages
358
- context = OpenAILLMContext(
359
- messages=[
360
- {"role": "system", "content": "You are helpful."},
361
- {"role": "user", "content": "Hello"},
362
- {"role": "user", "content": "How are you?"},
363
- ]
364
- )
365
- await service._process_context(context)
366
-
367
- # Verify messages were combined
368
- processed_messages = context.get_messages()
369
- assert len(processed_messages) == 2 # System + combined user
370
- assert processed_messages[1]["role"] == "user"
371
- assert processed_messages[1]["content"] == "Hello How are you?"
372
-
373
- # Verify stream was called (normal processing)
374
- mock_stream.assert_called_once()
375
-
376
- # Test 2: System message only - should skip processing
377
- mock_stream.reset_mock()
378
- system_only_context = OpenAILLMContext(
379
- messages=[
380
- {"role": "system", "content": "You are helpful."},
381
- ]
382
- )
383
- await service._process_context(system_only_context)
384
-
385
- # Verify that stream was not called (processing skipped)
386
- mock_stream.assert_not_called()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_nvidia_rag_service.py DELETED
@@ -1,261 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the NvidiaRAGService processor."""
5
-
6
- import httpx
7
- import pytest
8
- from loguru import logger
9
- from pipecat.frames.frames import ErrorFrame, LLMMessagesFrame, TextFrame
10
- from pipecat.pipeline.pipeline import Pipeline
11
- from pipecat.pipeline.task import PipelineTask
12
- from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
13
-
14
- from nvidia_pipecat.frames.nvidia_rag import NvidiaRAGCitation, NvidiaRAGCitationsFrame, NvidiaRAGSettingsFrame
15
- from nvidia_pipecat.services.nvidia_rag import NvidiaRAGService
16
- from tests.unit.utils import FrameStorage, ignore_ids, run_interactive_test, run_test
17
-
18
-
19
- class MockResponse:
20
- """Mock response class for testing HTTP responses.
21
-
22
- Attributes:
23
- json: The JSON response data.
24
- """
25
-
26
- def __init__(self, json):
27
- """Initialize MockResponse with JSON data.
28
-
29
- Args:
30
- json: The JSON response data.
31
- """
32
- self.json = json
33
-
34
- async def aclose(self):
35
- """Mock aclose method."""
36
- pass
37
-
38
- async def aiter_lines(self):
39
- """Simulate chunk iteration for streaming response.
40
-
41
- Yields:
42
- tuple: A tuple containing data and None.
43
- """
44
- yield self.json
45
-
46
-
47
- @pytest.mark.asyncio
48
- async def test_nvidia_rag_service(mocker):
49
- """Test NvidiaRAGService functionality with various test cases.
50
-
51
- Tests different RAG service behaviors including successful responses,
52
- citation handling, and error conditions.
53
-
54
- Args:
55
- mocker: Pytest mocker fixture for mocking HTTP responses.
56
-
57
- The test verifies:
58
- - Successful responses without citations
59
- - Successful responses with citations
60
- - Error handling for empty collection names
61
- - Error handling for empty queries
62
- - Error handling for incorrect message roles
63
- """
64
- testcases = {
65
- "Success without citations": {
66
- "collection_name": "collection123",
67
- "messages": [
68
- {
69
- "role": "system",
70
- "content": "You are a helpful Large Language Model. "
71
- "Your goal is to demonstrate your capabilities in a succinct way. "
72
- "Your output will be converted to audio so don't include special characters in your answers. "
73
- "Respond to what the user said in a creative and helpful way.",
74
- },
75
- ],
76
- "response_json": 'data: {"id":"a886cc44-e2ce-4ea3-95f0-9ffb1171adb1",'
77
- '"choices":[{"index":0,"message":{"role":"assistant","content":"this is rag response content"},'
78
- '"delta":{"role":"assistant","content":""},"finish_reason":"[DONE]"}]}',
79
- "result_frame": TextFrame("this is rag response content"),
80
- },
81
- "Success with citations": {
82
- "collection_name": "collection123",
83
- "messages": [
84
- {
85
- "role": "system",
86
- "content": "You are a helpful Large Language Model. "
87
- "Your goal is to demonstrate your capabilities in a succinct way. "
88
- "Your output will be converted to audio so don't include special characters in your answers. "
89
- "Respond to what the user said in a creative and helpful way.",
90
- },
91
- ],
92
- "response_json": 'data: {"id":"a886cc44-e2ce-4ea3-95f0-9ffb1171adb1",'
93
- '"choices":[{"index":0,"message":{"role":"assistant","content":"this is rag response content"},'
94
- '"delta":{"role":"assistant","content":""},"finish_reason":"[DONE]"}], "citations":{"total_results":0,'
95
- '"results":[{"document_id": "", "content": "this is rag citation content", "document_type": "text",'
96
- ' "document_name": "", "metadata": "", "score": 0.0}]}}',
97
- "result_frame": NvidiaRAGCitationsFrame(
98
- [
99
- NvidiaRAGCitation(
100
- document_id="",
101
- document_type="text",
102
- metadata="",
103
- score=0.0,
104
- document_name="",
105
- content=b"this is rag citation content",
106
- )
107
- ]
108
- ),
109
- },
110
- "Fail due to empty collection name": {
111
- "collection_name": "",
112
- "messages": [
113
- {
114
- "role": "system",
115
- "content": "You are a helpful Large Language Model. "
116
- "Your goal is to demonstrate your capabilities in a succinct way. "
117
- "Your output will be converted to audio so don't include special characters in your answers. "
118
- "Respond to what the user said in a creative and helpful way.",
119
- },
120
- ],
121
- "result_frame": ErrorFrame(
122
- "An error occurred in http request to RAG endpoint, Error: No query or collection name is provided.."
123
- ),
124
- },
125
- "Fail due to empty query": {
126
- "collection_name": "collection123",
127
- "messages": [
128
- {
129
- "role": "system",
130
- "content": "",
131
- },
132
- ],
133
- "result_frame": ErrorFrame(
134
- "An error occurred in http request to RAG endpoint, Error: No query or collection name is provided.."
135
- ),
136
- },
137
- "Fail due to incorrect role": {
138
- "collection_name": "collection123",
139
- "messages": [
140
- {
141
- "role": "tool",
142
- "content": "",
143
- },
144
- ],
145
- "result_frame": ErrorFrame(
146
- "An error occurred in http request to RAG endpoint, Error: Unexpected role tool found!"
147
- ),
148
- },
149
- }
150
-
151
- for tc_name, tc_data in testcases.items():
152
- logger.info(f"Verifying test case: {tc_name}")
153
-
154
- resp = None
155
- if "response_json" in tc_data:
156
- resp = MockResponse(tc_data["response_json"])
157
- else:
158
- resp = MockResponse("{}")
159
-
160
- mocker.patch("httpx.AsyncClient.post", return_value=resp)
161
-
162
- rag = NvidiaRAGService(collection_name=tc_data["collection_name"])
163
- storage1 = FrameStorage()
164
- storage2 = FrameStorage()
165
- context_aggregator = rag.create_context_aggregator(OpenAILLMContext(tc_data["messages"]))
166
-
167
- pipeline = Pipeline([context_aggregator.user(), storage1, rag, storage2, context_aggregator.assistant()])
168
-
169
- async def test_routine(task: PipelineTask, test_data=tc_data, s1=storage1, s2=storage2):
170
- await task.queue_frame(LLMMessagesFrame(test_data["messages"]))
171
-
172
- # Wait for the result frame
173
- if "ErrorFrame" in test_data["result_frame"].name:
174
- await s1.wait_for_frame(ignore_ids(test_data["result_frame"]))
175
- else:
176
- await s2.wait_for_frame(ignore_ids(test_data["result_frame"]))
177
-
178
- await run_interactive_test(pipeline, test_coroutine=test_routine)
179
-
180
- # Verify the frames in storage1
181
- for frame_history_entry in storage1.history:
182
- if frame_history_entry.frame.name.startswith("ErrorFrame"):
183
- assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
184
-
185
- # Verify the frames in storage2
186
- for frame_history_entry in storage2.history:
187
- if (
188
- frame_history_entry.frame.name.startswith("TextFrame")
189
- and tc_data["result_frame"].__str__().startswith("TextFrame")
190
- or frame_history_entry.frame.name.startswith("NvidiaRAGCitationsFrame")
191
- and tc_data["result_frame"].__str__().startswith("NvidiaRAGCitationsFrame")
192
- ):
193
- assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
194
-
195
-
196
- @pytest.mark.asyncio
197
- async def test_rag_service_sharing_session():
198
- """Test session sharing behavior between NvidiaRAGService instances.
199
-
200
- Tests the HTTP client session management across multiple RAG service
201
- instances.
202
-
203
- The test verifies:
204
- - Session sharing between instances with same parameters
205
- - Separate session handling for custom sessions
206
- - Proper session cleanup
207
- """
208
- rags = []
209
- rags.append(NvidiaRAGService(collection_name="collection_1"))
210
- rags.append(NvidiaRAGService(collection_name="collection_1"))
211
-
212
- initial_session = rags[1].shared_session
213
-
214
- for rag in rags:
215
- assert rag.shared_session is initial_session
216
-
217
- new_session = httpx.AsyncClient()
218
- rags.append(NvidiaRAGService(collection_name="collection_1", session=new_session))
219
-
220
- assert rags[0].shared_session is initial_session
221
- assert rags[1].shared_session is initial_session
222
- assert rags[2].shared_session is new_session
223
-
224
- await new_session.aclose()
225
- for r in rags:
226
- await r.cleanup()
227
-
228
-
229
- @pytest.mark.asyncio
230
- async def test_nvidia_rag_settings_frame_update(mocker):
231
- """Tests NvidiaRAGService settings update functionality.
232
-
233
- Tests the processing of NvidiaRAGSettingsFrame for dynamic configuration
234
- updates.
235
-
236
- Args:
237
- mocker: Pytest mocker fixture for mocking HTTP responses.
238
-
239
- The test verifies:
240
- - Collection name updates
241
- - Server URL updates
242
- - Settings frame propagation
243
- """
244
- mocker.patch("httpx.AsyncClient.post", return_value="")
245
-
246
- rag_settings_frame = NvidiaRAGSettingsFrame(
247
- settings={"collection_name": "nvidia_blogs", "rag_server_url": "http://10.41.23.247:8081"}
248
- )
249
- rag = NvidiaRAGService(collection_name="collection123")
250
-
251
- frames_to_send = [rag_settings_frame]
252
- expected_down_frames = [rag_settings_frame]
253
-
254
- await run_test(
255
- rag,
256
- frames_to_send=frames_to_send,
257
- expected_down_frames=expected_down_frames,
258
- )
259
-
260
- assert rag.collection_name == "nvidia_blogs"
261
- assert rag.rag_server_url == "http://10.41.23.247:8081"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_nvidia_tts_response_cacher.py DELETED
@@ -1,79 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the Nvidia TTS Response Cacher."""
5
-
6
- import pytest
7
- from pipecat.frames.frames import (
8
- LLMFullResponseEndFrame,
9
- LLMFullResponseStartFrame,
10
- TTSAudioRawFrame,
11
- UserStartedSpeakingFrame,
12
- UserStoppedSpeakingFrame,
13
- )
14
- from pipecat.pipeline.pipeline import Pipeline
15
- from pipecat.tests.utils import run_test as run_pipecat_test
16
-
17
- from nvidia_pipecat.processors.nvidia_context_aggregator import NvidiaTTSResponseCacher
18
-
19
-
20
- @pytest.mark.asyncio()
21
- async def test_nvidia_tts_response_cacher():
22
- """Tests NvidiaTTSResponseCacher's response deduplication functionality.
23
-
24
- Tests the cacher's ability to deduplicate TTS audio responses in a sequence
25
- of frames including user speech events and LLM responses.
26
-
27
- Args:
28
- None
29
-
30
- Returns:
31
- None
32
-
33
- The test verifies:
34
- - Correct handling of user speech start/stop frames
35
- - Deduplication of identical TTS audio frames
36
- - Preservation of LLM response start/end frames
37
- - Frame ordering in pipeline output
38
- - Only one TTS frame is retained
39
- """
40
- nvidia_tts_response_cacher = NvidiaTTSResponseCacher()
41
- pipeline = Pipeline([nvidia_tts_response_cacher])
42
-
43
- test_audio = b"\x52\x49\x46\x46\x24\x08\x00\x00\x57\x41\x56\x45\x66\x6d\x74\x20"
44
- frames_to_send = [
45
- UserStartedSpeakingFrame(),
46
- LLMFullResponseStartFrame(),
47
- TTSAudioRawFrame(
48
- audio=test_audio,
49
- sample_rate=16000,
50
- num_channels=1,
51
- ),
52
- LLMFullResponseEndFrame(),
53
- LLMFullResponseStartFrame(),
54
- TTSAudioRawFrame(
55
- audio=test_audio,
56
- sample_rate=16000,
57
- num_channels=1,
58
- ),
59
- LLMFullResponseEndFrame(),
60
- UserStoppedSpeakingFrame(),
61
- ]
62
-
63
- expected_down_frames = [
64
- UserStartedSpeakingFrame,
65
- UserStoppedSpeakingFrame,
66
- LLMFullResponseStartFrame,
67
- TTSAudioRawFrame,
68
- LLMFullResponseEndFrame,
69
- ]
70
-
71
- received_down_frames, received_up_frames = await run_pipecat_test(
72
- pipeline,
73
- frames_to_send=frames_to_send,
74
- expected_down_frames=expected_down_frames,
75
- )
76
-
77
- # Verify we only got one TTS frame
78
- tts_frames = [f for f in received_down_frames if isinstance(f, TTSAudioRawFrame)]
79
- assert len(tts_frames) == 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_posture.py DELETED
@@ -1,104 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the PostureProviderProcessor class."""
5
-
6
- import pytest
7
- from pipecat.frames.frames import BotStoppedSpeakingFrame, StartInterruptionFrame, TTSStartedFrame
8
-
9
- from nvidia_pipecat.frames.action import StartPostureBotActionFrame
10
- from nvidia_pipecat.processors.posture_provider import PostureProviderProcessor
11
- from tests.unit.utils import ignore_ids, run_test
12
-
13
-
14
- @pytest.mark.asyncio()
15
- async def test_posture_provider_processor_tts():
16
- """Test TTSStartedFrame processing in PostureProviderProcessor.
17
-
18
- Tests that the processor generates appropriate "Talking" posture when
19
- text-to-speech begins.
20
-
21
- Args:
22
- None
23
-
24
- Returns:
25
- None
26
-
27
- The test verifies:
28
- - TTSStartedFrame is processed correctly
29
- - "Talking" posture is generated
30
- - Frames are emitted in correct order
31
- """
32
- frames_to_send = [TTSStartedFrame()]
33
- expected_down_frames = [
34
- ignore_ids(TTSStartedFrame()),
35
- ignore_ids(StartPostureBotActionFrame(posture="Talking")),
36
- ]
37
-
38
- await run_test(
39
- PostureProviderProcessor(),
40
- frames_to_send=frames_to_send,
41
- expected_down_frames=expected_down_frames,
42
- )
43
-
44
-
45
- @pytest.mark.asyncio()
46
- async def test_posture_provider_processor_bot_finished():
47
- """Test BotStoppedSpeakingFrame processing in PostureProviderProcessor.
48
-
49
- Tests that the processor generates appropriate "Attentive" posture when
50
- bot stops speaking.
51
-
52
- Args:
53
- None
54
-
55
- Returns:
56
- None
57
-
58
- The test verifies:
59
- - BotStoppedSpeakingFrame is processed correctly
60
- - "Attentive" posture is generated
61
- - Frames are emitted in correct order
62
- """
63
- frames_to_send = [BotStoppedSpeakingFrame()]
64
- expected_down_frames = [
65
- ignore_ids(BotStoppedSpeakingFrame()),
66
- ignore_ids(StartPostureBotActionFrame(posture="Attentive")),
67
- ]
68
-
69
- await run_test(
70
- PostureProviderProcessor(),
71
- frames_to_send=frames_to_send,
72
- expected_down_frames=expected_down_frames,
73
- )
74
-
75
-
76
- @pytest.mark.asyncio()
77
- async def test_posture_provider_processor_interrupt():
78
- """Tests posture generation for interruption events.
79
-
80
- Tests that the processor generates appropriate "Listening" posture when
81
- an interruption occurs.
82
-
83
- Args:
84
- None
85
-
86
- Returns:
87
- None
88
-
89
- The test verifies:
90
- - StartInterruptionFrame is processed correctly
91
- - "Listening" posture is generated
92
- - Frames are emitted in correct order
93
- """
94
- frames_to_send = [StartInterruptionFrame()]
95
- expected_down_frames = [
96
- ignore_ids(StartInterruptionFrame()),
97
- ignore_ids(StartPostureBotActionFrame(posture="Listening")),
98
- ]
99
-
100
- await run_test(
101
- PostureProviderProcessor(),
102
- frames_to_send=frames_to_send,
103
- expected_down_frames=expected_down_frames,
104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_proactivity.py DELETED
@@ -1,85 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the proactivity processor module.
5
-
6
- This module contains tests that verify the behavior of the ProactivityProcessor.
7
- """
8
-
9
- import asyncio
10
- import os
11
- import sys
12
-
13
- import pytest
14
-
15
- sys.path.append(os.path.abspath("../../src"))
16
-
17
- from pipecat.frames.frames import TTSSpeakFrame, UserStoppedSpeakingFrame
18
- from pipecat.pipeline.pipeline import Pipeline
19
- from pipecat.pipeline.task import PipelineTask
20
-
21
- from nvidia_pipecat.frames.action import StartedPresenceUserActionFrame
22
- from nvidia_pipecat.processors.proactivity import ProactivityProcessor
23
- from tests.unit.utils import FrameStorage, run_interactive_test
24
-
25
-
26
- @pytest.mark.asyncio
27
- async def test_proactive_bot_processor_timer_behavior():
28
- """Test the ProactiveBotProcessor's timer and message behavior.
29
-
30
- Tests the processor's ability to manage timer-based proactive messages
31
- and handle timer resets based on user activity.
32
-
33
- Args:
34
- None
35
-
36
- Returns:
37
- None
38
-
39
- The test verifies:
40
- - Default message is sent after timer expiration
41
- - Timer resets correctly on user activity
42
- - Timer reset prevents premature message generation
43
- - Frames are processed in correct order
44
- - Message content matches configuration
45
- """
46
- proactivity = ProactivityProcessor(default_message="I'm here if you need me!", timer_duration=0.5)
47
- storage = FrameStorage()
48
- pipeline = Pipeline([proactivity, storage])
49
-
50
- async def test_routine(task: PipelineTask):
51
- """Inner test coroutine for proactivity testing.
52
-
53
- Args:
54
- task: PipelineTask instance for frame queueing.
55
-
56
- The routine:
57
- 1. Sends initial presence frame
58
- 2. Waits for timer expiration
59
- 3. Verifies message generation
60
- 4. Tests timer reset behavior
61
- 5. Confirms no premature messages
62
- """
63
- await task.queue_frame(StartedPresenceUserActionFrame(action_id="1"))
64
- # Wait for initial proactive message
65
- await asyncio.sleep(0.6)
66
-
67
- # Confirm at least one frame
68
- assert len(storage.history) >= 1, "Expected at least one frame."
69
-
70
- # Confirm correct text frame output
71
- frame = storage.history[2].frame
72
- assert isinstance(frame, TTSSpeakFrame)
73
- assert frame.text == "I'm here if you need me!"
74
-
75
- # Send another StartFrame to reset the timer
76
- await task.queue_frame(UserStoppedSpeakingFrame())
77
- await asyncio.sleep(0)
78
-
79
- # Wait half the timer (0.5s) => no new message yet
80
- frame_count_after_reset = len(storage.history)
81
- await asyncio.sleep(0.3)
82
- # Confirm no additional message arrived yet
83
- assert frame_count_after_reset == len(storage.history)
84
-
85
- await run_interactive_test(pipeline, test_coroutine=test_routine)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_riva_asr_service.py DELETED
@@ -1,523 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the RivaASRService.
5
-
6
- This module contains tests for the RivaASRService class, including initialization,
7
- ASR functionality, interruption handling, and integration tests.
8
- """
9
-
10
- import asyncio
11
- import unittest
12
- from unittest.mock import AsyncMock, MagicMock, patch
13
-
14
- import pytest
15
- from pipecat.frames.frames import (
16
- CancelFrame,
17
- EndFrame,
18
- StartFrame,
19
- StartInterruptionFrame,
20
- StopInterruptionFrame,
21
- TranscriptionFrame,
22
- UserStartedSpeakingFrame,
23
- UserStoppedSpeakingFrame,
24
- )
25
- from pipecat.transcriptions.language import Language
26
-
27
- from nvidia_pipecat.frames.riva import RivaInterimTranscriptionFrame
28
- from nvidia_pipecat.services.riva_speech import RivaASRService
29
-
30
-
31
- class TestRivaASRService(unittest.TestCase):
32
- """Test suite for RivaASRService functionality."""
33
-
34
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
35
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
36
- def test_initialization_with_default_parameters(self, mock_asr_service, mock_auth):
37
- """Tests RivaASRService initialization with default parameters.
38
-
39
- Args:
40
- mock_asr_service: Mock for the ASR service.
41
- mock_auth: Mock for the authentication service.
42
-
43
- The test verifies:
44
- - Correct type initialization for all parameters
45
- - Default server configuration
46
- - Authentication setup
47
- - Service parameter defaults
48
- """
49
- # Act
50
- service = RivaASRService(api_key="test_api_key")
51
-
52
- # Assert - only check types, not specific values
53
- self.assertIsInstance(service._language_code, Language)
54
- self.assertIsInstance(service._sample_rate, int)
55
- self.assertIsInstance(service._model, str)
56
- # Basic boolean parameter checks without restricting values
57
- self.assertIsInstance(service._profanity_filter, bool)
58
- self.assertIsInstance(service._automatic_punctuation, bool)
59
- self.assertIsInstance(service._interim_results, bool)
60
- self.assertIsInstance(service._max_alternatives, int)
61
- self.assertIsInstance(service._generate_interruptions, bool)
62
-
63
- # Verify Auth was called with correct parameters
64
- # For server "grpc.nvcf.nvidia.com:443", use_ssl is automatically set to True
65
- mock_auth.assert_called_with(
66
- None,
67
- True, # Changed from False to True since default server is "grpc.nvcf.nvidia.com:443"
68
- "grpc.nvcf.nvidia.com:443",
69
- [["function-id", "1598d209-5e27-4d3c-8079-4751568b1081"], ["authorization", "Bearer test_api_key"]],
70
- )
71
-
72
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
73
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
74
- def test_initialization_with_custom_parameters(self, mock_asr_service, mock_auth):
75
- """Tests RivaASRService initialization with custom parameters.
76
-
77
- Args:
78
- mock_asr_service: Mock for the ASR service.
79
- mock_auth: Mock for the authentication service.
80
-
81
- The test verifies:
82
- - Custom parameter values are set correctly
83
- - Server configuration is customized
84
- - Authentication with custom credentials
85
- - Optional parameter handling
86
- """
87
- # Define test parameters
88
- test_api_key = "test_api_key"
89
- test_server = "custom_server:50051"
90
- test_function_id = "custom_function_id"
91
- test_language = Language.ES_ES
92
- test_model = "custom_model"
93
- test_sample_rate = 44100
94
- test_channel_count = 2
95
- test_max_alternatives = 2
96
- test_boosted_words = {"boost": 1.0}
97
- test_boosted_score = 5.0
98
-
99
- # Act
100
- service = RivaASRService(
101
- api_key=test_api_key,
102
- server=test_server,
103
- function_id=test_function_id,
104
- language=test_language,
105
- model=test_model,
106
- profanity_filter=True,
107
- automatic_punctuation=True,
108
- no_verbatim_transcripts=True,
109
- boosted_lm_words=test_boosted_words,
110
- boosted_lm_score=test_boosted_score,
111
- sample_rate=test_sample_rate,
112
- audio_channel_count=test_channel_count,
113
- max_alternatives=test_max_alternatives,
114
- interim_results=False,
115
- generate_interruptions=True,
116
- use_ssl=True,
117
- )
118
-
119
- # Assert - verify custom parameters were set correctly
120
- self.assertEqual(service._language_code, test_language)
121
- self.assertEqual(service._sample_rate, test_sample_rate)
122
- self.assertEqual(service._model, test_model)
123
- self.assertEqual(service._boosted_lm_words, test_boosted_words)
124
- self.assertEqual(service._boosted_lm_score, test_boosted_score)
125
- self.assertEqual(service._max_alternatives, test_max_alternatives)
126
- self.assertEqual(service._audio_channel_count, test_channel_count)
127
-
128
- # Verify Auth was called with correct parameters
129
- mock_auth.assert_called_with(
130
- None,
131
- True,
132
- test_server,
133
- [["function-id", test_function_id], ["authorization", f"Bearer {test_api_key}"]],
134
- )
135
-
136
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
137
- def test_error_handling_during_initialization(self, mock_auth):
138
- """Tests error handling during service initialization.
139
-
140
- Args:
141
- mock_auth: Mock for the authentication service.
142
-
143
- The test verifies:
144
- - Proper exception handling
145
- - Error message formatting
146
- - Service cleanup on failure
147
- """
148
- # Arrange
149
- mock_auth.side_effect = Exception("Connection failed")
150
-
151
- # Act & Assert
152
- with self.assertRaises(Exception) as context:
153
- RivaASRService(api_key="test_api_key")
154
-
155
- # Verify the error message
156
- self.assertTrue("Missing module: Connection failed" in str(context.exception))
157
-
158
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
159
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
160
- def test_can_generate_metrics(self, mock_asr_service, mock_auth):
161
- """Test that the service can generate metrics."""
162
- # Arrange
163
- service = RivaASRService(api_key="test_api_key")
164
-
165
- # Act & Assert
166
- self.assertFalse(service.can_generate_metrics())
167
-
168
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
169
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
170
- def test_start_method(self, mock_asr_service, mock_auth):
171
- """Test the start method of RivaASRService."""
172
- # Arrange
173
- service = RivaASRService(api_key="test_api_key")
174
-
175
- # Create a completed mock task for expected return value
176
- mock_task = MagicMock()
177
- mock_task.done.return_value = True
178
-
179
- # Use MagicMock to avoid coroutine awaiting issues
180
- service.create_task = MagicMock(return_value=mock_task)
181
-
182
- # Mock the response task handler
183
- mock_response_coro = MagicMock()
184
- service._response_task_handler = MagicMock(return_value=mock_response_coro)
185
-
186
- # Create a mock StartFrame with the necessary attributes
187
- mock_start_frame = MagicMock(spec=StartFrame)
188
-
189
- # Act
190
- async def run_test():
191
- await service.start(mock_start_frame)
192
-
193
- # Run the test
194
- asyncio.run(run_test())
195
-
196
- # Assert
197
- # Verify create_task was called with the right handlers
198
- service.create_task.assert_called_once_with(mock_response_coro)
199
-
200
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
201
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
202
- def test_stop_method(self, mock_asr_service, mock_auth):
203
- """Test the stop method of RivaASRService."""
204
- # Arrange
205
- service = RivaASRService(api_key="test_api_key")
206
- service._stop_tasks = AsyncMock()
207
-
208
- # Create a mock EndFrame
209
- mock_end_frame = MagicMock(spec=EndFrame)
210
-
211
- # Act
212
- async def run_test():
213
- await service.stop(mock_end_frame)
214
-
215
- # Run the test
216
- asyncio.run(run_test())
217
-
218
- # Assert
219
- service._stop_tasks.assert_called_once()
220
-
221
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
222
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
223
- def test_cancel_method(self, mock_asr_service, mock_auth):
224
- """Test the cancel method of RivaASRService."""
225
- # Arrange
226
- service = RivaASRService(api_key="test_api_key")
227
- service._stop_tasks = AsyncMock()
228
-
229
- # Create a mock CancelFrame
230
- mock_cancel_frame = MagicMock(spec=CancelFrame)
231
-
232
- # Act
233
- async def run_test():
234
- await service.cancel(mock_cancel_frame)
235
-
236
- # Run the test
237
- asyncio.run(run_test())
238
-
239
- # Assert
240
- service._stop_tasks.assert_called_once()
241
-
242
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
243
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
244
- def test_run_stt_yields_frames(self, mock_asr_service, mock_auth):
245
- """Test that run_stt method yields frames."""
246
- # Arrange
247
- service = RivaASRService(api_key="test_api_key")
248
- service._queue = AsyncMock()
249
-
250
- # Create a completed mock task and mock thread task handler
251
- mock_task = MagicMock()
252
- mock_task.done.return_value = True
253
-
254
- mock_thread_coro = MagicMock()
255
- service._thread_task_handler = MagicMock(return_value=mock_thread_coro)
256
- service._thread_task = mock_task
257
-
258
- # Mock the create_task method to return our mock task and avoid coroutine warnings
259
- service.create_task = MagicMock(return_value=mock_task)
260
-
261
- # Act
262
- async def run_test():
263
- frames = []
264
- audio_data = b"test_audio_data"
265
- async for frame in service.run_stt(audio_data):
266
- frames.append(frame)
267
-
268
- # Assert
269
- service._queue.put.assert_called_once_with(audio_data)
270
- # run_stt yields a single None frame for RivaASRService
271
- self.assertEqual(len(frames), 1)
272
- self.assertIsNone(frames[0])
273
- service.create_task.assert_called_once_with(mock_thread_coro)
274
-
275
- # Run the test
276
- asyncio.run(run_test())
277
-
278
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
279
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
280
- def test_handle_response_with_final_transcript(self, mock_asr_service, mock_auth):
281
- """Tests handling of ASR responses with final transcripts.
282
-
283
- Args:
284
- mock_asr_service: Mock for the ASR service.
285
- mock_auth: Mock for the authentication service.
286
-
287
- The test verifies:
288
- - Final transcript processing
289
- - Metrics handling
290
- - Frame generation
291
- - Response completion handling
292
- """
293
- # Arrange
294
- service = RivaASRService(api_key="test_api_key")
295
- service.push_frame = AsyncMock()
296
-
297
- # Use MagicMock instead of asyncio.Future to avoid event loop issues
298
- service.stop_ttfb_metrics = AsyncMock()
299
- service.stop_processing_metrics = AsyncMock()
300
-
301
- # Create a mock response with final transcript
302
- mock_result = MagicMock()
303
- mock_alternative = MagicMock()
304
- mock_alternative.transcript = "This is a final transcript"
305
- mock_result.alternatives = [mock_alternative]
306
- mock_result.is_final = True
307
-
308
- mock_response = MagicMock()
309
- mock_response.results = [mock_result]
310
-
311
- # Act
312
- async def run_test():
313
- await service._handle_response(mock_response)
314
-
315
- # Run the test
316
- asyncio.run(run_test())
317
-
318
- # Assert
319
- service.stop_ttfb_metrics.assert_called_once()
320
- service.stop_processing_metrics.assert_called_once()
321
-
322
- # Verify that a TranscriptionFrame was pushed with the correct text
323
- found = False
324
- for call_args in service.push_frame.call_args_list:
325
- frame = call_args[0][0]
326
- if isinstance(frame, TranscriptionFrame) and frame.text == "This is a final transcript":
327
- found = True
328
- break
329
- self.assertTrue(found, "No TranscriptionFrame with the expected text was pushed")
330
-
331
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
332
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
333
- def test_handle_response_with_interim_transcript(self, mock_asr_service, mock_auth):
334
- """Test handling of ASR responses with interim transcript."""
335
- # Arrange
336
- service = RivaASRService(api_key="test_api_key")
337
- service.push_frame = AsyncMock()
338
-
339
- # Use AsyncMock directly instead of Future
340
- service.stop_ttfb_metrics = AsyncMock()
341
-
342
- # Create a mock response with interim transcript
343
- mock_result = MagicMock()
344
- mock_alternative = MagicMock()
345
- mock_alternative.transcript = "This is an interim transcript"
346
- mock_result.alternatives = [mock_alternative]
347
- mock_result.is_final = False
348
- mock_result.stability = 1.0 # High stability interim result
349
-
350
- mock_response = MagicMock()
351
- mock_response.results = [mock_result]
352
-
353
- # Act
354
- async def run_test():
355
- await service._handle_response(mock_response)
356
-
357
- # Run the test
358
- asyncio.run(run_test())
359
-
360
- # Assert
361
- service.stop_ttfb_metrics.assert_called_once()
362
-
363
- # Verify that a RivaInterimTranscriptionFrame was pushed with the correct text
364
- found = False
365
- for call_args in service.push_frame.call_args_list:
366
- frame = call_args[0][0]
367
- if (
368
- isinstance(frame, RivaInterimTranscriptionFrame)
369
- and frame.text == "This is an interim transcript"
370
- and frame.stability == 1.0
371
- ):
372
- found = True
373
- break
374
- self.assertTrue(found, "No RivaInterimTranscriptionFrame with the expected text was pushed")
375
-
376
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
377
- @patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService")
378
- def test_handle_interruptions(self, mock_asr_service, mock_auth):
379
- """Tests handling of speech interruptions.
380
-
381
- Args:
382
- mock_asr_service: Mock for the ASR service.
383
- mock_auth: Mock for the authentication service.
384
-
385
- The test verifies:
386
- - Interruption start handling
387
- - Interruption stop handling
388
- - Frame sequence generation
389
- - State management during interruptions
390
- """
391
- # Arrange
392
- service = RivaASRService(api_key="test_api_key", generate_interruptions=True)
393
- service.push_frame = AsyncMock()
394
-
395
- # Use AsyncMock directly instead of Future
396
- service._start_interruption = AsyncMock()
397
- service._stop_interruption = AsyncMock()
398
-
399
- # Mock the property to return True - avoids setting the property directly
400
- type(service).interruptions_allowed = MagicMock(return_value=True)
401
-
402
- # Act
403
- async def run_test():
404
- # Simulate interruption handling
405
- user_started_frame = UserStartedSpeakingFrame()
406
- user_stopped_frame = UserStoppedSpeakingFrame()
407
-
408
- # Direct calls to _handle_interruptions
409
- await service._handle_interruptions(user_started_frame)
410
- await service._handle_interruptions(user_stopped_frame)
411
-
412
- # Run the test
413
- asyncio.run(run_test())
414
-
415
- # Assert
416
- service._start_interruption.assert_called_once()
417
- service._stop_interruption.assert_called_once()
418
-
419
- # Check that frames were pushed (check by type instead of exact equality)
420
- pushed_frame_types = [type(call[0][0]) for call in service.push_frame.call_args_list]
421
- self.assertIn(StartInterruptionFrame, pushed_frame_types, "No StartInterruptionFrame was pushed")
422
- self.assertIn(StopInterruptionFrame, pushed_frame_types, "No StopInterruptionFrame was pushed")
423
- self.assertIn(UserStartedSpeakingFrame, pushed_frame_types, "No UserStartedSpeakingFrame was pushed")
424
- self.assertIn(UserStoppedSpeakingFrame, pushed_frame_types, "No UserStoppedSpeakingFrame was pushed")
425
-
426
-
427
- @pytest.mark.asyncio
428
- async def test_riva_asr_integration():
429
- """Tests integration of RivaASRService components.
430
-
431
- Tests the complete flow of the ASR service including initialization,
432
- processing, and cleanup.
433
-
434
- The test verifies:
435
- - Service startup sequence
436
- - Audio processing pipeline
437
- - Response handling
438
- - Service shutdown sequence
439
- - Resource cleanup
440
- """
441
- with (
442
- patch("nvidia_pipecat.services.riva_speech.riva.client.Auth"),
443
- patch("nvidia_pipecat.services.riva_speech.riva.client.ASRService") as mock_asr_service,
444
- ):
445
- # Setup mock ASR service
446
- mock_instance = mock_asr_service.return_value
447
-
448
- # Initialize service with interruptions enabled
449
- service = RivaASRService(api_key="test_api_key", generate_interruptions=True)
450
- service._asr_service = mock_instance
451
-
452
- # Set up the response queue
453
- service._response_queue = asyncio.Queue()
454
-
455
- # Mock the _stop_tasks method directly instead of relying on task_manager
456
- service._stop_tasks = AsyncMock()
457
-
458
- # Create mock coroutines for handlers
459
- thread_coro = MagicMock()
460
- response_coro = MagicMock()
461
-
462
- # Mock the handler methods to return the coroutines
463
- service._thread_task_handler = MagicMock(return_value=thread_coro)
464
- service._response_task_handler = MagicMock(return_value=response_coro)
465
-
466
- # Create a mock task that's already completed
467
- mock_task = MagicMock()
468
- mock_task.done.return_value = True
469
-
470
- # Mock task creation to return our completed task
471
- service.create_task = MagicMock(return_value=mock_task)
472
-
473
- # Set the tasks to our mock task
474
- service._thread_task = mock_task
475
- service._response_task = mock_task
476
-
477
- # Create a mock result for testing
478
- mock_result = MagicMock()
479
- mock_alternative = MagicMock()
480
- mock_alternative.transcript = "This is a test transcript"
481
- mock_result.alternatives = [mock_alternative]
482
- mock_result.is_final = True
483
-
484
- mock_response = MagicMock()
485
- mock_response.results = [mock_result]
486
-
487
- # Create a mock StartFrame
488
- mock_start_frame = MagicMock(spec=StartFrame)
489
-
490
- # Start the service with a start frame
491
- await service.start(mock_start_frame)
492
-
493
- # Verify response task was created
494
- service.create_task.assert_called_with(response_coro)
495
-
496
- # Put a mock response in the queue
497
- await service._response_queue.put(mock_response)
498
-
499
- # Test some other functionality
500
- audio_data = b"test_audio_data"
501
-
502
- # Run the run_stt method
503
- frames = []
504
- async for frame in service.run_stt(audio_data):
505
- frames.append(frame)
506
-
507
- # Verify results
508
- assert len(frames) == 1
509
- assert frames[0] is None
510
-
511
- # Verify thread task was created
512
- service.create_task.assert_any_call(thread_coro)
513
-
514
- # Simulate stopping the service
515
- mock_end_frame = MagicMock(spec=EndFrame)
516
- await service.stop(mock_end_frame)
517
-
518
- # Verify stop_tasks was called
519
- service._stop_tasks.assert_called_once()
520
-
521
-
522
- if __name__ == "__main__":
523
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_riva_nmt_service.py DELETED
@@ -1,197 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the Riva Neural Machine Translation (NMT) service.
5
-
6
- This module contains tests that verify the behavior of the RivaNMTService,
7
- including successful translations and various error cases for both STT and LLM outputs.
8
- """
9
-
10
- import pytest
11
- from loguru import logger
12
- from pipecat.frames.frames import (
13
- ErrorFrame,
14
- LLMFullResponseEndFrame,
15
- LLMFullResponseStartFrame,
16
- LLMMessagesFrame,
17
- TextFrame,
18
- TranscriptionFrame,
19
- )
20
- from pipecat.pipeline.pipeline import Pipeline
21
- from pipecat.pipeline.task import PipelineTask
22
- from pipecat.transcriptions.language import Language
23
-
24
- from nvidia_pipecat.services.riva_nmt import RivaNMTService
25
- from tests.unit.utils import FrameStorage, ignore_ids, run_interactive_test
26
-
27
-
28
- class MockText:
29
- """Mock class representing a translated text response.
30
-
31
- Attributes:
32
- text: The translated text content.
33
- """
34
-
35
- def __init__(self, text):
36
- """Initialize MockText.
37
-
38
- Args:
39
- text (str): The translated text content.
40
- """
41
- self.text = text
42
-
43
-
44
- class MockTranslations:
45
- """Mock class representing a collection of translations.
46
-
47
- Attributes:
48
- translations: List of MockText objects containing translations.
49
- """
50
-
51
- def __init__(self, text):
52
- """Initialize MockTranslations.
53
-
54
- Args:
55
- text (str): The text to be wrapped in a MockText object.
56
- """
57
- self.translations = [MockText(text)]
58
-
59
-
60
- class MockRivaNMTClient:
61
- """Mock class simulating the Riva NMT client.
62
-
63
- Attributes:
64
- translated_text: The text to return as translation.
65
- """
66
-
67
- def __init__(self, translated_text):
68
- """Initialize MockRivaNMTClient.
69
-
70
- Args:
71
- translated_text (str): Text to be returned as translation.
72
- """
73
- self.translated_text = translated_text
74
-
75
- def translate(self, arg1, arg2, arg3, arg4):
76
- """Mock translation method.
77
-
78
- Args:
79
- arg1: Source language.
80
- arg2: Target language.
81
- arg3: Text to translate.
82
- arg4: Additional options.
83
-
84
- Returns:
85
- MockTranslations: A mock translations object containing the pre-defined translated text.
86
- """
87
- return MockTranslations(self.translated_text)
88
-
89
-
90
- @pytest.mark.asyncio()
91
- async def test_riva_nmt_service(mocker):
92
- """Test the RivaNMTService functionality.
93
-
94
- Tests translation service behavior including successful translations
95
- and error handling for both STT and LLM outputs.
96
-
97
- Args:
98
- mocker: Pytest mocker fixture for mocking dependencies.
99
-
100
- The test verifies:
101
- - STT output translation
102
- - LLM output translation
103
- - Empty input handling
104
- - Missing language handling
105
- - Error frame generation
106
- - Frame sequence correctness
107
- """
108
- testcases = {
109
- "Success: STT output translated": {
110
- "source_language": Language.ES_US,
111
- "target_language": Language.EN_US,
112
- "input_frames": [TranscriptionFrame("Hola, por favor preséntate.", "", "")],
113
- "translated_text": "Hello, please introduce yourself.",
114
- "result_frame_name": "LLMMessagesFrame",
115
- "result_frame": LLMMessagesFrame([{"role": "system", "content": "Hello, please introduce yourself."}]),
116
- },
117
- "Success: LLM output translated": {
118
- "source_language": Language.EN_US,
119
- "target_language": Language.ES_US,
120
- "input_frames": [
121
- LLMFullResponseStartFrame(),
122
- TextFrame("Hello there!"),
123
- TextFrame("Im an artificial intelligence model known as Llama."),
124
- LLMFullResponseEndFrame(),
125
- ],
126
- "translated_text": "Hola Im un modelo de inteligencia artificial conocido como Llama",
127
- "result_frame_name": "TextFrame",
128
- "result_frame": TextFrame("Hola Im un modelo de inteligencia artificial conocido como Llama."),
129
- },
130
- "Fail due to empty input text": {
131
- "source_language": Language.ES_US,
132
- "target_language": Language.EN_US,
133
- "input_frames": [TranscriptionFrame("", "", "")],
134
- "translated_text": None,
135
- "result_frame_name": "ErrorFrame",
136
- "result_frame": ErrorFrame(
137
- f"Error while translating the text from {Language.ES_US} to {Language.EN_US}, "
138
- "Error: No input text provided for the translation..",
139
- ),
140
- },
141
- "Fail due to no source language provided": {
142
- "source_language": None,
143
- "target_language": Language.EN_US,
144
- "input_frames": [TranscriptionFrame("Hola, por favor preséntate.", "", "")],
145
- "translated_text": None,
146
- "error": Exception("No source language provided for the translation.."),
147
- },
148
- "Fail due to no target language provided": {
149
- "source_language": Language.ES_US,
150
- "target_language": None,
151
- "input_frames": [TranscriptionFrame("Hola, por favor preséntate.", "", "")],
152
- "translated_text": None,
153
- "error": Exception("No target language provided for the translation.."),
154
- },
155
- }
156
-
157
- for tc_name, tc_data in testcases.items():
158
- logger.info(f"Verifying test case: {tc_name}")
159
-
160
- mocker.patch(
161
- "riva.client.NeuralMachineTranslationClient", return_value=MockRivaNMTClient(tc_data["translated_text"])
162
- )
163
-
164
- try:
165
- nmt_service = RivaNMTService(
166
- source_language=tc_data["source_language"], target_language=tc_data["target_language"]
167
- )
168
- except Exception as e:
169
- assert str(e) == str(tc_data["error"])
170
- continue
171
-
172
- storage1 = FrameStorage()
173
- storage2 = FrameStorage()
174
-
175
- pipeline = Pipeline([storage1, nmt_service, storage2])
176
-
177
- async def test_routine(task: PipelineTask, test_data=tc_data, s1=storage1, s2=storage2):
178
- await task.queue_frames(test_data["input_frames"])
179
-
180
- # Wait for the result frame
181
- if "ErrorFrame" in test_data["result_frame"].name:
182
- await s1.wait_for_frame(ignore_ids(test_data["result_frame"]))
183
- else:
184
- await s2.wait_for_frame(ignore_ids(test_data["result_frame"]))
185
-
186
- await run_interactive_test(pipeline, test_coroutine=test_routine)
187
-
188
- for frame_history_entry in storage1.history:
189
- if frame_history_entry.frame.name.startswith("TextFrame"):
190
- # ignoring input text frames getting stored in storage1
191
- continue
192
- if frame_history_entry.frame.name.startswith(tc_data["result_frame_name"]):
193
- assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
194
-
195
- for frame_history_entry in storage2.history:
196
- if frame_history_entry.frame.name.startswith(tc_data["result_frame_name"]):
197
- assert frame_history_entry.frame == ignore_ids(tc_data["result_frame"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_riva_tts_service.py DELETED
@@ -1,301 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the RivaTTSService.
5
-
6
- This module contains tests for the RivaTTSService class, including initialization,
7
- TTS frame generation, audio frame handling, and integration tests.
8
- """
9
-
10
- import asyncio
11
- import unittest
12
- from unittest.mock import AsyncMock, MagicMock, patch
13
-
14
- import pytest
15
- from pipecat.frames.frames import TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, TTSTextFrame
16
- from pipecat.transcriptions.language import Language
17
-
18
- from nvidia_pipecat.services.riva_speech import RivaTTSService
19
-
20
-
21
- class TestRivaTTSService(unittest.TestCase):
22
- """Test suite for RivaTTSService functionality."""
23
-
24
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
25
- @patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
26
- def test_run_tts_yields_audio_frames(self, mock_speech_service, mock_auth):
27
- """Tests that run_tts correctly yields audio frames in sequence.
28
-
29
- Tests the complete flow of text-to-speech conversion including start,
30
- text processing, audio generation, and completion frames.
31
-
32
- Args:
33
- mock_speech_service: Mock for the speech synthesis service.
34
- mock_auth: Mock for the authentication service.
35
-
36
- The test verifies:
37
- - TTSStartedFrame generation
38
- - TTSTextFrame content
39
- - TTSAudioRawFrame generation
40
- - TTSStoppedFrame generation
41
- - Frame sequence order
42
- - Audio data integrity
43
- """
44
- # Arrange
45
- mock_audio_data = b"sample_audio_data"
46
- mock_audio_frame = TTSAudioRawFrame(audio=mock_audio_data, sample_rate=16000, num_channels=1)
47
-
48
- # Create a properly structured mock service
49
- mock_service_instance = MagicMock()
50
- mock_speech_service.return_value = mock_service_instance
51
-
52
- # Set up the mock to return a regular list, not an async generator
53
- # The service expects a list/iterable that it can call iter() on
54
- mock_service_instance.synthesize_online.return_value = [mock_audio_frame]
55
-
56
- # Create an instance of RivaTTSService
57
- service = RivaTTSService(api_key="test_api_key")
58
-
59
- # Ensure the service has the right mock
60
- service._service = mock_service_instance
61
-
62
- # Act
63
- async def run_test():
64
- frames = []
65
- async for frame in service.run_tts("Hello, world!"):
66
- frames.append(frame)
67
-
68
- # Assert
69
- self.assertEqual(
70
- len(frames), 4
71
- ) # Should yield 4 frames: TTSStartedFrame, TTSTextFrame, TTSAudioRawFrame, TTSStoppedFrame
72
- self.assertIsInstance(frames[0], TTSStartedFrame)
73
- self.assertIsInstance(frames[1], TTSTextFrame)
74
- self.assertEqual(frames[1].text, "Hello, world!")
75
- self.assertIsInstance(frames[2], TTSAudioRawFrame)
76
- self.assertEqual(frames[2].audio, mock_audio_data)
77
- self.assertIsInstance(frames[3], TTSStoppedFrame)
78
-
79
- # Run the async test
80
- asyncio.run(run_test())
81
-
82
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
83
- @patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
84
- def test_push_tts_frames(self, mock_speech_service, mock_auth):
85
- """Tests that _push_tts_frames correctly processes TTS generation.
86
-
87
- Tests the internal frame processing mechanism including metrics
88
- handling and generator processing.
89
-
90
- Args:
91
- mock_speech_service: Mock for the speech synthesis service.
92
- mock_auth: Mock for the authentication service.
93
-
94
- The test verifies:
95
- - Metrics start/stop timing
96
- - Generator processing
97
- - Frame processing order
98
- - Method call sequence
99
- """
100
- # Arrange
101
- mock_audio_frame = TTSAudioRawFrame(audio=b"sample_audio_data", sample_rate=16000, num_channels=1)
102
-
103
- # Create a properly structured mock service
104
- mock_service_instance = MagicMock()
105
- mock_speech_service.return_value = mock_service_instance
106
-
107
- # Return a regular list for synthesize_online
108
- mock_service_instance.synthesize_online.return_value = [mock_audio_frame]
109
-
110
- # Create an instance of RivaTTSService
111
- service = RivaTTSService(api_key="test_api_key")
112
- service._service = mock_service_instance
113
- service.start_processing_metrics = AsyncMock()
114
- service.stop_processing_metrics = AsyncMock()
115
- service.process_generator = AsyncMock()
116
-
117
- # Create a mock for run_tts instead of trying to capture its generator
118
- async def mock_run_tts(text):
119
- # This is the sequence of frames that would be yielded by run_tts
120
- yield TTSStartedFrame()
121
- yield TTSTextFrame(text)
122
- yield mock_audio_frame
123
- yield TTSStoppedFrame()
124
-
125
- # Replace the run_tts method with our mock
126
- with patch.object(service, "run_tts", side_effect=mock_run_tts):
127
- # Act
128
- async def run_test():
129
- await service._push_tts_frames("Hello, world!")
130
-
131
- # Assert
132
- # Check that the necessary methods were called
133
- service.start_processing_metrics.assert_called_once()
134
- service.process_generator.assert_called_once()
135
- service.stop_processing_metrics.assert_called_once()
136
-
137
- # Verify call order using the call_args.called_before method
138
- assert service.start_processing_metrics.call_args.called_before(service.process_generator.call_args)
139
- assert service.process_generator.call_args.called_before(service.stop_processing_metrics.call_args)
140
-
141
- # Run the async test
142
- asyncio.run(run_test())
143
-
144
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
145
- @patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
146
- def test_init_with_different_parameters(self, mock_speech_service, mock_auth):
147
- """Tests initialization with various configuration parameters.
148
-
149
- Tests service initialization with different combinations of
150
- configuration options.
151
-
152
- Args:
153
- mock_speech_service: Mock for the speech synthesis service.
154
- mock_auth: Mock for the authentication service.
155
-
156
- The test verifies:
157
- - Parameter assignment
158
- - Default value handling
159
- - Custom parameter validation
160
- - Authentication configuration
161
- - Service initialization
162
- """
163
- # Define test parameters - users should be able to use any values
164
- test_api_key = "test_api_key"
165
- test_server = "custom_server:50051"
166
- test_voice_id = "English-US.Male-1"
167
- test_sample_rate = 22050
168
- test_language = Language.ES_ES
169
- test_zero_shot_quality = 10
170
- test_model = "custom-tts-model"
171
- test_dictionary = {"word": "pronunciation"}
172
- test_audio_prompt_file = "test_audio.wav"
173
- # Test initialization with different parameters
174
- service = RivaTTSService(
175
- api_key=test_api_key,
176
- server=test_server,
177
- voice_id=test_voice_id,
178
- sample_rate=test_sample_rate,
179
- language=test_language,
180
- zero_shot_quality=test_zero_shot_quality,
181
- model=test_model,
182
- custom_dictionary=test_dictionary,
183
- zero_shot_audio_prompt_file=test_audio_prompt_file,
184
- use_ssl=True,
185
- )
186
-
187
- # Verify the parameters were set correctly
188
- self.assertEqual(service._api_key, test_api_key)
189
- self.assertEqual(service._voice_id, test_voice_id)
190
- self.assertEqual(service._sample_rate, test_sample_rate)
191
- self.assertEqual(service._language_code, test_language)
192
- self.assertEqual(service._zero_shot_quality, test_zero_shot_quality)
193
- self.assertEqual(service._model_name, test_model)
194
- self.assertEqual(service._custom_dictionary, test_dictionary)
195
- self.assertEqual(service._zero_shot_audio_prompt_file, test_audio_prompt_file)
196
-
197
- # Verify Auth was called with correct parameters
198
- mock_auth.assert_called_with(
199
- None,
200
- True, # use_ssl=True
201
- test_server,
202
- [["function-id", "0149dedb-2be8-4195-b9a0-e57e0e14f972"], ["authorization", f"Bearer {test_api_key}"]],
203
- )
204
-
205
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
206
- def test_riva_error_handling(self, mock_auth):
207
- """Tests error handling when Riva service initialization fails.
208
-
209
- Tests the service's behavior when encountering initialization errors.
210
-
211
- Args:
212
- mock_auth: Mock for the authentication service.
213
-
214
- The test verifies:
215
- - Exception propagation
216
- - Error message formatting
217
- - Cleanup behavior
218
- - Service state after error
219
- """
220
- # Test error handling when Riva service initialization fails
221
- mock_auth.side_effect = Exception("Connection failed")
222
-
223
- # Assert that exception is raised and propagated
224
- with self.assertRaises(Exception) as context:
225
- RivaTTSService(api_key="test_api_key")
226
-
227
- self.assertTrue("Missing module: Connection failed" in str(context.exception))
228
-
229
- @patch("nvidia_pipecat.services.riva_speech.riva.client.Auth")
230
- @patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService")
231
- def test_can_generate_metrics(self, mock_speech_service, mock_auth):
232
- """Tests that the service reports capability to generate metrics.
233
-
234
- Tests the metrics generation capability reporting functionality.
235
-
236
- Args:
237
- mock_speech_service: Mock for the speech synthesis service.
238
- mock_auth: Mock for the authentication service.
239
-
240
- The test verifies:
241
- - Metrics capability reporting
242
- - Consistency of capability flag
243
- """
244
- # Test that the service can generate metrics
245
- service = RivaTTSService(api_key="test_api_key")
246
- self.assertTrue(service.can_generate_metrics())
247
-
248
-
249
- @pytest.mark.asyncio
250
- async def test_riva_tts_integration():
251
- """Tests integration of RivaTTSService components.
252
-
253
- Tests the complete flow of the TTS service in an integrated environment.
254
-
255
- The test verifies:
256
- - Service initialization
257
- - Frame generation sequence
258
- - Audio chunk processing
259
- - Frame content validation
260
- - Service completion
261
- """
262
- # Use parentheses for multiline with statements instead of backslashes
263
- with (
264
- patch("nvidia_pipecat.services.riva_speech.riva.client.Auth"),
265
- patch("nvidia_pipecat.services.riva_speech.riva.client.SpeechSynthesisService") as mock_service,
266
- ):
267
- # Setup mock responses
268
- mock_instance = mock_service.return_value
269
-
270
- # Create audio frames for the response
271
- audio_frame1 = TTSAudioRawFrame(audio=b"audio_chunk_1", sample_rate=16000, num_channels=1)
272
- audio_frame2 = TTSAudioRawFrame(audio=b"audio_chunk_2", sample_rate=16000, num_channels=1)
273
-
274
- # Return a list of frames, not an async generator
275
- mock_instance.synthesize_online.return_value = [audio_frame1, audio_frame2]
276
-
277
- # Initialize service and call its methods
278
- service = RivaTTSService(api_key="test_api_key")
279
- service._service = mock_instance
280
-
281
- # Simulate running the service
282
- collected_frames = []
283
- async for frame in service.run_tts("Test sentence for TTS"):
284
- collected_frames.append(frame)
285
-
286
- # Verify the expected frames were produced
287
- assert len(collected_frames) == 5 # Started, TextFrame, 2 audio chunks, stopped
288
- assert isinstance(collected_frames[0], TTSStartedFrame)
289
- assert isinstance(collected_frames[1], TTSTextFrame)
290
- assert collected_frames[1].text == "Test sentence for TTS"
291
- assert isinstance(collected_frames[2], TTSAudioRawFrame)
292
- assert isinstance(collected_frames[3], TTSAudioRawFrame)
293
- assert isinstance(collected_frames[4], TTSStoppedFrame)
294
-
295
- # Verify audio content
296
- assert collected_frames[2].audio == b"audio_chunk_1"
297
- assert collected_frames[3].audio == b"audio_chunk_2"
298
-
299
-
300
- if __name__ == "__main__":
301
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_speech_planner.py DELETED
@@ -1,546 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the SpeechPlanner service.
5
-
6
- This module contains tests for the SpeechPlanner class, focusing on core functionalities:
7
- - Service initialization with various parameters
8
- - Frame processing for different frame types
9
- - Speech completion detection and LLM integration
10
- - Chat history management with context windows
11
- - Interruption handling and VAD integration
12
- - Label preprocessing and classification logic
13
- """
14
-
15
- import tempfile
16
- from datetime import datetime
17
- from unittest.mock import AsyncMock, Mock, patch
18
-
19
- import pytest
20
- import yaml
21
- from pipecat.frames.frames import (
22
- InterimTranscriptionFrame,
23
- StartInterruptionFrame,
24
- StopInterruptionFrame,
25
- TranscriptionFrame,
26
- )
27
- from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
28
-
29
- from nvidia_pipecat.services.speech_planner import SpeechPlanner
30
-
31
-
32
- class MockBaseMessageChunk:
33
- """Mock for BaseMessageChunk that mimics the structure."""
34
-
35
- def __init__(self, content=""):
36
- """Initialize with content.
37
-
38
- Args:
39
- content: The text content of the chunk
40
- """
41
- self.content = content
42
-
43
-
44
- @pytest.fixture
45
- def mock_prompt_file():
46
- """Create a temporary YAML prompt file for testing."""
47
- prompt_data = {
48
- "configurations": {"using_chat_history": False},
49
- "prompts": {
50
- "completion_prompt": (
51
- "Evaluate whether the following user speech is sufficient:\n"
52
- "1. Label1: Complete and coherent thought\n"
53
- "2. Label2: Incomplete speech\n"
54
- "3. Label3: User commands\n"
55
- "4. Label4: Acknowledgments\n"
56
- "User Speech: {transcript}\n"
57
- "Only return Label1 or Label2 or Label3 or Label4."
58
- )
59
- },
60
- }
61
-
62
- with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
63
- yaml.dump(prompt_data, f)
64
- return f.name
65
-
66
-
67
- @pytest.fixture
68
- def mock_context():
69
- """Create a mock OpenAI LLM context."""
70
- context = Mock(spec=OpenAILLMContext)
71
- context.get_messages.return_value = [
72
- {"role": "user", "content": "Hello"},
73
- {"role": "assistant", "content": "Hi there!"},
74
- {"role": "user", "content": "How are you?"},
75
- {"role": "assistant", "content": "I'm doing well!"},
76
- ]
77
- return context
78
-
79
-
80
- @pytest.fixture
81
- def speech_planner(mock_prompt_file, mock_context):
82
- """Create a SpeechPlanner instance for testing."""
83
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
84
- mock_chat.return_value = AsyncMock()
85
- planner = SpeechPlanner(
86
- prompt_file=mock_prompt_file,
87
- model="test-model",
88
- api_key="test-key",
89
- base_url="http://test-url",
90
- context=mock_context,
91
- context_window=2,
92
- )
93
- return planner
94
-
95
-
96
- class TestSpeechPlannerInitialization:
97
- """Test SpeechPlanner initialization."""
98
-
99
- def test_init_with_default_params(self, mock_prompt_file, mock_context):
100
- """Test initialization with default parameters."""
101
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
102
- mock_chat.return_value = AsyncMock()
103
-
104
- planner = SpeechPlanner(prompt_file=mock_prompt_file, context=mock_context)
105
-
106
- assert planner.model_name == "nvdev/google/gemma-2b-it"
107
- assert planner.context == mock_context
108
- assert planner.context_window == 1
109
- assert planner.user_speaking is None
110
- assert planner.current_prediction is None
111
-
112
- def test_init_with_custom_params(self, mock_prompt_file, mock_context):
113
- """Test initialization with custom parameters."""
114
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
115
- mock_chat.return_value = AsyncMock()
116
-
117
- params = SpeechPlanner.InputParams(temperature=0.7, max_tokens=100, top_p=0.9)
118
-
119
- planner = SpeechPlanner(
120
- prompt_file=mock_prompt_file,
121
- model="custom-model",
122
- api_key="custom-key",
123
- base_url="http://custom-url",
124
- context=mock_context,
125
- params=params,
126
- context_window=3,
127
- )
128
-
129
- assert planner.model_name == "custom-model"
130
- assert planner.context_window == 3
131
- assert planner._settings["temperature"] == 0.7
132
- assert planner._settings["max_tokens"] == 100
133
- assert planner._settings["top_p"] == 0.9
134
-
135
- def test_init_loads_prompts(self, mock_prompt_file, mock_context):
136
- """Test that initialization properly loads prompts from file."""
137
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
138
- mock_chat.return_value = AsyncMock()
139
-
140
- planner = SpeechPlanner(prompt_file=mock_prompt_file, context=mock_context)
141
-
142
- assert "configurations" in planner.prompts
143
- assert "prompts" in planner.prompts
144
- assert "completion_prompt" in planner.prompts["prompts"]
145
- assert planner.prompts["configurations"]["using_chat_history"] is False
146
-
147
-
148
- ### Adding tests for preprocess_pred function
149
-
150
-
151
- class TestPreprocessPred:
152
- """Test the preprocess_pred function logic through end-to-end processing."""
153
-
154
- @pytest.mark.asyncio
155
- async def test_preprocess_pred_label1_complete(self, speech_planner):
156
- """Test preprocess_pred with Label1 returns Complete via end-to-end processing."""
157
- frame = TranscriptionFrame("Hello there", "user1", datetime.now())
158
- test_cases = ["Label1", "Label 1", "The answer is Label1.", "I think this is Label 1 based on analysis."]
159
-
160
- for case in test_cases:
161
- # Mock the LLM to return the specific prediction
162
- mock_chunks = [MockBaseMessageChunk(case)]
163
-
164
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
165
- mock_stream.return_value.__aiter__.return_value = mock_chunks
166
-
167
- await speech_planner._process_complete_context(frame)
168
- assert speech_planner.current_prediction == "Complete", f"Failed for case: {case}"
169
-
170
- @pytest.mark.asyncio
171
- async def test_preprocess_pred_label2_incomplete(self, speech_planner):
172
- """Test preprocess_pred with Label2 returns Incomplete via end-to-end processing."""
173
- frame = TranscriptionFrame("Hello", "user1", datetime.now())
174
- test_cases = ["Label2", "Label 2", "The answer is Label2.", "This should be Label 2."]
175
-
176
- for case in test_cases:
177
- # Mock the LLM to return the specific prediction
178
- mock_chunks = [MockBaseMessageChunk(case)]
179
-
180
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
181
- mock_stream.return_value.__aiter__.return_value = mock_chunks
182
-
183
- await speech_planner._process_complete_context(frame)
184
- assert speech_planner.current_prediction == "Incomplete", f"Failed for case: {case}"
185
-
186
- @pytest.mark.asyncio
187
- async def test_preprocess_pred_label3_complete(self, speech_planner):
188
- """Test preprocess_pred with Label3 returns Complete via end-to-end processing."""
189
- frame = TranscriptionFrame("Stop that", "user1", datetime.now())
190
- test_cases = ["Label3", "Label 3", "The answer is Label3.", "I classify this as Label 3."]
191
-
192
- for case in test_cases:
193
- # Mock the LLM to return the specific prediction
194
- mock_chunks = [MockBaseMessageChunk(case)]
195
-
196
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
197
- mock_stream.return_value.__aiter__.return_value = mock_chunks
198
-
199
- await speech_planner._process_complete_context(frame)
200
- assert speech_planner.current_prediction == "Complete", f"Failed for case: {case}"
201
-
202
- @pytest.mark.asyncio
203
- async def test_preprocess_pred_label4_complete(self, speech_planner):
204
- """Test preprocess_pred with Label4 returns Complete via end-to-end processing."""
205
- frame = TranscriptionFrame("Okay", "user1", datetime.now())
206
- test_cases = ["Label4", "Label 4", "The answer is Label4.", "This is Label 4 category."]
207
-
208
- for case in test_cases:
209
- # Mock the LLM to return the specific prediction
210
- mock_chunks = [MockBaseMessageChunk(case)]
211
-
212
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
213
- mock_stream.return_value.__aiter__.return_value = mock_chunks
214
-
215
- await speech_planner._process_complete_context(frame)
216
- assert speech_planner.current_prediction == "Complete", f"Failed for case: {case}"
217
-
218
- @pytest.mark.asyncio
219
- async def test_preprocess_pred_unrecognized_incomplete(self, speech_planner):
220
- """Test preprocess_pred with unrecognized labels returns Incomplete via end-to-end processing."""
221
- frame = TranscriptionFrame("Unknown input", "user1", datetime.now())
222
- test_cases = ["Label5", "Unknown", "No label found", "", "Some random text"]
223
-
224
- for case in test_cases:
225
- # Mock the LLM to return the specific prediction
226
- mock_chunks = [MockBaseMessageChunk(case)]
227
-
228
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
229
- mock_stream.return_value.__aiter__.return_value = mock_chunks
230
-
231
- await speech_planner._process_complete_context(frame)
232
- assert speech_planner.current_prediction == "Incomplete", f"Failed for case: {case}"
233
-
234
-
235
- ### Adding tests for chat_history management
236
- class TestChatHistory:
237
- """Test chat history management."""
238
-
239
- def test_get_chat_history_empty_messages(self, mock_prompt_file):
240
- """Test get_chat_history with empty message list."""
241
- context = Mock(spec=OpenAILLMContext)
242
- context.get_messages.return_value = []
243
-
244
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
245
- mock_chat.return_value = AsyncMock()
246
-
247
- planner = SpeechPlanner(prompt_file=mock_prompt_file, context=context, context_window=2)
248
-
249
- history = planner.get_chat_history()
250
- assert history == []
251
-
252
- def test_get_chat_history_with_context_window(self, mock_prompt_file):
253
- """Test get_chat_history respects context_window setting."""
254
- messages = [
255
- {"role": "user", "content": "First user message"},
256
- {"role": "assistant", "content": "First assistant response"},
257
- {"role": "user", "content": "Second user message"},
258
- {"role": "assistant", "content": "Second assistant response"},
259
- {"role": "user", "content": "Third user message"},
260
- {"role": "assistant", "content": "Third assistant response"},
261
- ]
262
-
263
- context = Mock(spec=OpenAILLMContext)
264
- context.get_messages.return_value = messages
265
-
266
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
267
- mock_chat.return_value = AsyncMock()
268
-
269
- # Test with context_window=1 (should get last 2 messages)
270
- planner = SpeechPlanner(prompt_file=mock_prompt_file, context=context, context_window=1)
271
-
272
- history = planner.get_chat_history()
273
- assert len(history) == 2
274
- assert history[0]["content"] == "Third user message"
275
- assert history[1]["content"] == "Third assistant response"
276
-
277
- def test_get_chat_history_starts_with_user(self, mock_prompt_file):
278
- """Test get_chat_history starts with user message."""
279
- messages = [
280
- {"role": "user", "content": "User message 1"},
281
- {"role": "assistant", "content": "Assistant response 1"},
282
- {"role": "user", "content": "User message 2"},
283
- {"role": "assistant", "content": "Assistant response 2"},
284
- ]
285
-
286
- context = Mock(spec=OpenAILLMContext)
287
- context.get_messages.return_value = messages
288
-
289
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat:
290
- mock_chat.return_value = AsyncMock()
291
-
292
- planner = SpeechPlanner(prompt_file=mock_prompt_file, context=context, context_window=2)
293
-
294
- history = planner.get_chat_history()
295
- assert len(history) > 0
296
- assert history[0]["role"] == "user"
297
-
298
-
299
- class TestFrameProcessing:
300
- """Test frame processing functionality by testing the logic directly."""
301
-
302
- def test_user_speaking_state_changes(self, speech_planner):
303
- """Test that user speaking state changes correctly."""
304
- # Test UserStartedSpeakingFrame logic
305
- assert speech_planner.user_speaking is None
306
- speech_planner.user_speaking = True
307
- assert speech_planner.user_speaking is True
308
-
309
- # Test UserStoppedSpeakingFrame logic
310
- speech_planner.user_speaking = False
311
- assert speech_planner.user_speaking is False
312
-
313
- def test_bot_speaking_timestamp_tracking(self, speech_planner):
314
- """Test bot speaking timestamp tracking."""
315
- # Initially no timestamp
316
- assert speech_planner.latest_bot_started_speaking_frame_timestamp is None
317
-
318
- # Set timestamp (simulating BotStartedSpeakingFrame)
319
- test_time = datetime.now()
320
- speech_planner.latest_bot_started_speaking_frame_timestamp = test_time
321
- assert speech_planner.latest_bot_started_speaking_frame_timestamp == test_time
322
-
323
- # Clear timestamp (simulating BotStoppedSpeakingFrame)
324
- speech_planner.latest_bot_started_speaking_frame_timestamp = None
325
- assert speech_planner.latest_bot_started_speaking_frame_timestamp is None
326
-
327
- def test_frame_state_management(self, speech_planner):
328
- """Test frame state management without full processing."""
329
- # Test last_frame tracking
330
- frame = InterimTranscriptionFrame("Hello", "user1", datetime.now())
331
- speech_planner.last_frame = frame
332
- assert speech_planner.last_frame == frame
333
-
334
- # Test clearing last_frame
335
- speech_planner.last_frame = None
336
- assert speech_planner.last_frame is None
337
-
338
- # Test current_prediction state
339
- speech_planner.current_prediction = "Complete"
340
- assert speech_planner.current_prediction == "Complete"
341
-
342
- speech_planner.current_prediction = "Incomplete"
343
- assert speech_planner.current_prediction == "Incomplete"
344
-
345
- def test_transcription_frame_conditions(self, speech_planner):
346
- """Test the conditions for processing transcription frames."""
347
- # Set up conditions for processing
348
- speech_planner.user_speaking = False
349
- speech_planner.current_prediction = "Incomplete"
350
-
351
- # These conditions should allow processing
352
- assert speech_planner.user_speaking is False
353
- assert speech_planner.current_prediction == "Incomplete"
354
-
355
- # Test conditions that would prevent processing
356
- speech_planner.user_speaking = True
357
- assert speech_planner.user_speaking is True # Should prevent processing
358
-
359
- @pytest.mark.asyncio
360
- async def test_cancel_current_task_helper(self, speech_planner):
361
- """Test the _cancel_current_task helper method."""
362
- # Test with no current task
363
- speech_planner._current_task = None
364
- await speech_planner._cancel_current_task()
365
- assert speech_planner._current_task is None
366
-
367
- # Test with completed task
368
- completed_task = Mock()
369
- completed_task.done.return_value = True
370
- completed_task.cancelled.return_value = False
371
- speech_planner._current_task = completed_task
372
-
373
- await speech_planner._cancel_current_task()
374
- assert speech_planner._current_task is None
375
-
376
- # Test with cancelled task
377
- cancelled_task = Mock()
378
- cancelled_task.done.return_value = False
379
- cancelled_task.cancelled.return_value = True
380
- speech_planner._current_task = cancelled_task
381
-
382
- await speech_planner._cancel_current_task()
383
- assert speech_planner._current_task is None
384
-
385
- # Test with active task that needs cancellation
386
- active_task = Mock()
387
- active_task.done.return_value = False
388
- active_task.cancelled.return_value = False
389
- speech_planner._current_task = active_task
390
-
391
- with patch.object(speech_planner, "cancel_task", new_callable=AsyncMock) as mock_cancel:
392
- await speech_planner._cancel_current_task()
393
- mock_cancel.assert_called_once_with(active_task)
394
- assert speech_planner._current_task is None
395
-
396
-
397
- class TestCompletionDetection:
398
- """Test speech completion detection."""
399
-
400
- @pytest.mark.asyncio
401
- async def test_process_complete_context_with_complete_prediction(self, speech_planner):
402
- """Test _process_complete_context with complete prediction."""
403
- frame = TranscriptionFrame("Hello there", "user1", datetime.now())
404
-
405
- # Mock the LLM response
406
- mock_chunks = [MockBaseMessageChunk("Label1")]
407
-
408
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
409
- mock_stream.return_value.__aiter__.return_value = mock_chunks
410
- with patch.object(speech_planner, "push_frame", new_callable=AsyncMock) as mock_push:
411
- await speech_planner._process_complete_context(frame)
412
-
413
- assert speech_planner.current_prediction == "Complete"
414
-
415
- # Should push start/stop interruption frames and transcription
416
- assert mock_push.call_count >= 3
417
-
418
- # Verify the correct frames were pushed
419
- call_args = [args[0][0] for args in mock_push.call_args_list]
420
- assert any(isinstance(f, StartInterruptionFrame) for f in call_args)
421
- assert any(isinstance(f, StopInterruptionFrame) for f in call_args)
422
- assert any(isinstance(f, TranscriptionFrame) for f in call_args)
423
-
424
- @pytest.mark.asyncio
425
- async def test_process_complete_context_with_incomplete_prediction(self, speech_planner):
426
- """Test _process_complete_context with incomplete prediction."""
427
- frame = TranscriptionFrame("Hello", "user1", datetime.now())
428
-
429
- # Mock the LLM response with Label2 (incomplete)
430
- mock_chunks = [MockBaseMessageChunk("Label2")]
431
-
432
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
433
- mock_stream.return_value.__aiter__.return_value = mock_chunks
434
- with patch.object(speech_planner, "push_frame", new_callable=AsyncMock) as mock_push:
435
- await speech_planner._process_complete_context(frame)
436
-
437
- assert speech_planner.current_prediction == "Incomplete"
438
- # Should not push any frames for incomplete prediction
439
- mock_push.assert_not_called()
440
-
441
- @pytest.mark.asyncio
442
- async def test_process_complete_context_with_error(self, speech_planner):
443
- """Test _process_complete_context handles errors gracefully with proper logging."""
444
- frame = TranscriptionFrame("Hello", "user1", datetime.now())
445
-
446
- # Mock an exception during processing
447
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
448
- mock_stream.side_effect = Exception("LLM service error")
449
-
450
- # Patch logger to verify warning is logged with stack trace
451
- with patch("nvidia_pipecat.services.speech_planner.logger") as mock_logger:
452
- await speech_planner._process_complete_context(frame)
453
-
454
- # Should default to "Complete" on error
455
- assert speech_planner.current_prediction == "Complete"
456
-
457
- # Should log warning with stack trace information
458
- mock_logger.warning.assert_called_once()
459
- call_args = mock_logger.warning.call_args
460
- assert "Disabling Smart EOU detection due to error" in call_args[0][0]
461
- assert "LLM service error" in call_args[0][0]
462
- assert call_args[1]["exc_info"] is True
463
-
464
- @pytest.mark.asyncio
465
- async def test_process_complete_context_with_chunk_content_error(self, speech_planner):
466
- """Test _process_complete_context handles chunk content errors gracefully."""
467
- frame = TranscriptionFrame("Hello", "user1", datetime.now())
468
-
469
- # Mock chunks where one has content that causes concatenation error
470
- class BadChunk:
471
- def __init__(self):
472
- # Use a number instead of None to pass the `if not chunk.content:` check
473
- # but still cause TypeError during string concatenation
474
- self.content = 42
475
-
476
- mock_chunks = [MockBaseMessageChunk("Label"), BadChunk()]
477
-
478
- with patch.object(speech_planner, "_stream_chat_completions", new_callable=AsyncMock) as mock_stream:
479
- mock_stream.return_value.__aiter__.return_value = mock_chunks
480
-
481
- # Patch logger to verify debug message is logged for chunk error
482
- with patch("nvidia_pipecat.services.speech_planner.logger") as mock_logger:
483
- await speech_planner._process_complete_context(frame)
484
-
485
- # Should still get a prediction (either from first chunk or default "Complete")
486
- assert speech_planner.current_prediction is not None
487
-
488
- # Should log debug message for chunk content error
489
- debug_calls = [
490
- call for call in mock_logger.debug.call_args_list if "Failed to append chunk content" in str(call)
491
- ]
492
- assert len(debug_calls) >= 1, (
493
- f"Expected debug log for chunk error, got calls: {mock_logger.debug.call_args_list}"
494
- )
495
-
496
-
497
- class TestClientCreation:
498
- """Test client creation functionality."""
499
-
500
- def test_create_client_with_params(self, mock_prompt_file, mock_context):
501
- """Test create_client method with parameters."""
502
- with patch("nvidia_pipecat.services.speech_planner.ChatNVIDIA") as mock_chat_class:
503
- mock_client = AsyncMock()
504
- mock_chat_class.return_value = mock_client
505
-
506
- planner = SpeechPlanner(
507
- prompt_file=mock_prompt_file,
508
- context=mock_context,
509
- model="test-model",
510
- api_key="test-key",
511
- base_url="http://test-url",
512
- )
513
-
514
- client = planner.create_client(api_key="custom-key", base_url="http://custom-url")
515
-
516
- # Verify ChatNVIDIA was called with correct parameters
517
- mock_chat_class.assert_called_with(base_url="http://custom-url", model="test-model", api_key="custom-key")
518
- assert client == mock_client
519
-
520
-
521
- @pytest.mark.asyncio
522
- async def test_get_chat_completions(speech_planner):
523
- """Test get_chat_completions method."""
524
- messages = [{"role": "user", "content": "Test message"}]
525
-
526
- # Mock the client's astream method to return an async iterator
527
- mock_chunks = [MockBaseMessageChunk("Response chunk")]
528
-
529
- async def mock_astream(*args, **kwargs):
530
- for chunk in mock_chunks:
531
- yield chunk
532
-
533
- speech_planner._client.astream = mock_astream
534
-
535
- result = await speech_planner.get_chat_completions(messages)
536
-
537
- # The result should be the return value from astream
538
- assert result is not None
539
-
540
- # Convert result to list to verify contents
541
- result_list = []
542
- async for chunk in result:
543
- result_list.append(chunk)
544
-
545
- assert len(result_list) == 1
546
- assert result_list[0].content == "Response chunk"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_traced_processor.py DELETED
@@ -1,159 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests tracing."""
5
-
6
- import pytest
7
- from opentelemetry import trace
8
- from pipecat.frames.frames import EndFrame, ErrorFrame, Frame, TextFrame
9
- from pipecat.pipeline.pipeline import Pipeline
10
- from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
11
-
12
- from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced
13
- from tests.unit.utils import ignore_ids, run_test
14
-
15
-
16
- @pytest.mark.asyncio
17
- async def test_traced_processor_basic_usage():
18
- """Tests basic tracing functionality in a pipeline processor.
19
-
20
- Tests the @traceable and @traced decorators with different attachment
21
- strategies in a simple pipeline configuration.
22
-
23
- The test verifies:
24
- - Span creation and attachment
25
- - Event recording
26
- - Nested span handling
27
- - Generator tracing
28
- - Frame processing
29
- """
30
- tracer = trace.get_tracer(__name__)
31
-
32
- @traceable
33
- class TestProcessor(FrameProcessor):
34
- """Example processor demonstrating how to use the tracing utilities."""
35
-
36
- @traced(attachment_strategy=AttachmentStrategy.NONE)
37
- async def process_frame(self, frame, direction):
38
- """Process a frame with tracing.
39
-
40
- Args:
41
- frame: The frame to process.
42
- direction: The direction of frame flow.
43
-
44
- The method demonstrates:
45
- - Basic span creation
46
- - Event recording
47
- - Nested span handling
48
- - Multiple tracing strategies
49
- """
50
- await super().process_frame(frame, direction)
51
- trace.get_current_span().add_event("Before inner")
52
- with tracer.start_as_current_span("inner") as span:
53
- span.add_event("inner event")
54
- await self.child()
55
- await self.linked()
56
- await self.none()
57
- trace.get_current_span().add_event("After inner")
58
- async for f in self.generator():
59
- print(f"{f}")
60
- await super().push_frame(frame, direction)
61
-
62
- @traced
63
- async def child(self):
64
- """Example method with child attachment strategy.
65
-
66
- This span is attached as CHILD, meaning it will be attached to
67
- the class span if no parent exists, or to its parent otherwise.
68
- """
69
- trace.get_current_span().add_event("child")
70
-
71
- @traced(attachment_strategy=AttachmentStrategy.LINK)
72
- async def linked(self):
73
- """Example method with link attachment strategy.
74
-
75
- This span is attached as LINK, meaning it will be attached to
76
- the class span but linked to its parent.
77
- """
78
- trace.get_current_span().add_event("linked")
79
-
80
- @traced(attachment_strategy=AttachmentStrategy.NONE)
81
- async def none(self):
82
- """Example method with no attachment strategy.
83
-
84
- This span is attached as NONE, meaning it will be attached to
85
- the class span even if nested under another span.
86
- """
87
- trace.get_current_span().add_event("none")
88
-
89
- @traced
90
- async def generator(self):
91
- """Example generator method with tracing.
92
-
93
- Demonstrates tracing in a generator context.
94
-
95
- Yields:
96
- TextFrame: Text frames with sample content.
97
- """
98
- yield TextFrame("Hello, ")
99
- trace.get_current_span().add_event("generated!")
100
- yield TextFrame("World")
101
-
102
- processor = TestProcessor()
103
- pipeline = Pipeline([processor])
104
-
105
- await run_test(
106
- pipeline,
107
- frames_to_send=[],
108
- expected_down_frames=[],
109
- )
110
-
111
-
112
- @pytest.mark.asyncio
113
- async def test_wrong_usage() -> None:
114
- """Test that error message is raised when a processor is not traceable."""
115
-
116
- class TestProcessor(FrameProcessor):
117
- @traced
118
- async def process_frame(self, frame: Frame, direction: FrameDirection):
119
- await super().process_frame(frame, direction)
120
- await super().push_frame(frame, direction)
121
-
122
- class HasSeenError(FrameProcessor):
123
- def __init__(self, **kwargs):
124
- super().__init__(**kwargs)
125
- self.seen_error = False
126
-
127
- async def process_frame(self, frame: Frame, direction: FrameDirection):
128
- await super().process_frame(frame, direction)
129
- if isinstance(frame, ErrorFrame):
130
- self.seen_error = True
131
- elif isinstance(frame, EndFrame):
132
- assert self.seen_error
133
- await super().push_frame(frame, direction)
134
-
135
- seen_error = HasSeenError()
136
- processor = TestProcessor()
137
- pipeline = Pipeline([seen_error, processor])
138
- await run_test(
139
- pipeline,
140
- frames_to_send=[],
141
- expected_down_frames=[],
142
- expected_up_frames=[
143
- ignore_ids(ErrorFrame("@traced annotation can only be used in classes inheriting from Traceable"))
144
- ],
145
- )
146
-
147
-
148
- @pytest.mark.asyncio
149
- async def test_no_processor() -> None:
150
- """Test that a processor can be used without a pipeline."""
151
-
152
- @traceable
153
- class TestTraceable:
154
- @traced
155
- async def traced_test(self):
156
- trace.get_current_span().add_event("I can use it as another utility function as well.")
157
-
158
- test = TestTraceable()
159
- await test.traced_test()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_transcription_sync_processors.py DELETED
@@ -1,262 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for transcript synchronization processors.
5
-
6
- This module contains tests that verify the behavior of transcript synchronization processors,
7
- including both user and bot transcript synchronization with different TTS providers.
8
- The tests ensure proper handling of speech events, transcriptions, and TTS frames.
9
- """
10
-
11
- import pytest
12
- from pipecat.frames.frames import (
13
- BotStartedSpeakingFrame,
14
- BotStoppedSpeakingFrame,
15
- InterimTranscriptionFrame,
16
- StartInterruptionFrame,
17
- TranscriptionFrame,
18
- TTSStartedFrame,
19
- TTSStoppedFrame,
20
- TTSTextFrame,
21
- UserStartedSpeakingFrame,
22
- UserStoppedSpeakingFrame,
23
- )
24
- from pipecat.tests.utils import SleepFrame
25
- from pipecat.utils.time import time_now_iso8601
26
-
27
- from nvidia_pipecat.frames.transcripts import (
28
- BotUpdatedSpeakingTranscriptFrame,
29
- UserStoppedSpeakingTranscriptFrame,
30
- UserUpdatedSpeakingTranscriptFrame,
31
- )
32
- from nvidia_pipecat.processors.transcript_synchronization import (
33
- BotTranscriptSynchronization,
34
- UserTranscriptSynchronization,
35
- )
36
- from tests.unit.utils import ignore_ids, run_test
37
-
38
-
39
- @pytest.mark.asyncio()
40
- async def test_user_transcript_synchronization_processor():
41
- """Test the UserTranscriptSynchronization processor functionality.
42
-
43
- Tests the complete flow of user speech transcription synchronization,
44
- including interim and final transcriptions.
45
-
46
- The test verifies:
47
- - User speech start/stop handling
48
- - Interim transcription processing
49
- - Speaking transcript updates
50
- - Final transcript generation
51
- - Frame sequence ordering
52
- - Multiple speech segment handling
53
- """
54
- user_id = ""
55
- interim_transcript_frames = [
56
- InterimTranscriptionFrame("Hi", user_id, time_now_iso8601()),
57
- InterimTranscriptionFrame("Hi there!", user_id, time_now_iso8601()),
58
- InterimTranscriptionFrame("How are", user_id, time_now_iso8601()),
59
- InterimTranscriptionFrame("How are you?", user_id, time_now_iso8601()),
60
- ]
61
- finale_transcript_frame1 = TranscriptionFrame("Hi there!", user_id, time_now_iso8601())
62
- finale_transcript_frame2 = TranscriptionFrame("How are you?", user_id, time_now_iso8601())
63
-
64
- frames_to_send = [
65
- UserStartedSpeakingFrame(),
66
- interim_transcript_frames[0],
67
- interim_transcript_frames[1],
68
- SleepFrame(0.1),
69
- UserStoppedSpeakingFrame(),
70
- finale_transcript_frame1,
71
- SleepFrame(0.1),
72
- UserStartedSpeakingFrame(),
73
- interim_transcript_frames[0],
74
- interim_transcript_frames[1],
75
- finale_transcript_frame1,
76
- interim_transcript_frames[2],
77
- interim_transcript_frames[3],
78
- finale_transcript_frame2,
79
- SleepFrame(0.1),
80
- UserStoppedSpeakingFrame(),
81
- ]
82
-
83
- expected_down_frames = [
84
- ignore_ids(UserStartedSpeakingFrame()),
85
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("user started speaking")),
86
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi")),
87
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there!")),
88
- ignore_ids(interim_transcript_frames[0]),
89
- ignore_ids(interim_transcript_frames[1]),
90
- ignore_ids(UserStoppedSpeakingFrame()),
91
- ignore_ids(UserStoppedSpeakingTranscriptFrame("Hi there!")),
92
- ignore_ids(finale_transcript_frame1),
93
- ignore_ids(UserStartedSpeakingFrame()),
94
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("user started speaking")),
95
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi")),
96
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there!")),
97
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there! How are")),
98
- ignore_ids(UserUpdatedSpeakingTranscriptFrame("Hi there! How are you?")),
99
- ignore_ids(interim_transcript_frames[0]),
100
- ignore_ids(interim_transcript_frames[1]),
101
- ignore_ids(finale_transcript_frame1),
102
- ignore_ids(interim_transcript_frames[2]),
103
- ignore_ids(interim_transcript_frames[3]),
104
- ignore_ids(finale_transcript_frame2),
105
- ignore_ids(UserStoppedSpeakingFrame()),
106
- ignore_ids(UserStoppedSpeakingTranscriptFrame("Hi there! How are you?")),
107
- ]
108
-
109
- await run_test(
110
- UserTranscriptSynchronization("user started speaking"),
111
- frames_to_send=frames_to_send,
112
- expected_down_frames=expected_down_frames,
113
- )
114
-
115
-
116
- @pytest.mark.asyncio()
117
- async def test_bot_transcript_synchronization_processor_with_riva_tts():
118
- """Test the BotTranscriptSynchronization processor with Riva TTS.
119
-
120
- Tests the synchronization of bot transcripts when using Riva TTS,
121
- including speech events and interruption handling.
122
-
123
- The test verifies:
124
- - Bot speech start/stop handling
125
- - TTS text frame processing
126
- - Speaking transcript updates
127
- - Interruption handling
128
- - Frame sequence ordering
129
- - Multiple sentence handling
130
- """
131
- tts_text_frames = [
132
- TTSTextFrame("Welcome user!"),
133
- TTSTextFrame("How are you?"),
134
- TTSTextFrame("Did you have a nice day?"),
135
- ]
136
-
137
- frames_to_send = [
138
- TTSStartedFrame(), # Bot sentence transcript 1
139
- tts_text_frames[0],
140
- SleepFrame(0.1), # Give time for transcript to be buffered
141
- BotStartedSpeakingFrame(), # Start playing sentence 1
142
- TTSStoppedFrame(),
143
- BotStoppedSpeakingFrame(), # End of playing sentence 1
144
- SleepFrame(0.1),
145
- TTSStartedFrame(), # Bot sentence transcript 2
146
- tts_text_frames[1],
147
- SleepFrame(0.1), # Give time for transcript to be buffered
148
- BotStartedSpeakingFrame(), # Start playing sentence 2
149
- TTSStoppedFrame(),
150
- BotStoppedSpeakingFrame(), # End of playing sentence 2
151
- SleepFrame(0.1),
152
- TTSStartedFrame(), # Bot sentence transcript 3
153
- tts_text_frames[2],
154
- SleepFrame(0.1), # Give time for transcript to be buffered
155
- BotStartedSpeakingFrame(), # Start playing sentence 3
156
- TTSStoppedFrame(),
157
- BotStoppedSpeakingFrame(), # End of playing sentence 3
158
- SleepFrame(0.1),
159
- StartInterruptionFrame(), # User interrupts
160
- TTSStartedFrame(), # Bot sentence 1 again
161
- tts_text_frames[0],
162
- SleepFrame(0.1), # Give time for transcript to be buffered
163
- BotStartedSpeakingFrame(), # Start playing sentence 1
164
- TTSStoppedFrame(),
165
- BotStoppedSpeakingFrame(), # End of playing sentence 1
166
- ]
167
-
168
- expected_down_frames = [
169
- ignore_ids(TTSStartedFrame()),
170
- ignore_ids(tts_text_frames[0]),
171
- ignore_ids(BotStartedSpeakingFrame()),
172
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user!")),
173
- ignore_ids(BotStoppedSpeakingFrame()),
174
- ignore_ids(TTSStoppedFrame()),
175
- ignore_ids(TTSStartedFrame()),
176
- ignore_ids(tts_text_frames[1]),
177
- ignore_ids(BotStartedSpeakingFrame()),
178
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("How are you?")),
179
- ignore_ids(BotStoppedSpeakingFrame()),
180
- ignore_ids(TTSStoppedFrame()),
181
- ignore_ids(TTSStartedFrame()),
182
- ignore_ids(tts_text_frames[2]),
183
- ignore_ids(BotStartedSpeakingFrame()),
184
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Did you have a nice day?")),
185
- ignore_ids(BotStoppedSpeakingFrame()),
186
- ignore_ids(TTSStoppedFrame()),
187
- ignore_ids(StartInterruptionFrame()),
188
- ignore_ids(TTSStartedFrame()),
189
- ignore_ids(tts_text_frames[0]),
190
- ignore_ids(BotStartedSpeakingFrame()),
191
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user!")),
192
- ignore_ids(BotStoppedSpeakingFrame()),
193
- ignore_ids(TTSStoppedFrame()),
194
- ]
195
-
196
- await run_test(
197
- BotTranscriptSynchronization(),
198
- frames_to_send=frames_to_send,
199
- expected_down_frames=expected_down_frames,
200
- )
201
-
202
-
203
- @pytest.mark.asyncio()
204
- async def test_bot_transcript_synchronization_processor_with_elevenlabs_tts():
205
- """Test the BotTranscriptSynchronization processor with ElevenLabs TTS.
206
-
207
- Tests the synchronization of bot transcripts when using ElevenLabs TTS,
208
- including partial text handling and concatenation.
209
-
210
- The test verifies:
211
- - Bot speech start/stop handling
212
- - Partial TTS text processing
213
- - Speaking transcript updates
214
- - Text concatenation
215
- - Frame sequence ordering
216
- - Complete transcript assembly
217
- """
218
- tts_text_frames = [
219
- TTSTextFrame("Welcome"),
220
- TTSTextFrame("user!"),
221
- TTSTextFrame("How"),
222
- TTSTextFrame("are"),
223
- TTSTextFrame("you?"),
224
- ]
225
-
226
- frames_to_send = [
227
- TTSStartedFrame(),
228
- tts_text_frames[0],
229
- SleepFrame(0.1),
230
- BotStartedSpeakingFrame(),
231
- tts_text_frames[1],
232
- tts_text_frames[2],
233
- tts_text_frames[3],
234
- tts_text_frames[4],
235
- SleepFrame(0.1),
236
- TTSStoppedFrame(),
237
- SleepFrame(0.1),
238
- BotStoppedSpeakingFrame(),
239
- ]
240
-
241
- expected_down_frames = [
242
- ignore_ids(TTSStartedFrame()),
243
- ignore_ids(tts_text_frames[0]),
244
- ignore_ids(BotStartedSpeakingFrame()),
245
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome")),
246
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user!")),
247
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user! How")),
248
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user! How are")),
249
- ignore_ids(BotUpdatedSpeakingTranscriptFrame("Welcome user! How are you?")),
250
- ignore_ids(tts_text_frames[1]),
251
- ignore_ids(tts_text_frames[2]),
252
- ignore_ids(tts_text_frames[3]),
253
- ignore_ids(tts_text_frames[4]),
254
- ignore_ids(TTSStoppedFrame()),
255
- ignore_ids(BotStoppedSpeakingFrame()),
256
- ]
257
-
258
- await run_test(
259
- BotTranscriptSynchronization(),
260
- frames_to_send=frames_to_send,
261
- expected_down_frames=expected_down_frames,
262
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_user_presence.py DELETED
@@ -1,157 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Unit tests for the user presence frame processor."""
5
-
6
- import pytest
7
- from pipecat.frames.frames import (
8
- FilterControlFrame,
9
- LLMUpdateSettingsFrame,
10
- StartInterruptionFrame,
11
- TextFrame,
12
- TTSSpeakFrame,
13
- TTSStartedFrame,
14
- UserStartedSpeakingFrame,
15
- )
16
- from pipecat.tests.utils import SleepFrame
17
-
18
- from nvidia_pipecat.frames.action import FinishedPresenceUserActionFrame, StartedPresenceUserActionFrame
19
- from nvidia_pipecat.processors.user_presence import UserPresenceProcesssor
20
- from tests.unit.utils import ignore, run_test
21
-
22
-
23
- @pytest.mark.asyncio
24
- async def test_user_presence_start():
25
- """Tests user presence start handling with welcome message.
26
-
27
- Tests that the processor correctly handles user presence start events
28
- and generates appropriate welcome messages.
29
-
30
- Args:
31
- None
32
-
33
- Returns:
34
- None
35
-
36
- The test verifies:
37
- - StartedPresenceUserActionFrame processing
38
- - Welcome message generation
39
- - UserStartedSpeakingFrame handling
40
- - Frame sequence ordering
41
- - Message content accuracy
42
- """
43
- user_presence_bot = UserPresenceProcesssor(welcome_msg="Hey there!")
44
- frames_to_send = [StartedPresenceUserActionFrame(action_id=123), SleepFrame(0.01), UserStartedSpeakingFrame()]
45
- expected_down_frames = [
46
- ignore(StartedPresenceUserActionFrame(action_id=123), "ids", "timestamps"),
47
- ignore(TTSSpeakFrame("Hey there!"), "ids"),
48
- ignore(UserStartedSpeakingFrame(), "ids", "timestamps"),
49
- ]
50
-
51
- await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
52
-
53
-
54
- @pytest.mark.asyncio
55
- async def test_user_presence_finished():
56
- """Tests user presence finish handling with farewell message.
57
-
58
- Tests that the processor correctly handles user presence finish events
59
- and generates appropriate farewell messages.
60
-
61
- Args:
62
- None
63
-
64
- Returns:
65
- None
66
-
67
- The test verifies:
68
- - StartedPresenceUserActionFrame processing
69
- - FinishedPresenceUserActionFrame handling
70
- - Farewell message generation
71
- - StartInterruptionFrame sequencing
72
- - Frame ordering
73
- - Message content accuracy
74
- """
75
- user_presence_bot = UserPresenceProcesssor(farewell_msg="Bye bye!")
76
- frames_to_send = [
77
- StartedPresenceUserActionFrame(action_id=123),
78
- SleepFrame(0.5),
79
- FinishedPresenceUserActionFrame(action_id=123),
80
- SleepFrame(0.5),
81
- ]
82
-
83
- expected_down_frames = [
84
- ignore(StartedPresenceUserActionFrame(action_id=123), "ids", "timestamps"),
85
- ignore(TTSSpeakFrame("Hello"), "ids"),
86
- # WAR: The StartInterruptionFrame is sent in response to the FinishedPresenceUserActionFrame.
87
- # However, the test framework consistently logs it in the sequence prior to the FinishedPresenceUserFrame
88
- ignore(StartInterruptionFrame(), "ids", "timestamps"),
89
- ignore(FinishedPresenceUserActionFrame(action_id=123), "ids", "timestamps"),
90
- ignore(TTSSpeakFrame("Bye bye!"), "ids"),
91
- ]
92
-
93
- await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
94
-
95
-
96
- @pytest.mark.asyncio
97
- async def test_user_presence():
98
- """Tests behavior without presence frames.
99
-
100
- Tests that no welcome/farewell messages are sent when no presence
101
- frames are received.
102
-
103
- Args:
104
- None
105
-
106
- Returns:
107
- None
108
-
109
- The test verifies:
110
- - No messages without presence frames
111
- - UserStartedSpeakingFrame handling
112
- - Frame filtering behavior
113
- """
114
- user_presence_bot = UserPresenceProcesssor(welcome_msg="Hello", farewell_msg="Bye bye!")
115
- frames_to_send = [UserStartedSpeakingFrame()]
116
- expected_down_frames = []
117
-
118
- await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
119
-
120
-
121
- @pytest.mark.asyncio
122
- async def test_user_presence_system_frames():
123
- """Tests system and control frame handling.
124
-
125
- Tests that system and control frames are processed regardless of
126
- user presence state.
127
-
128
- Args:
129
- None
130
-
131
- Returns:
132
- None
133
-
134
- The test verifies:
135
- - TTSStartedFrame processing
136
- - FilterControlFrame handling
137
- - LLMUpdateSettingsFrame processing
138
- - TextFrame filtering
139
- - Frame passthrough behavior
140
- - Frame sequence preservation
141
- """
142
- user_presence_bot = UserPresenceProcesssor(welcome_msg="Hello", farewell_msg="Bye bye!")
143
-
144
- frames_to_send = [
145
- TTSStartedFrame(),
146
- FilterControlFrame(),
147
- LLMUpdateSettingsFrame(settings="ABC"),
148
- TextFrame("How are you?"),
149
- ]
150
-
151
- expected_down_frames = [
152
- ignore(TTSStartedFrame(), "ids", "timestamps"),
153
- ignore(FilterControlFrame(), "ids", "timestamps"),
154
- ignore(LLMUpdateSettingsFrame(settings="ABC"), "ids", "timestamps"),
155
- ]
156
-
157
- await run_test(user_presence_bot, frames_to_send=frames_to_send, expected_down_frames=expected_down_frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/test_utils.py DELETED
@@ -1,95 +0,0 @@
1
- """Unit tests for utility processors."""
2
-
3
- import pytest
4
- from pipecat.frames.frames import Frame, TextFrame
5
-
6
- from nvidia_pipecat.processors.utils import FrameBlockingProcessor
7
- from tests.unit.utils import ignore_ids, run_test
8
-
9
-
10
- class TestFrame(Frame):
11
- """Test frame for testing FrameBlockingProcessor."""
12
-
13
- pass
14
-
15
-
16
- class TestResetFrame(Frame):
17
- """Test frame for testing FrameBlockingProcessor reset."""
18
-
19
- pass
20
-
21
-
22
- @pytest.mark.asyncio()
23
- async def test_frame_blocking_processor():
24
- """Test that FrameBlockingProcessor blocks frames after threshold.
25
-
26
- Verifies that:
27
- - Frames are passed through until threshold is reached
28
- - Frames are blocked after threshold
29
- - Non-matching frame types are always passed through
30
- """
31
- # Create processor that blocks after 2 TextFrames
32
- processor = FrameBlockingProcessor(block_after_frame=2, frame_type=TextFrame)
33
-
34
- # Create test frames - mix of TextFrames and StartFrames
35
- frames_to_send = [
36
- TestFrame(), # Should pass through
37
- TextFrame("First"), # Should pass through
38
- TextFrame("Second"), # Should pass through
39
- TextFrame("Third"), # Should be blocked
40
- TestFrame(), # Should pass through
41
- TextFrame("Fourth"), # Should be blocked
42
- ]
43
-
44
- # Expected frames - all StartFrames and first 2 TextFrames
45
- expected_down_frames = [
46
- ignore_ids(TestFrame()),
47
- ignore_ids(TextFrame("First")),
48
- ignore_ids(TextFrame("Second")),
49
- ignore_ids(TestFrame()),
50
- ]
51
-
52
- await run_test(
53
- processor,
54
- frames_to_send=frames_to_send,
55
- expected_down_frames=expected_down_frames,
56
- )
57
-
58
-
59
- @pytest.mark.asyncio()
60
- async def test_frame_blocking_processor_with_reset():
61
- """Test that FrameBlockingProcessor resets counter on reset frame type.
62
-
63
- Verifies that:
64
- - Counter resets when reset frame type is received
65
- - Frames are blocked after threshold
66
- - Counter can be reset multiple times
67
- """
68
- # Create processor that blocks after 2 TextFrames and resets on StartFrame
69
- processor = FrameBlockingProcessor(block_after_frame=2, frame_type=TextFrame, reset_frame_type=TestResetFrame)
70
-
71
- # Create test frames - mix of TextFrames and TestResetFrame
72
- frames_to_send = [
73
- TextFrame("First"), # Should pass through
74
- TextFrame("Second"), # Should pass through
75
- TextFrame("Third"), # Should be blocked
76
- TestResetFrame(), # Should reset counter and pass through
77
- TextFrame("Fourth"), # Should pass through (counter reset)
78
- TextFrame("Fifth"), # Should pass through
79
- TextFrame("Sixth"), # Should be blocked
80
- ]
81
-
82
- # Expected frames - all TestResetFrame and TextFrames except blocked ones
83
- expected_down_frames = [
84
- ignore_ids(TextFrame("First")),
85
- ignore_ids(TextFrame("Second")),
86
- ignore_ids(TestResetFrame()),
87
- ignore_ids(TextFrame("Fourth")),
88
- ignore_ids(TextFrame("Fifth")),
89
- ]
90
-
91
- await run_test(
92
- processor,
93
- frames_to_send=frames_to_send,
94
- expected_down_frames=expected_down_frames,
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/unit/utils.py DELETED
@@ -1,428 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: BSD 2-Clause License
3
-
4
- """Utility functions and classes for testing pipelines and processors.
5
-
6
- This module provides various test utilities including frame processors, test runners,
7
- and helper functions for frame comparison in tests.
8
- """
9
-
10
- import asyncio
11
- from asyncio.tasks import sleep
12
- from collections.abc import Sequence
13
- from copy import copy
14
- from dataclasses import dataclass
15
- from datetime import datetime, timedelta
16
- from typing import Any
17
- from unittest.mock import ANY
18
-
19
- import numpy as np
20
- from pipecat.frames.frames import EndFrame, Frame, InputAudioRawFrame, StartFrame
21
- from pipecat.pipeline.pipeline import Pipeline
22
- from pipecat.pipeline.runner import PipelineRunner
23
- from pipecat.pipeline.task import PipelineParams, PipelineTask
24
- from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
25
- from pipecat.tests.utils import QueuedFrameProcessor
26
- from pipecat.tests.utils import run_test as run_pipecat_test
27
-
28
- from nvidia_pipecat.frames.action import ActionFrame
29
-
30
-
31
- def ignore(frame: Frame, *args: str):
32
- """Return a copy of the frames with attributes listed in *args set to unittest.mock.ANY.
33
-
34
- Any attribute listed in args will be ignored when using comparisons against the returned frame.
35
-
36
- Args:
37
- frame (Frame): Frame to create a copy from and set selected attributes to unittest.mock.ANY.
38
- *args (str): Attribute names. Special values to ignore common sets of attributes:
39
- 'ids': ignore standard frame attributes related to frame ids
40
- 'all_ids': ignore standard frame attributes related to frame and action ids
41
- 'timestamps': ignore standard frame attributes containing timestamps
42
-
43
- Returns:
44
- Frame: A copy of the input frame with specified attributes set to ANY.
45
-
46
- Raises:
47
- ValueError: If an attribute name is not found in the frame.
48
- """
49
- new_frame = copy(frame)
50
- for arg in args:
51
- is_updated = False
52
- if arg == "all_ids" or arg == "ids":
53
- new_frame.id = ANY
54
- new_frame.name = ANY
55
- is_updated = True
56
- if arg == "all_ids":
57
- new_frame.action_id = ANY
58
- is_updated = True
59
- if arg == "timestamps":
60
- new_frame.pts = ANY
61
- new_frame.action_started_at = ANY
62
- new_frame.action_finished_at = ANY
63
- new_frame.action_updated_at = ANY
64
- is_updated = True
65
-
66
- if hasattr(new_frame, arg):
67
- setattr(new_frame, arg, ANY)
68
- elif not is_updated:
69
- raise ValueError(
70
- f"Frame {frame.__class__.__name__} has not attribute '{arg}' to ignore. Did you misspell it?"
71
- )
72
- return new_frame
73
-
74
-
75
- def ignore_ids(frame: Frame) -> Frame:
76
- """Return a copy of the frame that matches any id and name.
77
-
78
- This is useful if you do not want to assert a specific ID and name.
79
-
80
- Args:
81
- frame (Frame): The frame to create a copy from.
82
-
83
- Returns:
84
- Frame: A copy of the input frame with ID and name set to ANY.
85
- """
86
- new_frame = copy(frame)
87
- new_frame.id = ANY
88
- new_frame.name = ANY
89
- if isinstance(frame, ActionFrame):
90
- new_frame.action_id = ANY
91
- return new_frame
92
-
93
-
94
- def ignore_timestamps(frame: Frame) -> Frame:
95
- """Return a copy of the frame that matches frames ignoring any timestamps.
96
-
97
- Args:
98
- frame (Frame): The frame to create a copy from.
99
-
100
- Returns:
101
- Frame: A copy of the input frame with all timestamp fields set to ANY.
102
- """
103
- new_frame = copy(frame)
104
- new_frame.pts = ANY
105
- new_frame.action_started_at = ANY
106
- new_frame.action_finished_at = ANY
107
- new_frame.action_updated_at = ANY
108
- return new_frame
109
-
110
-
111
- @dataclass
112
- class FrameHistoryEntry:
113
- """Storing a frame and the frame direction.
114
-
115
- Attributes:
116
- frame (Frame): The stored frame.
117
- direction (FrameDirection): The direction of the frame in the pipeline.
118
- """
119
-
120
- frame: Frame
121
- direction: FrameDirection
122
-
123
-
124
- class FrameStorage(FrameProcessor):
125
- """A frame processor that stores all received frames in memory for inspection.
126
-
127
- This processor maintains a history of all frames that pass through it, along with their
128
- direction, allowing for later inspection and verification in tests.
129
-
130
- Attributes:
131
- history (list[FrameHistoryEntry]): List of all frames that have passed through the processor.
132
- """
133
-
134
- def __init__(self, **kwargs) -> None:
135
- """Initialize the FrameStorage processor.
136
-
137
- Args:
138
- **kwargs: Additional keyword arguments passed to the parent FrameProcessor.
139
- """
140
- super().__init__(**kwargs)
141
- self.history: list[FrameHistoryEntry] = []
142
-
143
- async def process_frame(self, frame: Frame, direction: FrameDirection):
144
- """Process an incoming frame by storing it in history and forwarding it.
145
-
146
- Args:
147
- frame (Frame): The frame to process.
148
- direction (FrameDirection): The direction the frame is traveling in the pipeline.
149
- """
150
- await super().process_frame(frame, direction)
151
-
152
- self.history.append(FrameHistoryEntry(frame, direction))
153
- await self.push_frame(frame, direction)
154
-
155
- def frames_of_type(self, t) -> list[Frame]:
156
- """Get all frames of a specific type from the history.
157
-
158
- Args:
159
- t: The type of frames to filter for.
160
-
161
- Returns:
162
- list[Frame]: List of all frames matching the specified type.
163
- """
164
- return [e.frame for e in self.history if isinstance(e.frame, t)]
165
-
166
- async def wait_for_frame(self, frame: Frame, timeout: timedelta = timedelta(seconds=5.0)) -> None:
167
- """Block until a matching frame is found in history.
168
-
169
- Args:
170
- frame (Frame): The frame to wait for.
171
- timeout (timedelta, optional): Maximum time to wait. Defaults to 5 seconds.
172
-
173
- Raises:
174
- TimeoutError: If the frame is not found within the timeout period.
175
- """
176
- candidates = {}
177
-
178
- def is_same_frame(frame_a, frame_b) -> bool:
179
- if type(frame_a) is not type(frame_b):
180
- return False
181
-
182
- if frame_a == frame_b:
183
- return True
184
- else:
185
- candidates[frame_a.id] = frame_a.__repr__()
186
- return False
187
-
188
- found = False
189
- start_time = datetime.now()
190
- while not found:
191
- found = any([is_same_frame(entry.frame, frame) for entry in self.history])
192
- if not found:
193
- if datetime.now() - start_time > timeout:
194
- raise TimeoutError(
195
- "Frame not found until timeout reached.\n"
196
- f"EXPECTED:\n{frame.__repr__()}\n"
197
- f"FOUND\n{'\n'.join(candidates.values())}"
198
- )
199
- await asyncio.sleep(0.01)
200
-
201
-
202
- class SinusWaveProcessor(FrameProcessor):
203
- """A frame processor that generates a sine wave audio signal.
204
-
205
- This processor generates a continuous sine wave at 440 Hz (A4 note) when started,
206
- and outputs it as audio frames. It is useful for testing audio processing pipelines.
207
-
208
- Attributes:
209
- duration (timedelta): The total duration of the sine wave to generate.
210
- chunk_duration (float): Duration of each audio chunk in seconds.
211
- audio_frame_count (int): Total number of audio frames to generate.
212
- audio_task (asyncio.Task | None): Task handling the audio generation.
213
- """
214
-
215
- def __init__(
216
- self,
217
- *,
218
- duration: timedelta,
219
- **kwargs,
220
- ):
221
- """Initialize the SinusWaveProcessor.
222
-
223
- Args:
224
- duration (timedelta): The total duration of the sine wave to generate.
225
- **kwargs: Additional keyword arguments passed to the parent FrameProcessor.
226
- """
227
- super().__init__(**kwargs)
228
- self.duration = duration
229
- self.audio_task: asyncio.Task | None = None
230
-
231
- self.chunk_duration = 0.02 # 20 milliseconds
232
- self.audio_frame_count = round(self.duration.total_seconds() / self.chunk_duration)
233
-
234
- async def _cancel_audio_task(self):
235
- """Cancel the running audio generation task if it exists."""
236
- if self.audio_task and not self.audio_task.done():
237
- await self.cancel_task(self.audio_task)
238
- self.audio_task = None
239
-
240
- async def stop(self):
241
- """Stop the audio generation by canceling the audio task."""
242
- await self._cancel_audio_task()
243
-
244
- async def cleanup(self):
245
- """Clean up resources by stopping audio generation and calling parent cleanup."""
246
- await super().cleanup()
247
- await self._cancel_audio_task()
248
-
249
- async def process_frame(self, frame: Frame, direction: FrameDirection):
250
- """Process incoming frames to start/stop audio generation.
251
-
252
- Starts audio generation on StartFrame and stops on EndFrame.
253
-
254
- Args:
255
- frame (Frame): The frame to process.
256
- direction (FrameDirection): The direction the frame is traveling in the pipeline.
257
- """
258
- await super().process_frame(frame, direction)
259
- if isinstance(frame, StartFrame):
260
- self.audio_task = self.create_task(self.start())
261
- if isinstance(frame, EndFrame):
262
- await self.stop()
263
- await super().push_frame(frame, direction)
264
-
265
- async def start(self):
266
- """Start generating and outputting sine wave audio frames.
267
-
268
- Generates a continuous 440 Hz sine wave, split into 20ms chunks,
269
- and outputs them as InputAudioRawFrame instances.
270
- """
271
- sample_rate = 16000 # Hz
272
- frequency = 440 # Hz (A4 tone)
273
-
274
- phase_offset = 0
275
- for _ in range(self.audio_frame_count):
276
- chunk_samples = int(sample_rate * self.chunk_duration)
277
-
278
- # Generate the time axis
279
- t = np.arange(chunk_samples) / sample_rate
280
-
281
- # Create the sine wave with the given phase offset
282
- sine_wave = 0.5 * np.sin(2 * np.pi * frequency * t + phase_offset)
283
-
284
- # Calculate the new phase offset for the next chunk
285
- phase_offset = (2 * np.pi * frequency * self.chunk_duration + phase_offset) % (2 * np.pi)
286
-
287
- sine_wave_pcm = (sine_wave * 32767).astype(np.int16)
288
-
289
- pipecat_frame = InputAudioRawFrame(audio=sine_wave_pcm.tobytes(), sample_rate=sample_rate, num_channels=1)
290
- await super().push_frame(pipecat_frame)
291
-
292
- await sleep(self.chunk_duration)
293
-
294
-
295
- async def run_test(
296
- processor: FrameProcessor,
297
- *,
298
- frames_to_send: Sequence[Frame],
299
- expected_down_frames: Sequence[Frame],
300
- expected_up_frames: Sequence[Frame] = [],
301
- ignore_start: bool = True,
302
- start_metadata: dict[str, Any] | None = None,
303
- send_end_frame: bool = True,
304
- ) -> tuple[Sequence[Frame], Sequence[Frame]]:
305
- """Run a test on a frame processor with predefined input and expected output frames.
306
-
307
- Args:
308
- processor (FrameProcessor): The processor to test.
309
- frames_to_send (Sequence[Frame]): Frames to send through the processor.
310
- expected_down_frames (Sequence[Frame]): Expected frames in downstream direction.
311
- expected_up_frames (Sequence[Frame], optional): Expected frames in upstream direction.
312
- Defaults to [].
313
- ignore_start (bool, optional): Whether to ignore start frames. Defaults to True.
314
- start_metadata (dict[str, Any], optional): Metadata to include in start frame.
315
- Defaults to None.
316
- send_end_frame (bool, optional): Whether to send an end frame. Defaults to True.
317
-
318
- Returns:
319
- tuple[Sequence[Frame], Sequence[Frame]]: Tuple of (received downstream frames, received upstream frames).
320
-
321
- Raises:
322
- AssertionError: If received frames don't match expected frames.
323
- """
324
- if start_metadata is None:
325
- start_metadata = {}
326
- received_down_frames, received_up_frames = await run_pipecat_test(
327
- processor,
328
- frames_to_send=frames_to_send,
329
- expected_down_frames=[f.__class__ for f in expected_down_frames],
330
- expected_up_frames=[f.__class__ for f in expected_up_frames],
331
- ignore_start=ignore_start,
332
- start_metadata=start_metadata,
333
- send_end_frame=send_end_frame,
334
- )
335
-
336
- for real, expected in zip(received_up_frames, expected_up_frames, strict=True):
337
- assert real == expected, f"Frame mismatch: \nreal: {repr(real)} \nexpected: {repr(expected)}"
338
-
339
- for real, expected in zip(received_down_frames, expected_down_frames, strict=True):
340
- assert real == expected, f"Frame mismatch: \nreal: {repr(real)} \nexpected: {repr(expected)}"
341
-
342
- return received_down_frames, received_up_frames
343
-
344
-
345
- async def run_interactive_test(
346
- processor: FrameProcessor,
347
- *,
348
- test_coroutine,
349
- start_metadata: dict[str, Any] | None = None,
350
- ignore_start: bool = True,
351
- send_end_frame: bool = True,
352
- ) -> tuple[Sequence[Frame], Sequence[Frame]]:
353
- """Run an interactive test on a frame processor with a custom test coroutine.
354
-
355
- This function allows for more complex testing scenarios where frames need to be
356
- sent and received dynamically during the test.
357
-
358
- Args:
359
- processor (FrameProcessor): The processor to test.
360
- test_coroutine: Coroutine function that implements the test logic.
361
- start_metadata (dict[str, Any], optional): Metadata to include in start frame.
362
- Defaults to None.
363
- ignore_start (bool, optional): Whether to ignore start frames. Defaults to True.
364
- send_end_frame (bool, optional): Whether to send an end frame. Defaults to True.
365
-
366
- Returns:
367
- tuple[Sequence[Frame], Sequence[Frame]]: Tuple of (received downstream frames, received upstream frames).
368
- """
369
- if start_metadata is None:
370
- start_metadata = {}
371
- received_up = asyncio.Queue()
372
- received_down = asyncio.Queue()
373
- source = QueuedFrameProcessor(
374
- queue=received_up,
375
- queue_direction=FrameDirection.UPSTREAM,
376
- ignore_start=ignore_start,
377
- )
378
- sink = QueuedFrameProcessor(
379
- queue=received_down,
380
- queue_direction=FrameDirection.DOWNSTREAM,
381
- ignore_start=ignore_start,
382
- )
383
-
384
- pipeline = Pipeline([source, processor, sink])
385
-
386
- task = PipelineTask(pipeline, params=PipelineParams(start_metadata=start_metadata))
387
-
388
- async def run_test():
389
- # Just give a little head start to the runner.
390
- await asyncio.sleep(0.01)
391
- await test_coroutine(task)
392
-
393
- if send_end_frame:
394
- await task.queue_frame(EndFrame())
395
-
396
- # await asyncio.sleep(1.5)
397
- # debug = ""
398
- # for t in processor.get_task_manager().current_tasks():
399
- # debug += f"\nMy Task {t.get_name()} is still running."
400
-
401
- # assert False, debug
402
- # print(debug)
403
-
404
- runner = PipelineRunner()
405
- await asyncio.gather(runner.run(task), run_test())
406
-
407
- #
408
- # Down frames
409
- #
410
- received_down_frames: list[Frame] = []
411
- while not received_down.empty():
412
- frame = await received_down.get()
413
- if not isinstance(frame, EndFrame) or not send_end_frame:
414
- received_down_frames.append(frame)
415
-
416
- print("received DOWN frames =", received_down_frames)
417
-
418
- #
419
- # Up frames
420
- #
421
- received_up_frames: list[Frame] = []
422
- while not received_up.empty():
423
- frame = await received_up.get()
424
- received_up_frames.append(frame)
425
-
426
- print("received UP frames =", received_up_frames)
427
-
428
- return (received_down_frames, received_up_frames)