""" This script implements a Voice Activity Detection (VAD) system using a Convolutional Neural Network (CNN) model. The script uses Gradio for the user interface, NumPy and SciPy for signal processing, and ONNX Runtime for model inference. Modules: - gradio: For creating the web interface. - numpy: For numerical operations. - scipy: For signal processing. - onnxruntime: For running the ONNX model. - threading: For running model inference in a separate thread. - queue: For thread-safe communication between threads. """ import atexit import queue import threading import time import gradio as gr import numpy as np import onnxruntime as ort from scipy import signal from scipy.fft import rfft # Parameters ORIGINAL_FS = 48_000 # Original sampling frequency in Hz DECIMATION_FACTOR = 3 # Decimation factor TARGET_FS = ORIGINAL_FS // DECIMATION_FACTOR # Target sampling frequency in Hz FRAME_DURATION_MS = 25 # Frame duration in milliseconds FRAME_SIZE = int(ORIGINAL_FS * FRAME_DURATION_MS / 1000) # Frame size FRAME_STEP = FRAME_SIZE // 2 # 50% overlap # FFT parameters DECIMATED_FRAME_SIZE = FRAME_SIZE // DECIMATION_FACTOR FFT_SIZE = 1 while FFT_SIZE < DECIMATED_FRAME_SIZE: FFT_SIZE <<= 1 # Mel spectrogram parameters N_FILTERS = 40 # Number of mel filters FREQ_LOW = 300 # Lower frequency bound FREQ_HIGH = 8000 # Upper frequency bound # Model evaluation parameters INITIAL_FRAMES = 40 # Initial frames required for first model evaluation FRAMES_PER_EVAL = 20 # Run model every this many new frames after initial # EWMA parameter ALPHA = 0.9 # Weight for current observation in EWMA # Filter coefficients FILTER_COEFFICIENTS = np.array([0.0625, 0.125, 0.25, 0.5, 0.25, 0.125, 0.0625]) # Create window function (Hanning window) window = np.hanning(DECIMATED_FRAME_SIZE) # Create mel filterbank matrix def create_mel_filterbank(n_filt, freq_low, freq_high, n_fft, fs): """ Create a mel filterbank matrix for computing the mel spectrogram. Args: n_filt: Number of mel filters freq_low: Lower frequency bound freq_high: Upper frequency bound n_fft: FFT size fs: Sampling frequency Returns: filterbank: Mel filterbank matrix of shape (n_fft//2+1, n_filt) """ # Convert Hz to mel def hz_to_mel(hz): return 1125 * np.log(1 + hz / 700) # Convert mel to Hz def mel_to_hz(mel): return 700 * (np.exp(mel / 1125) - 1) # Compute points evenly spaced in mel scale lower_mel = hz_to_mel(freq_low) higher_mel = hz_to_mel(freq_high) mel_points = np.linspace(lower_mel, higher_mel, n_filt + 2) # Convert mel points to Hz hz_points = mel_to_hz(mel_points) # Convert Hz points to FFT bin indices bin_indices = np.floor((n_fft + 1) * hz_points / fs).astype(int) # Create filterbank matrix filterbank = np.zeros((n_fft // 2 + 1, n_filt)) for i in range(n_filt): # For each filter, create a triangular filter for j in range(bin_indices[i], bin_indices[i + 1]): filterbank[j, i] = (j - bin_indices[i]) / ( bin_indices[i + 1] - bin_indices[i] ) for j in range(bin_indices[i + 1], bin_indices[i + 2]): filterbank[j, i] = (bin_indices[i + 2] - j) / ( bin_indices[i + 2] - bin_indices[i + 1] ) return filterbank # Create mel filterbank mel_filterbank = create_mel_filterbank( N_FILTERS, FREQ_LOW, FREQ_HIGH, FFT_SIZE, TARGET_FS ) # Initialize mel spectrogram image mel_spectrogram_image = np.zeros((N_FILTERS, N_FILTERS), dtype=np.float32) # Load the ONNX model session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL session_options.intra_op_num_threads = 2 session = ort.InferenceSession("./model/cnn-vad.onnx", session_options) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name # Create thread-safe queues for ML tasks and results inference_queue = queue.Queue(maxsize=10) # Queue for pending inference tasks result_queue = queue.Queue(maxsize=10) # Queue for inference results stop_event = threading.Event() # Event to signal thread termination # ML worker thread function def ml_worker(): """Worker thread function for running ML inference.""" while not stop_event.is_set(): try: # Try to get a task with timeout to allow checking stop_event periodically mel_image = inference_queue.get(timeout=0.1) # Run inference outputs = session.run( [output_name], { input_name: mel_image.astype(np.float32).reshape( -1, N_FILTERS, N_FILTERS ) }, )[0] # Get speech probability score and put in result queue current_score = outputs[0][1] result_queue.put(current_score) # Mark task as done inference_queue.task_done() except queue.Empty: # No pending tasks, just continue continue except Exception as e: print(f"Error in ML worker thread: {e}") # Still mark task as done if there was one if not inference_queue.empty(): inference_queue.task_done() # Start ML worker thread ml_thread = threading.Thread(target=ml_worker, daemon=True) ml_thread.start() # Initialize stream state initial_state = { "buffer": np.zeros(0, dtype=np.float32), # Audio buffer "frames_processed": 0, # Number of frames processed "ewma_score": 0.0, # EWMA of speech detection score "mel_image": np.zeros( (N_FILTERS, N_FILTERS), dtype=np.float32 ), # Mel spectrogram image "inference_pending": False, # Flag to track if inference is pending "last_inference_time": 0, # Timestamp of last inference } def resample_audio(audio, orig_sr, target_sr): """ Resample audio from original sample rate to target sample rate. Simple implementation using linear interpolation. Args: audio: Audio data orig_sr: Original sample rate target_sr: Target sample rate Returns: Resampled audio """ if orig_sr == target_sr: return audio # For more accurate resampling in production, consider using: # from scipy import signal return signal.resample_poly(audio, target_sr, orig_sr) # # Simple linear interpolation (less accurate but faster) # duration = len(audio) / orig_sr # new_length = int(duration * target_sr) # indices = np.linspace(0, len(audio) - 1, new_length) # indices = indices.astype(np.int32) # return audio[indices] def process_frame(frame, mel_image): """ Process a single frame of audio data. Args: frame: Audio frame data mel_image: Current mel spectrogram image Returns: Updated mel spectrogram image """ # 1. Apply FIR filter filtered_frame = signal.lfilter(FILTER_COEFFICIENTS, [1.0], frame) # Clamp values filtered_frame = np.clip(filtered_frame, -1.0, 1.0) # 2. Decimate the frame decimated_frame = filtered_frame[::DECIMATION_FACTOR] # Apply window and pad with zeros to FFT_SIZE padded_frame = np.zeros(FFT_SIZE) padded_frame[:DECIMATED_FRAME_SIZE] = decimated_frame * window # 3. Perform FFT fft_result = rfft(padded_frame) # Compute power spectrum power_spectrum = np.abs(fft_result) ** 2 # 4. Calculate mel spectrogram mel_power = np.dot(power_spectrum[: FFT_SIZE // 2 + 1], mel_filterbank) # Log mel spectrogram mel_power = np.log(mel_power + 1e-8) # Update mel spectrogram image (shift up and add new row at bottom) mel_image = np.roll(mel_image, -1, axis=0) mel_image[-1] = mel_power return mel_image def detect(state, new_chunk): """ Detects speech in an audio stream using a Voice Activity Detection (VAD) model. Args: state: Current state dictionary containing: - buffer: Audio buffer - frames_processed: Number of frames processed so far - ewma_score: Exponentially weighted moving average of speech detection score - mel_image: Current mel spectrogram image - inference_pending: Flag indicating if an inference is pending - last_inference_time: Timestamp of last inference new_chunk: A tuple containing the sample rate (sr) and the audio data (y). Returns: A tuple containing the updated state, a string indicating whether speech was detected, and an HTML element for the visual indicator. """ # Initialize state if it's the first call if state is None: state = initial_state.copy() state["mel_image"] = np.copy(initial_state["mel_image"]) sr, y = new_chunk # If no audio data, return current state if y is None or len(y) == 0: return ( state, f"No audio detected (Score: {state['ewma_score']:.2f})", generate_indicator(state["ewma_score"]), ) # Pre-processing: Convert to mono if stereo if len(y.shape) > 1: y = np.mean(y, axis=1) # Convert to float32 [-1.0, 1.0] y = y.astype(np.float32) / 32768 # Resample if necessary if sr != ORIGINAL_FS: y = resample_audio(y, sr, ORIGINAL_FS) # Append new audio to buffer buffer = np.concatenate([state["buffer"], y]) # Process frames from buffer new_frames_processed = 0 run_inference = False current_time = time.time() # Process as many complete frames as possible while len(buffer) >= FRAME_SIZE: # Extract frame frame = buffer[:FRAME_SIZE] buffer = buffer[FRAME_STEP:] # Advance by hop size (50% overlap) # Process frame and update mel spectrogram image state["mel_image"] = process_frame(frame, state["mel_image"]) new_frames_processed += 1 state["frames_processed"] += 1 # Determine if we should run inference if ( state["frames_processed"] >= INITIAL_FRAMES and (state["frames_processed"] - INITIAL_FRAMES) % FRAMES_PER_EVAL == 0 ): run_inference = True # Update buffer in state state["buffer"] = buffer # Check for completed inference results if not result_queue.empty(): # Get the result from the queue current_score = result_queue.get() # Update EWMA if state["ewma_score"] == 0: # First evaluation state["ewma_score"] = current_score else: state["ewma_score"] = ( ALPHA * current_score + (1 - ALPHA) * state["ewma_score"] ) # Mark that we're no longer waiting for inference state["inference_pending"] = False # Run VAD model if criteria are met and no inference is currently pending if ( ( run_inference or ( state["frames_processed"] >= INITIAL_FRAMES and new_frames_processed > 0 and current_time - state["last_inference_time"] > 0.1 # Rate limiting ) ) and not state["inference_pending"] and not inference_queue.full() ): # Instead of running inference here, queue it for the worker thread inference_queue.put(np.copy(state["mel_image"])) state["inference_pending"] = True state["last_inference_time"] = current_time # Determine result based on EWMA score if state["frames_processed"] < INITIAL_FRAMES: message = ( f"Building audio context... ({state['frames_processed']}/{INITIAL_FRAMES})" ) indicator = generate_indicator(0) # Inactive during initialization elif state["ewma_score"] > 0.5: message = f"Speech Detected (Score: {state['ewma_score']:.2f})" indicator = generate_indicator(state["ewma_score"]) else: message = f"No Speech Detected (Score: {state['ewma_score']:.2f})" indicator = generate_indicator(state["ewma_score"]) return state, message, indicator def generate_indicator(score): """ Generate an HTML indicator that lights up based on the speech detection score. Args: score: Speech detection score (0.0 to 1.0) Returns: HTML string for the visual indicator """ # Set colors based on score if score > 0.5: # Active - green glow brightness = min(100, 50 + int(score * 50)) # 50-100% brightness based on score color = f"rgba(144, 219, 130, {score:.1f})" # Green with opacity based on score glow = f"0 0 {int(score * 20)}px rgba(0, 255, 0, {score:.1f})" # Glow effect else: # Inactive - dim red brightness = max(10, int(score * 50)) # 0-25% brightness based on score color = f"rgba(255, 0, 0, {max(0.1, score):.1f})" # Red with low opacity glow = "none" # No glow when inactive # Create HTML for the indicator html = f"""