Spaces:
Sleeping
Sleeping
""" | |
This script implements a Voice Activity Detection (VAD) system using a Convolutional Neural Network (CNN) model. | |
The script uses Gradio for the user interface, NumPy and SciPy for signal | |
processing, and ONNX Runtime for model inference. | |
Modules: | |
- gradio: For creating the web interface. | |
- numpy: For numerical operations. | |
- scipy: For signal processing. | |
- onnxruntime: For running the ONNX model. | |
- threading: For running model inference in a separate thread. | |
- queue: For thread-safe communication between threads. | |
""" | |
import atexit | |
import queue | |
import threading | |
import time | |
import gradio as gr | |
import numpy as np | |
import onnxruntime as ort | |
from scipy import signal | |
from scipy.fft import rfft | |
# Parameters | |
ORIGINAL_FS = 48_000 # Original sampling frequency in Hz | |
DECIMATION_FACTOR = 3 # Decimation factor | |
TARGET_FS = ORIGINAL_FS // DECIMATION_FACTOR # Target sampling frequency in Hz | |
FRAME_DURATION_MS = 25 # Frame duration in milliseconds | |
FRAME_SIZE = int(ORIGINAL_FS * FRAME_DURATION_MS / 1000) # Frame size | |
FRAME_STEP = FRAME_SIZE // 2 # 50% overlap | |
# FFT parameters | |
DECIMATED_FRAME_SIZE = FRAME_SIZE // DECIMATION_FACTOR | |
FFT_SIZE = 1 | |
while FFT_SIZE < DECIMATED_FRAME_SIZE: | |
FFT_SIZE <<= 1 | |
# Mel spectrogram parameters | |
N_FILTERS = 40 # Number of mel filters | |
FREQ_LOW = 300 # Lower frequency bound | |
FREQ_HIGH = 8000 # Upper frequency bound | |
# Model evaluation parameters | |
INITIAL_FRAMES = 40 # Initial frames required for first model evaluation | |
FRAMES_PER_EVAL = 20 # Run model every this many new frames after initial | |
# EWMA parameter | |
ALPHA = 0.9 # Weight for current observation in EWMA | |
# Filter coefficients | |
FILTER_COEFFICIENTS = np.array([0.0625, 0.125, 0.25, 0.5, 0.25, 0.125, 0.0625]) | |
# Create window function (Hanning window) | |
window = np.hanning(DECIMATED_FRAME_SIZE) | |
# Create mel filterbank matrix | |
def create_mel_filterbank(n_filt, freq_low, freq_high, n_fft, fs): | |
""" | |
Create a mel filterbank matrix for computing the mel spectrogram. | |
Args: | |
n_filt: Number of mel filters | |
freq_low: Lower frequency bound | |
freq_high: Upper frequency bound | |
n_fft: FFT size | |
fs: Sampling frequency | |
Returns: | |
filterbank: Mel filterbank matrix of shape (n_fft//2+1, n_filt) | |
""" | |
# Convert Hz to mel | |
def hz_to_mel(hz): | |
return 1125 * np.log(1 + hz / 700) | |
# Convert mel to Hz | |
def mel_to_hz(mel): | |
return 700 * (np.exp(mel / 1125) - 1) | |
# Compute points evenly spaced in mel scale | |
lower_mel = hz_to_mel(freq_low) | |
higher_mel = hz_to_mel(freq_high) | |
mel_points = np.linspace(lower_mel, higher_mel, n_filt + 2) | |
# Convert mel points to Hz | |
hz_points = mel_to_hz(mel_points) | |
# Convert Hz points to FFT bin indices | |
bin_indices = np.floor((n_fft + 1) * hz_points / fs).astype(int) | |
# Create filterbank matrix | |
filterbank = np.zeros((n_fft // 2 + 1, n_filt)) | |
for i in range(n_filt): | |
# For each filter, create a triangular filter | |
for j in range(bin_indices[i], bin_indices[i + 1]): | |
filterbank[j, i] = (j - bin_indices[i]) / ( | |
bin_indices[i + 1] - bin_indices[i] | |
) | |
for j in range(bin_indices[i + 1], bin_indices[i + 2]): | |
filterbank[j, i] = (bin_indices[i + 2] - j) / ( | |
bin_indices[i + 2] - bin_indices[i + 1] | |
) | |
return filterbank | |
# Create mel filterbank | |
mel_filterbank = create_mel_filterbank( | |
N_FILTERS, FREQ_LOW, FREQ_HIGH, FFT_SIZE, TARGET_FS | |
) | |
# Initialize mel spectrogram image | |
mel_spectrogram_image = np.zeros((N_FILTERS, N_FILTERS), dtype=np.float32) | |
# Load the ONNX model | |
session_options = ort.SessionOptions() | |
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
session_options.intra_op_num_threads = 2 | |
session = ort.InferenceSession("./model/cnn-vad.onnx", session_options) | |
input_name = session.get_inputs()[0].name | |
output_name = session.get_outputs()[0].name | |
# Create thread-safe queues for ML tasks and results | |
inference_queue = queue.Queue(maxsize=10) # Queue for pending inference tasks | |
result_queue = queue.Queue(maxsize=10) # Queue for inference results | |
stop_event = threading.Event() # Event to signal thread termination | |
# ML worker thread function | |
def ml_worker(): | |
"""Worker thread function for running ML inference.""" | |
while not stop_event.is_set(): | |
try: | |
# Try to get a task with timeout to allow checking stop_event periodically | |
mel_image = inference_queue.get(timeout=0.1) | |
# Run inference | |
outputs = session.run( | |
[output_name], | |
{ | |
input_name: mel_image.astype(np.float32).reshape( | |
-1, N_FILTERS, N_FILTERS | |
) | |
}, | |
)[0] | |
# Get speech probability score and put in result queue | |
current_score = outputs[0][1] | |
result_queue.put(current_score) | |
# Mark task as done | |
inference_queue.task_done() | |
except queue.Empty: | |
# No pending tasks, just continue | |
continue | |
except Exception as e: | |
print(f"Error in ML worker thread: {e}") | |
# Still mark task as done if there was one | |
if not inference_queue.empty(): | |
inference_queue.task_done() | |
# Start ML worker thread | |
ml_thread = threading.Thread(target=ml_worker, daemon=True) | |
ml_thread.start() | |
# Initialize stream state | |
initial_state = { | |
"buffer": np.zeros(0, dtype=np.float32), # Audio buffer | |
"frames_processed": 0, # Number of frames processed | |
"ewma_score": 0.0, # EWMA of speech detection score | |
"mel_image": np.zeros( | |
(N_FILTERS, N_FILTERS), dtype=np.float32 | |
), # Mel spectrogram image | |
"inference_pending": False, # Flag to track if inference is pending | |
"last_inference_time": 0, # Timestamp of last inference | |
} | |
def resample_audio(audio, orig_sr, target_sr): | |
""" | |
Resample audio from original sample rate to target sample rate. | |
Simple implementation using linear interpolation. | |
Args: | |
audio: Audio data | |
orig_sr: Original sample rate | |
target_sr: Target sample rate | |
Returns: | |
Resampled audio | |
""" | |
if orig_sr == target_sr: | |
return audio | |
# For more accurate resampling in production, consider using: | |
# from scipy import signal | |
return signal.resample_poly(audio, target_sr, orig_sr) | |
# # Simple linear interpolation (less accurate but faster) | |
# duration = len(audio) / orig_sr | |
# new_length = int(duration * target_sr) | |
# indices = np.linspace(0, len(audio) - 1, new_length) | |
# indices = indices.astype(np.int32) | |
# return audio[indices] | |
def process_frame(frame, mel_image): | |
""" | |
Process a single frame of audio data. | |
Args: | |
frame: Audio frame data | |
mel_image: Current mel spectrogram image | |
Returns: | |
Updated mel spectrogram image | |
""" | |
# 1. Apply FIR filter | |
filtered_frame = signal.lfilter(FILTER_COEFFICIENTS, [1.0], frame) | |
# Clamp values | |
filtered_frame = np.clip(filtered_frame, -1.0, 1.0) | |
# 2. Decimate the frame | |
decimated_frame = filtered_frame[::DECIMATION_FACTOR] | |
# Apply window and pad with zeros to FFT_SIZE | |
padded_frame = np.zeros(FFT_SIZE) | |
padded_frame[:DECIMATED_FRAME_SIZE] = decimated_frame * window | |
# 3. Perform FFT | |
fft_result = rfft(padded_frame) | |
# Compute power spectrum | |
power_spectrum = np.abs(fft_result) ** 2 | |
# 4. Calculate mel spectrogram | |
mel_power = np.dot(power_spectrum[: FFT_SIZE // 2 + 1], mel_filterbank) | |
# Log mel spectrogram | |
mel_power = np.log(mel_power + 1e-8) | |
# Update mel spectrogram image (shift up and add new row at bottom) | |
mel_image = np.roll(mel_image, -1, axis=0) | |
mel_image[-1] = mel_power | |
return mel_image | |
def detect(state, new_chunk): | |
""" | |
Detects speech in an audio stream using a Voice Activity Detection (VAD) model. | |
Args: | |
state: Current state dictionary containing: | |
- buffer: Audio buffer | |
- frames_processed: Number of frames processed so far | |
- ewma_score: Exponentially weighted moving average of speech detection score | |
- mel_image: Current mel spectrogram image | |
- inference_pending: Flag indicating if an inference is pending | |
- last_inference_time: Timestamp of last inference | |
new_chunk: A tuple containing the sample rate (sr) and the audio data (y). | |
Returns: | |
A tuple containing the updated state, a string indicating whether speech was detected, | |
and an HTML element for the visual indicator. | |
""" | |
# Initialize state if it's the first call | |
if state is None: | |
state = initial_state.copy() | |
state["mel_image"] = np.copy(initial_state["mel_image"]) | |
sr, y = new_chunk | |
# If no audio data, return current state | |
if y is None or len(y) == 0: | |
return ( | |
state, | |
f"No audio detected (Score: {state['ewma_score']:.2f})", | |
generate_indicator(state["ewma_score"]), | |
) | |
# Pre-processing: Convert to mono if stereo | |
if len(y.shape) > 1: | |
y = np.mean(y, axis=1) | |
# Convert to float32 [-1.0, 1.0] | |
y = y.astype(np.float32) / 32768 | |
# Resample if necessary | |
if sr != ORIGINAL_FS: | |
y = resample_audio(y, sr, ORIGINAL_FS) | |
# Append new audio to buffer | |
buffer = np.concatenate([state["buffer"], y]) | |
# Process frames from buffer | |
new_frames_processed = 0 | |
run_inference = False | |
current_time = time.time() | |
# Process as many complete frames as possible | |
while len(buffer) >= FRAME_SIZE: | |
# Extract frame | |
frame = buffer[:FRAME_SIZE] | |
buffer = buffer[FRAME_STEP:] # Advance by hop size (50% overlap) | |
# Process frame and update mel spectrogram image | |
state["mel_image"] = process_frame(frame, state["mel_image"]) | |
new_frames_processed += 1 | |
state["frames_processed"] += 1 | |
# Determine if we should run inference | |
if ( | |
state["frames_processed"] >= INITIAL_FRAMES | |
and (state["frames_processed"] - INITIAL_FRAMES) % FRAMES_PER_EVAL == 0 | |
): | |
run_inference = True | |
# Update buffer in state | |
state["buffer"] = buffer | |
# Check for completed inference results | |
if not result_queue.empty(): | |
# Get the result from the queue | |
current_score = result_queue.get() | |
# Update EWMA | |
if state["ewma_score"] == 0: # First evaluation | |
state["ewma_score"] = current_score | |
else: | |
state["ewma_score"] = ( | |
ALPHA * current_score + (1 - ALPHA) * state["ewma_score"] | |
) | |
# Mark that we're no longer waiting for inference | |
state["inference_pending"] = False | |
# Run VAD model if criteria are met and no inference is currently pending | |
if ( | |
( | |
run_inference | |
or ( | |
state["frames_processed"] >= INITIAL_FRAMES | |
and new_frames_processed > 0 | |
and current_time - state["last_inference_time"] > 0.1 # Rate limiting | |
) | |
) | |
and not state["inference_pending"] | |
and not inference_queue.full() | |
): | |
# Instead of running inference here, queue it for the worker thread | |
inference_queue.put(np.copy(state["mel_image"])) | |
state["inference_pending"] = True | |
state["last_inference_time"] = current_time | |
# Determine result based on EWMA score | |
if state["frames_processed"] < INITIAL_FRAMES: | |
message = ( | |
f"Building audio context... ({state['frames_processed']}/{INITIAL_FRAMES})" | |
) | |
indicator = generate_indicator(0) # Inactive during initialization | |
elif state["ewma_score"] > 0.5: | |
message = f"Speech Detected (Score: {state['ewma_score']:.2f})" | |
indicator = generate_indicator(state["ewma_score"]) | |
else: | |
message = f"No Speech Detected (Score: {state['ewma_score']:.2f})" | |
indicator = generate_indicator(state["ewma_score"]) | |
return state, message, indicator | |
def generate_indicator(score): | |
""" | |
Generate an HTML indicator that lights up based on the speech detection score. | |
Args: | |
score: Speech detection score (0.0 to 1.0) | |
Returns: | |
HTML string for the visual indicator | |
""" | |
# Set colors based on score | |
if score > 0.5: | |
# Active - green glow | |
brightness = min(100, 50 + int(score * 50)) # 50-100% brightness based on score | |
color = f"rgba(144, 219, 130, {score:.1f})" # Green with opacity based on score | |
glow = f"0 0 {int(score * 20)}px rgba(0, 255, 0, {score:.1f})" # Glow effect | |
else: | |
# Inactive - dim red | |
brightness = max(10, int(score * 50)) # 0-25% brightness based on score | |
color = f"rgba(255, 0, 0, {max(0.1, score):.1f})" # Red with low opacity | |
glow = "none" # No glow when inactive | |
# Create HTML for the indicator | |
html = f""" | |
<div style="display: flex; justify-content: center; margin-top: 10px; margin-bottom: 10px;"> | |
<div style=" | |
width: 80px; | |
height: 80px; | |
border-radius: 50%; | |
background-color: {color}; | |
box-shadow: {glow}; | |
transition: all 0.3s ease; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
font-weight: bold; | |
color: white; | |
filter: brightness({brightness}%); | |
"> | |
{int(score * 100)}% | |
</div> | |
</div> | |
""" | |
return html | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
gr.Markdown("# Voice Activity Detection") | |
gr.Markdown( | |
"Speak into your microphone to see the voice activity detection in action." | |
) | |
# State for maintaining app state between calls | |
state = gr.State(None) | |
# Audio input | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
streaming=True, | |
autoplay=True, | |
elem_id="mic_input", | |
) | |
# Visual indicator and text output | |
with gr.Row(): | |
with gr.Column(scale=1): | |
indicator_html = gr.HTML(generate_indicator(0)) | |
with gr.Column(scale=2): | |
text_output = gr.Textbox(label="Detection Result") | |
# Set up the processing function | |
audio_input.stream( | |
detect, | |
inputs=[state, audio_input], | |
outputs=[state, text_output, indicator_html], | |
show_progress=False, | |
) | |
gr.Markdown(""" | |
## How it works | |
This app uses a Convolutional Neural Network (CNN) to detect speech in audio. | |
- The indicator lights up **green** when speech is detected | |
- The indicator turns **red** when no speech is detected | |
- The percentage shows the confidence level of speech detection | |
- The model and processing is listed in [this paper](https://ieeexplore.ieee.org/abstract/document/8278160) | |
""") | |
# Cleanup function to be called when the app is closed | |
def cleanup(): | |
# Signal the worker thread to stop | |
stop_event.set() | |
# Wait for the thread to finish (with timeout) | |
ml_thread.join(timeout=1.0) | |
print("ML worker thread stopped") | |
# Register cleanup handler | |
atexit.register(cleanup) | |
# Launch the interface | |
demo.launch() | |