vad_demo / app.py
Gabriel Bibbó
adjust app.py
9b9394f
import gradio as gr
import numpy as np
import torch
import time
import warnings
from dataclasses import dataclass
from typing import List, Tuple, Dict
import threading
import queue
import os
import requests
from pathlib import Path
import base64
# Suppress warnings
warnings.filterwarnings('ignore')
# Function to convert image to base64
def image_to_base64(image_path):
try:
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
# Load logos as base64
def load_logos():
logos = {}
logo_files = {
'ai4s': 'ai4s_banner.png',
'surrey': 'surrey_logo.png',
'epsrc': 'EPSRC_logo.png',
'cvssp': 'CVSSP_logo.png'
}
for key, filename in logo_files.items():
if os.path.exists(filename):
logos[key] = image_to_base64(filename)
else:
print(f"Logo file {filename} not found")
logos[key] = None
return logos
# Optional imports with fallbacks
try:
import librosa
LIBROSA_AVAILABLE = True
print("✅ Librosa available")
except ImportError:
LIBROSA_AVAILABLE = False
print("⚠️ Librosa not available, using scipy fallback")
try:
import webrtcvad
WEBRTC_AVAILABLE = True
print("✅ WebRTC VAD available")
except ImportError:
WEBRTC_AVAILABLE = False
print("⚠️ WebRTC VAD not available, using fallback")
try:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
PLOTLY_AVAILABLE = True
print("✅ Plotly available")
except ImportError:
PLOTLY_AVAILABLE = False
print("⚠️ Plotly not available")
# PANNs imports - UPDATED to include SoundEventDetection
try:
from panns_inference import AudioTagging, SoundEventDetection, labels
PANNS_AVAILABLE = True
PANNS_SED_AVAILABLE = True
print("✅ PANNs available with SoundEventDetection")
except ImportError:
try:
from panns_inference import AudioTagging, labels
PANNS_AVAILABLE = True
PANNS_SED_AVAILABLE = False
print("✅ PANNs available (AudioTagging only)")
except ImportError:
PANNS_AVAILABLE = False
PANNS_SED_AVAILABLE = False
print("⚠️ PANNs not available, using fallback")
# Transformers for AST
try:
from transformers import ASTForAudioClassification, ASTFeatureExtractor
import transformers
AST_AVAILABLE = True
print("✅ AST (Transformers) available")
except ImportError:
AST_AVAILABLE = False
print("⚠️ AST not available, using fallback")
print("🚀 Creating VAD Demo...")
# ===== HELPER FUNCTIONS FOR CORRECTED MODELS =====
def safe_resample(x, sr_in, sr_out):
"""Safely resample audio from sr_in to sr_out with improved error handling"""
if sr_in == sr_out:
return x.astype(np.float32)
try:
if LIBROSA_AVAILABLE:
# Use librosa with error handling
result = librosa.resample(x.astype(float), orig_sr=sr_in, target_sr=sr_out)
return result.astype(np.float32)
else:
# Fallback linear interpolation
dur = len(x) / sr_in
n_out = max(1, int(round(dur * sr_out)))
xi = np.linspace(0, len(x)-1, num=len(x))
xo = np.linspace(0, len(x)-1, num=n_out)
return np.interp(xo, xi, x).astype(np.float32)
except Exception as e:
print(f"⚠️ Resample error ({sr_in}{sr_out}Hz): {e}")
# Return input as fallback
return x.astype(np.float32)
# ===== DATA STRUCTURES =====
@dataclass
class VADResult:
probability: float
is_speech: bool
model_name: str
processing_time: float
timestamp: float
@dataclass
class OnsetOffset:
onset_time: float
offset_time: float
model_name: str
confidence: float
# ===== MODEL IMPLEMENTATIONS =====
class OptimizedSileroVAD:
def __init__(self):
self.model = None
self.sample_rate = 16000
self.model_name = "Silero-VAD"
self.load_model()
def load_model(self):
try:
self.model, _ = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False
)
self.model.eval()
print(f"✅ {self.model_name} loaded successfully")
except Exception as e:
print(f"❌ Error loading {self.model_name}: {e}")
self.model = None
def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
start_time = time.time()
if self.model is None or len(audio) == 0:
return VADResult(0.0, False, f"{self.model_name} (unavailable)", time.time() - start_time, timestamp)
try:
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
required_samples = 512
# Silero requires exactly 512 samples, handle this precisely
if len(audio) != required_samples:
if len(audio) > required_samples:
# Take center portion to avoid edge effects
start_idx = (len(audio) - required_samples) // 2
audio_chunk = audio[start_idx:start_idx + required_samples]
else:
# Pad symmetrically instead of just at the end
pad_total = required_samples - len(audio)
pad_left = pad_total // 2
pad_right = pad_total - pad_left
audio_chunk = np.pad(audio, (pad_left, pad_right), 'reflect')
else:
audio_chunk = audio
audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0)
with torch.no_grad():
speech_prob = self.model(audio_tensor, self.sample_rate).item()
is_speech = speech_prob > 0.5
processing_time = time.time() - start_time
return VADResult(speech_prob, is_speech, self.model_name, processing_time, timestamp)
except Exception as e:
print(f"Error in {self.model_name}: {e}")
return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
class OptimizedWebRTCVAD:
def __init__(self):
self.model_name = "WebRTC-VAD"
self.sample_rate = 16000
self.frame_duration = 30 # Only 10, 20, or 30 ms are supported
self.frame_size = int(self.sample_rate * self.frame_duration / 1000) # 480 samples for 30ms
if WEBRTC_AVAILABLE:
try:
self.vad = webrtcvad.Vad(3)
print(f"✅ {self.model_name} loaded successfully (frame size: {self.frame_size} samples)")
except:
self.vad = None
else:
self.vad = None
def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
start_time = time.time()
if self.vad is None or len(audio) == 0:
energy = np.sum(audio ** 2) if len(audio) > 0 else 0
threshold = 0.01
probability = min(energy / threshold, 1.0)
is_speech = energy > threshold
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
try:
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
# Properly convert to int16 with clipping to avoid saturation
audio_clipped = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio_clipped * 32767).astype(np.int16)
# Ensure we have enough samples for at least one frame
if len(audio_int16) < self.frame_size:
# Pad to frame size
audio_int16 = np.pad(audio_int16, (0, self.frame_size - len(audio_int16)), 'constant')
speech_frames = 0
total_frames = 0
# Process exact frame sizes only
for i in range(0, len(audio_int16) - self.frame_size + 1, self.frame_size):
frame = audio_int16[i:i + self.frame_size].tobytes()
if self.vad.is_speech(frame, self.sample_rate):
speech_frames += 1
total_frames += 1
probability = speech_frames / max(total_frames, 1)
is_speech = probability > 0.3
return VADResult(probability, is_speech, self.model_name, time.time() - start_time, timestamp)
except Exception as e:
print(f"Error in {self.model_name}: {e}")
return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
class OptimizedEPANNs:
"""CORRECTED E-PANNs with proper temporal resolution using sliding windows"""
def __init__(self):
self.model_name = "E-PANNs"
self.sample_rate = 32000
self.win_s = 1.0 # CHANGED from 6.0 to 1.0 for better temporal resolution
print(f"✅ {self.model_name} initialized")
# Try to load PANNs AudioTagging as backend for E-PANNs
self.at_model = None
if PANNS_AVAILABLE:
try:
self.at_model = AudioTagging(checkpoint_path=None, device='cpu')
print(f"✅ {self.model_name} using PANNs AT backend")
except Exception as e:
print(f"⚠️ {self.model_name} PANNs AT unavailable: {e}")
def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
start_time = time.time()
try:
if len(audio) == 0:
return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
# CORRECTED: Work with the chunk directly, no more extracting windows
# The audio passed is already the chunk for this timestamp
x = safe_resample(audio, 16000, self.sample_rate)
# Pad to minimum window size if needed (no repeating)
min_samples = int(self.sample_rate * self.win_s)
if len(x) < min_samples:
x = np.pad(x, (0, min_samples - len(x)), mode='constant')
# If we have PANNs AT model, use it
if self.at_model is not None:
# Run inference
clipwise_output, _ = self.at_model.inference(x[np.newaxis, :])
# Get speech-related classes
speech_keywords = [
'speech', 'voice', 'talk', 'conversation', 'speaking',
'male speech', 'female speech', 'child speech',
'narration', 'monologue', 'speech synthesizer'
]
speech_indices = []
for i, lbl in enumerate(labels):
if any(word in lbl.lower() for word in speech_keywords):
speech_indices.append(i)
if speech_indices:
speech_probs = clipwise_output[0, speech_indices]
speech_score = float(np.max(speech_probs))
else:
speech_score = float(np.max(clipwise_output[0]))
else:
# Fallback to spectral features
if LIBROSA_AVAILABLE:
mel_spec = librosa.feature.melspectrogram(y=x, sr=self.sample_rate, n_mels=64)
energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=x, sr=self.sample_rate))
energy_score = np.clip((energy + 80) / 40, 0, 1)
centroid_score = np.clip((spectral_centroid - 200) / 3000, 0, 1)
speech_score = energy_score * 0.7 + centroid_score * 0.3
else:
energy = np.sum(x ** 2) / len(x)
speech_score = min(energy * 50, 1.0)
probability = np.clip(speech_score, 0, 1)
is_speech = probability > 0.4
return VADResult(probability, is_speech, self.model_name, time.time() - start_time, timestamp)
except Exception as e:
print(f"❌ E-PANNs ERROR: {e}")
import traceback
traceback.print_exc()
return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
class OptimizedPANNs:
"""CORRECTED PANNs with SoundEventDetection for framewise output when available"""
def __init__(self):
self.model_name = "PANNs"
self.sample_rate = 32000
self.model = None
self.sed_model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.load_model()
def load_model(self):
try:
if PANNS_AVAILABLE:
# Try to load SED model first for framewise output
if PANNS_SED_AVAILABLE:
try:
self.sed_model = SoundEventDetection(checkpoint_path=None, device=self.device)
print(f"✅ {self.model_name} SED loaded successfully (framewise mode)")
except Exception as e:
print(f"⚠️ {self.model_name} SED initialization failed: {e}")
self.sed_model = None
# Load AudioTagging as fallback or primary
if self.sed_model is None:
self.model = AudioTagging(checkpoint_path=None, device=self.device)
print(f"✅ {self.model_name} AT loaded successfully")
else:
print(f"⚠️ {self.model_name} not available, using fallback")
self.model = None
self.sed_model = None
except Exception as e:
print(f"❌ Error loading {self.model_name}: {e}")
self.model = None
self.sed_model = None
def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
start_time = time.time()
if (self.model is None and self.sed_model is None) or len(audio) == 0:
if len(audio) > 0:
energy = np.sum(audio ** 2)
threshold = 0.01
probability = min(energy / (threshold * 100), 1.0)
is_speech = energy > threshold
else:
probability = 0.0
is_speech = False
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
try:
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
# CORRECTED: Work with the chunk directly
# Convert audio to PANNs sample rate
if LIBROSA_AVAILABLE:
audio_resampled = librosa.resample(audio.astype(float),
orig_sr=16000,
target_sr=self.sample_rate)
else:
# Simple resampling fallback
resample_factor = self.sample_rate / 16000
audio_resampled = np.interp(
np.linspace(0, len(audio) - 1, int(len(audio) * resample_factor)),
np.arange(len(audio)),
audio
)
# For short audio, pad (no repeating)
min_samples = 1 * self.sample_rate # 1 second minimum
if len(audio_resampled) < min_samples:
audio_resampled = np.pad(audio_resampled, (0, min_samples - len(audio_resampled)), mode='constant')
# Use SED for framewise predictions if available
if self.sed_model is not None:
# SED gives framewise output
framewise_output = self.sed_model.inference(audio_resampled[np.newaxis, :])
if hasattr(framewise_output, 'cpu'):
framewise_output = framewise_output.cpu().numpy()
if framewise_output.ndim == 3:
framewise_output = framewise_output[0] # Remove batch dimension
# Get middle frame (corresponding to center of window)
frame_idx = framewise_output.shape[0] // 2
# Get speech-related classes
speech_keywords = [
'speech', 'voice', 'talk', 'conversation', 'speaking',
'male speech', 'female speech', 'child speech',
'narration', 'monologue'
]
speech_indices = []
for i, lbl in enumerate(labels):
if any(word in lbl.lower() for word in speech_keywords):
speech_indices.append(i)
if speech_indices and frame_idx < framewise_output.shape[0]:
speech_probs = framewise_output[frame_idx, speech_indices]
speech_prob = float(np.max(speech_probs))
else:
speech_prob = float(np.max(framewise_output[frame_idx])) if frame_idx < framewise_output.shape[0] else 0.0
else:
# Use AudioTagging model
# Run inference
clip_probs, _ = self.model.inference(audio_resampled[np.newaxis, :])
# Enhanced speech detection using multiple relevant labels
speech_keywords = [
'speech', 'voice', 'talk', 'conversation', 'speaking',
'male speech', 'female speech', 'child speech',
'narration', 'monologue'
]
speech_indices = []
for i, lbl in enumerate(labels):
if any(word in lbl.lower() for word in speech_keywords):
speech_indices.append(i)
# Also get silence/noise indices for contrast
noise_keywords = ['silence', 'white noise', 'pink noise']
noise_indices = []
for i, lbl in enumerate(labels):
if any(word in lbl.lower() for word in noise_keywords):
noise_indices.append(i)
if speech_indices:
# Get speech probability
speech_probs = clip_probs[0, speech_indices]
speech_prob = np.max(speech_probs) # Use max instead of mean for better detection
# Get noise probability for contrast
if noise_indices:
noise_prob = np.mean(clip_probs[0, noise_indices])
# Adjust speech probability based on noise
speech_prob = speech_prob * (1 - noise_prob * 0.5)
else:
# Fallback if no speech indices found
top_indices = np.argsort(clip_probs[0])[-10:]
speech_prob = np.mean(clip_probs[0, top_indices])
return VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
except Exception as e:
print(f"❌ PANNs ERROR: {e}")
import traceback
traceback.print_exc()
if len(audio) > 0:
energy = np.sum(audio ** 2)
threshold = 0.01
probability = min(energy / (threshold * 100), 1.0)
is_speech = energy > threshold
else:
probability = 0.0
is_speech = False
return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
class OptimizedAST:
"""CORRECTED AST with proper 16kHz sample rate and NO CACHE"""
def __init__(self):
self.model_name = "AST"
self.sample_rate = 16000 # AST REQUIRES 16kHz
self.model = None
self.feature_extractor = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# NO CACHE - removed cache_window and prediction_cache
self.load_model()
def load_model(self):
try:
if AST_AVAILABLE:
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
self.model = ASTForAudioClassification.from_pretrained(model_name)
self.model.to(self.device)
# Use FP16 for faster inference on GPU
if self.device.type == 'cuda':
self.model = self.model.half()
print(f"✅ {self.model_name} loaded with FP16 optimization")
else:
# Apply quantization for CPU acceleration
import torch.nn as nn
self.model = torch.quantization.quantize_dynamic(
self.model, {nn.Linear}, dtype=torch.qint8
)
print(f"✅ {self.model_name} loaded with CPU quantization")
self.model.eval()
else:
print(f"⚠️ {self.model_name} not available, using fallback")
self.model = None
except Exception as e:
print(f"❌ Error loading {self.model_name}: {e}")
self.model = None
def predict(self, audio: np.ndarray, timestamp: float = 0.0, full_audio: np.ndarray = None) -> VADResult:
start_time = time.time()
if self.model is None or len(audio) == 0:
# Enhanced fallback using spectral features
if len(audio) > 0:
energy = np.sum(audio ** 2)
if LIBROSA_AVAILABLE:
spectral_features = librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate)
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
# Combine multiple features for better speech detection
probability = min((energy * 100 + spectral_centroid / 1000) / 2, 1.0)
else:
probability = min(energy * 50, 1.0)
is_speech = probability > 0.25 # Use AST threshold
else:
probability = 0.0
is_speech = False
return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
try:
# NO CACHE - removed all cache-related code
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
# CRITICAL: AST uses 16kHz, input is already at 16kHz
audio_for_ast = audio.astype(np.float32)
# Pad to minimum 1 second if needed
min_samples = int(1.0 * self.sample_rate) # 1 second minimum
if len(audio_for_ast) < min_samples:
audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), mode='constant')
# Feature extraction with NO PADDING to 1024
inputs = self.feature_extractor(
audio_for_ast,
sampling_rate=self.sample_rate, # Must be 16kHz
return_tensors="pt",
padding=False, # CHANGED: No padding to 1024
truncation=False # CHANGED: No truncation
)
# Move inputs to correct device and dtype
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if self.device.type == 'cuda' and hasattr(self.model, 'half'):
# Convert inputs to FP16 if model is in FP16
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.sigmoid(logits)
# Find speech-related classes with enhanced keywords
label2id = self.model.config.label2id
speech_indices = []
speech_keywords = [
'speech', 'voice', 'talk', 'conversation', 'speaking',
'male speech', 'female speech', 'child speech',
'speech synthesizer', 'narration'
]
for lbl, idx in label2id.items():
if any(word in lbl.lower() for word in speech_keywords):
speech_indices.append(idx)
# Also identify background/noise classes for better discrimination
noise_keywords = ['silence', 'white noise', 'background']
noise_indices = []
for lbl, idx in label2id.items():
if any(word in lbl.lower() for word in noise_keywords):
noise_indices.append(idx)
if speech_indices:
# Use max probability among speech classes for better detection
speech_probs = probs[0, speech_indices]
speech_prob = torch.max(speech_probs).item()
# Consider noise/silence probability
if noise_indices:
noise_prob = torch.mean(probs[0, noise_indices]).item()
# Reduce speech probability if high noise/silence detected
speech_prob = speech_prob * (1 - noise_prob * 0.3)
else:
# Fallback to energy-based detection with better calibration
energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
speech_prob = min(energy * 20, 1.0) # Better scaling
# Use lower threshold specifically for AST (0.25 instead of 0.4)
is_speech_ast = speech_prob > 0.25
result = VADResult(float(speech_prob), is_speech_ast, self.model_name, time.time()-start_time, timestamp)
return result
except Exception as e:
print(f"❌ AST ERROR: {e}")
import traceback
traceback.print_exc()
# Enhanced fallback
if len(audio) > 0:
energy = np.sum(audio ** 2) / len(audio) # Normalize by length
probability = min(energy * 100, 1.0) # More conservative scaling
is_speech = energy > 0.001 # Lower threshold for fallback
else:
probability = 0.0
is_speech = False
return VADResult(probability, is_speech, f"{self.model_name} (error)", time.time() - start_time, timestamp)
# ===== AUDIO PROCESSOR =====
class AudioProcessor:
def __init__(self, sample_rate=16000):
self.sample_rate = sample_rate
self.chunk_duration = 4.0
self.chunk_size = int(sample_rate * self.chunk_duration)
self.n_fft = 2048
self.hop_length = 256
self.n_mels = 128
self.fmin = 20
self.fmax = 8000
self.base_window = 0.064
self.base_hop = 0.032
# Model-specific window sizes (each model gets appropriate context)
self.model_windows = {
"Silero-VAD": 0.032, # 32ms exactly as required (512 samples)
"WebRTC-VAD": 0.03, # 30ms frames (480 samples)
"E-PANNs": 1.0, # CHANGED from 6.0 to 1.0 for better temporal resolution
"PANNs": 1.0, # CHANGED from 10.0 to 1.0 for better temporal resolution
"AST": 0.96 # OPTIMIZED: Natural window size for AST
}
# Model-specific hop sizes for efficiency - OPTIMIZED for performance
self.model_hop_sizes = {
"Silero-VAD": 0.016, # 16ms hop for Silero (512 samples window)
"WebRTC-VAD": 0.03, # 30ms hop for WebRTC (match frame duration)
"E-PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
"PANNs": 0.05, # CHANGED from 0.1 to 0.05 for 20Hz
"AST": 0.1 # IMPROVED: Better resolution (10 Hz) while maintaining performance
}
# Model-specific thresholds for better detection
self.model_thresholds = {
"Silero-VAD": 0.5,
"WebRTC-VAD": 0.5,
"E-PANNs": 0.4,
"PANNs": 0.4,
"AST": 0.25
}
self.delay_compensation = 0.0
self.correlation_threshold = 0.5 # REDUCED: More sensitive delay detection
def process_audio(self, audio):
if audio is None:
return np.array([])
try:
if isinstance(audio, tuple):
sample_rate, audio_data = audio
if sample_rate != self.sample_rate and LIBROSA_AVAILABLE:
audio_data = librosa.resample(audio_data.astype(float),
orig_sr=sample_rate,
target_sr=self.sample_rate)
else:
audio_data = audio
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
if np.max(np.abs(audio_data)) > 0:
audio_data = audio_data / np.max(np.abs(audio_data))
return audio_data
except Exception as e:
print(f"Audio processing error: {e}")
return np.array([])
def compute_high_res_spectrogram(self, audio_data):
try:
if LIBROSA_AVAILABLE and len(audio_data) > 0:
stft = librosa.stft(
audio_data,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.n_fft,
window='hann',
center=True # CAMBIO: True para mejor alineación en bordes
)
power_spec = np.abs(stft) ** 2
mel_basis = librosa.filters.mel(
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.fmin,
fmax=self.fmax
)
mel_spec = np.dot(mel_basis, power_spec)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
# CORRECTED: Use frames_to_time with only valid parameters
frames = np.arange(mel_spec_db.shape[1])
time_frames = librosa.frames_to_time(
frames, sr=self.sample_rate, hop_length=self.hop_length
)
return mel_spec_db, time_frames
else:
from scipy import signal
f, t, Sxx = signal.spectrogram(
audio_data,
self.sample_rate,
nperseg=self.n_fft,
noverlap=self.n_fft - self.hop_length,
window='hann',
mode='psd'
)
# Ajustar tiempos para alinear con center=False (empezar en 0)
t -= (self.n_fft / 2.0) / self.sample_rate
mel_spec_db = np.zeros((self.n_mels, Sxx.shape[1]))
mel_freqs = np.logspace(
np.log10(self.fmin),
np.log10(min(self.fmax, self.sample_rate/2)),
self.n_mels + 1
)
for i in range(self.n_mels):
f_start = mel_freqs[i]
f_end = mel_freqs[i + 1]
bin_start = int(f_start * len(f) / (self.sample_rate/2))
bin_end = int(f_end * len(f) / (self.sample_rate/2))
if bin_end > bin_start:
mel_spec_db[i, :] = np.mean(Sxx[bin_start:bin_end, :], axis=0)
mel_spec_db = 10 * np.log10(mel_spec_db + 1e-10)
return mel_spec_db, t
except Exception as e:
print(f"Spectrogram computation error: {e}")
dummy_spec = np.zeros((self.n_mels, 200))
dummy_time = np.linspace(0, len(audio_data) / self.sample_rate, 200)
return dummy_spec, dummy_time
def detect_onset_offset_advanced(self, vad_results: List[VADResult],
threshold: float,
apply_delay: float = 0.0,
min_duration: float = 0.12,
total_duration: float = None) -> List[OnsetOffset]:
"""
Cruces exactos de umbral global, con compensación de delay y filtro de duración mínima.
Onset: p[i-1] < thr y p[i] >= thr
Offset: p[i-1] >= thr y p[i] < thr
El instante se obtiene por interpolación lineal entre (t[i-1], p[i-1]) y (t[i], p[i]).
"""
onsets_offsets = []
if len(vad_results) < 2:
return onsets_offsets
if total_duration is None:
total_duration = max([r.timestamp for r in vad_results]) + 0.01 if vad_results else 0.0
# agrupar por modelo
grouped = {}
for r in vad_results:
base = r.model_name.split('(')[0].strip()
# aplica delay al guardar
grouped.setdefault(base, []).append(
VADResult(r.probability, r.is_speech, base, r.processing_time, r.timestamp - apply_delay)
)
for base, rs in grouped.items():
if not rs:
continue
rs.sort(key=lambda r: r.timestamp)
# Add virtual start point if first timestamp > 0
if rs[0].timestamp > 0:
virtual_start = VADResult(
probability=rs[0].probability,
is_speech=rs[0].probability > threshold,
model_name=base,
processing_time=0,
timestamp=0.0
)
rs.insert(0, virtual_start)
# Add virtual end point if last timestamp < total_duration
if rs[-1].timestamp < total_duration - 1e-4:
virtual_end = VADResult(
probability=rs[-1].probability,
is_speech=rs[-1].probability > threshold,
model_name=base,
processing_time=0,
timestamp=total_duration
)
rs.append(virtual_end)
t = np.array([r.timestamp for r in rs], dtype=float)
p = np.array([r.probability for r in rs], dtype=float)
thr = float(threshold)
in_seg = False
onset_t = None
if p[0] > thr:
in_seg = True
onset_t = t[0]
def xcross(t0, p0, t1, p1, thr):
if p1 == p0: return t1
alpha = (thr - p0) / (p1 - p0)
return t0 + alpha * (t1 - t0)
for i in range(1, len(p)):
p0, p1 = p[i-1], p[i]
t0, t1 = t[i-1], t[i]
if (not in_seg) and (p0 < thr) and (p1 >= thr):
onset_t = xcross(t0, p0, t1, p1, thr)
in_seg = True
elif in_seg and (p0 >= thr) and (p1 < thr):
off = xcross(t0, p0, t1, p1, thr)
if off - onset_t >= min_duration: # debounce
mask = (t >= onset_t) & (t <= off)
conf = float(p[mask].mean()) if np.any(mask) else float(max(p0, p1))
onsets_offsets.append(OnsetOffset(max(0.0, float(onset_t)), float(off), base, conf))
in_seg = False
onset_t = None
if in_seg and onset_t is not None:
off = float(t[-1])
if off - onset_t >= min_duration:
mask = (t >= onset_t)
conf = float(p[mask].mean()) if np.any(mask) else float(p[-1])
onsets_offsets.append(OnsetOffset(max(0.0, float(onset_t)), off, base, conf))
return onsets_offsets
def estimate_delay_compensation(self, audio_data, vad_results):
try:
if len(audio_data) == 0 or len(vad_results) == 0:
return 0.0
window_size = int(self.sample_rate * self.base_window)
hop_size = int(self.sample_rate * self.base_hop)
energy_signal = []
for i in range(0, len(audio_data) - window_size + 1, hop_size):
window = audio_data[i:i + window_size]
energy = np.sum(window ** 2)
energy_signal.append(energy)
energy_signal = np.array(energy_signal)
if len(energy_signal) == 0:
return 0.0
energy_signal = (energy_signal - np.mean(energy_signal)) / (np.std(energy_signal) + 1e-8)
vad_times = np.array([r.timestamp for r in vad_results])
vad_probs = np.array([r.probability for r in vad_results])
energy_times = np.arange(len(energy_signal)) * self.base_hop + self.base_window / 2
vad_interp = np.interp(energy_times, vad_times, vad_probs)
vad_interp = (vad_interp - np.mean(vad_interp)) / (np.std(vad_interp) + 1e-8)
if len(energy_signal) > 10 and len(vad_interp) > 10:
correlation = np.correlate(energy_signal, vad_interp, mode='full')
delay_samples = np.argmax(correlation) - len(vad_interp) + 1
delay_seconds = -delay_samples * self.base_hop
if delay_seconds <= 0:
delay_seconds = 0.4
delay_seconds = np.clip(delay_seconds, 0, 1.0)
return delay_seconds
except Exception as e:
print(f"Delay estimation error: {e}")
return 0.4
# ===== ENHANCED VISUALIZATION =====
def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
onsets_offsets: List[OnsetOffset], processor: AudioProcessor,
model_a: str, model_b: str, threshold: float):
if not PLOTLY_AVAILABLE:
return None
try:
mel_spec_db, time_frames = processor.compute_high_res_spectrogram(audio_data)
freq_axis = np.linspace(processor.fmin, processor.fmax, processor.n_mels)
fig = make_subplots(
rows=2, cols=1,
subplot_titles=(f"Model A: {model_a}", f"Model B: {model_b}"),
vertical_spacing=0.15,
shared_xaxes=True,
specs=[[{"secondary_y": True}], [{"secondary_y": True}]]
)
colorscale = 'Viridis'
fig.add_trace(
go.Heatmap(
z=mel_spec_db,
x=time_frames,
y=freq_axis,
colorscale=colorscale,
showscale=False,
hovertemplate='Time: %{x:.2f}s<br>Freq: %{y:.0f}Hz<br>Power: %{z:.1f}dB<extra></extra>',
name=f'Spectrogram {model_a}'
),
row=1, col=1
)
fig.add_trace(
go.Heatmap(
z=mel_spec_db,
x=time_frames,
y=freq_axis,
colorscale=colorscale,
showscale=False,
hovertemplate='Time: %{x:.2f}s<br>Freq: %{y:.0f}Hz<br>Power: %{z:.1f}dB<extra></extra>',
name=f'Spectrogram {model_b}'
),
row=2, col=1
)
# Use global threshold for both models
thr_a = threshold
thr_b = threshold
if len(time_frames) > 0:
# Add threshold lines using model-specific thresholds
fig.add_shape(
type="line",
x0=time_frames[0], x1=time_frames[-1],
y0=thr_a, y1=thr_a,
line=dict(color='cyan', width=2, dash='dash'),
row=1, col=1,
yref="y2" # Reference to secondary y-axis
)
fig.add_shape(
type="line",
x0=time_frames[0], x1=time_frames[-1],
y0=thr_b, y1=thr_b,
line=dict(color='cyan', width=2, dash='dash'),
row=2, col=1,
yref="y4" # Reference to secondary y-axis of second subplot
)
# Add threshold annotations with global threshold
fig.add_annotation(
x=time_frames[-1] * 0.95, y=thr_a,
text=f'Threshold: {threshold:.2f}',
showarrow=False,
font=dict(color='cyan', size=10),
row=1, col=1,
yref="y2"
)
fig.add_annotation(
x=time_frames[-1] * 0.95, y=thr_b,
text=f'Threshold: {threshold:.2f}',
showarrow=False,
font=dict(color='cyan', size=10),
row=2, col=1,
yref="y4"
)
model_a_data = {'times': [], 'probs': []}
model_b_data = {'times': [], 'probs': []}
for result in vad_results:
# Fix model name filtering - remove suffixes properly and consistently
base_name = result.model_name.split('(')[0].strip()
if base_name == model_a:
model_a_data['times'].append(result.timestamp)
model_a_data['probs'].append(result.probability)
elif base_name == model_b:
model_b_data['times'].append(result.timestamp)
model_b_data['probs'].append(result.probability)
# IMPROVEMENT: Use common high-resolution time grid for better alignment
if len(time_frames) > 0:
common_times = np.linspace(0, time_frames[-1], 1000) # High-res grid
if len(model_a_data['times']) > 1:
# IMPROVED: Use first probability for extrapolation instead of 0
first_prob_a = model_a_data['probs'][0]
interp_probs_a = np.interp(common_times, model_a_data['times'], model_a_data['probs'],
left=first_prob_a, right=model_a_data['probs'][-1])
fig.add_trace(
go.Scatter(
x=common_times,
y=interp_probs_a,
mode='lines',
line=dict(color='yellow', width=3),
name=f'{model_a} Probability',
hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
showlegend=True
),
row=1, col=1, secondary_y=True
)
elif len(model_a_data['times']) == 1:
# Single point fallback
fig.add_trace(
go.Scatter(
x=model_a_data['times'],
y=model_a_data['probs'],
mode='markers',
marker=dict(size=8, color='yellow'),
name=f'{model_a} Probability',
hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
showlegend=True
),
row=1, col=1, secondary_y=True
)
if len(model_b_data['times']) > 1:
# IMPROVED: Use first probability for extrapolation instead of 0
first_prob_b = model_b_data['probs'][0]
interp_probs_b = np.interp(common_times, model_b_data['times'], model_b_data['probs'],
left=first_prob_b, right=model_b_data['probs'][-1])
fig.add_trace(
go.Scatter(
x=common_times,
y=interp_probs_b,
mode='lines',
line=dict(color='orange', width=3),
name=f'{model_b} Probability',
hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
showlegend=True
),
row=2, col=1, secondary_y=True
)
elif len(model_b_data['times']) == 1:
# Single point fallback
fig.add_trace(
go.Scatter(
x=model_b_data['times'],
y=model_b_data['probs'],
mode='markers',
marker=dict(size=8, color='orange'),
name=f'{model_b} Probability',
hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
showlegend=True
),
row=2, col=1, secondary_y=True
)
model_a_events = [e for e in onsets_offsets if e.model_name.split('(')[0].strip() == model_a]
model_b_events = [e for e in onsets_offsets if e.model_name.split('(')[0].strip() == model_b]
for event in model_a_events:
if event.onset_time >= 0 and event.onset_time <= time_frames[-1]:
fig.add_vline(
x=event.onset_time,
line=dict(color='lime', width=3),
annotation_text='▲',
annotation_position="top",
row=1, col=1
)
if event.offset_time >= 0 and event.offset_time <= time_frames[-1]:
fig.add_vline(
x=event.offset_time,
line=dict(color='red', width=3),
annotation_text='▼',
annotation_position="bottom",
row=1, col=1
)
for event in model_b_events:
if event.onset_time >= 0 and event.onset_time <= time_frames[-1]:
fig.add_vline(
x=event.onset_time,
line=dict(color='lime', width=3),
annotation_text='▲',
annotation_position="top",
row=2, col=1
)
if event.offset_time >= 0 and event.offset_time <= time_frames[-1]:
fig.add_vline(
x=event.offset_time,
line=dict(color='red', width=3),
annotation_text='▼',
annotation_position="bottom",
row=2, col=1
)
fig.update_layout(
height=600,
title_text="Real-Time Speech Visualizer",
showlegend=True,
legend=dict(
x=1.02,
y=1,
bgcolor="rgba(255,255,255,0.8)",
bordercolor="Black",
borderwidth=1
),
font=dict(size=10),
margin=dict(l=60, r=120, t=50, b=50),
plot_bgcolor='black',
paper_bgcolor='white',
yaxis2=dict(overlaying='y', side='right', title='Probability', range=[0, 1]),
yaxis4=dict(overlaying='y3', side='right', title='Probability', range=[0, 1])
)
fig.update_xaxes(
title_text="Time (seconds)",
row=2, col=1,
gridcolor='gray',
gridwidth=1,
griddash='dot'
)
fig.update_yaxes(
title_text="Frequency (Hz)",
range=[processor.fmin, processor.fmax],
gridcolor='gray',
gridwidth=1,
griddash='dot',
secondary_y=False
)
fig.update_yaxes(
title_text="Probability",
range=[0, 1],
secondary_y=True
)
return fig
except Exception as e:
print(f"Visualization error: {e}")
import traceback
traceback.print_exc()
fig = go.Figure()
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Error'))
fig.update_layout(title=f"Visualization Error: {str(e)}")
return fig
# ===== MAIN APPLICATION =====
class VADDemo:
def __init__(self):
print("🎤 Initializing VAD Demo with 5 models...")
# Debug: Check library availability
print("\n🔍 **LIBRARY AVAILABILITY CHECK**:")
print(f" LIBROSA_AVAILABLE: {LIBROSA_AVAILABLE}")
print(f" WEBRTC_AVAILABLE: {WEBRTC_AVAILABLE}")
print(f" PLOTLY_AVAILABLE: {PLOTLY_AVAILABLE}")
print(f" PANNS_AVAILABLE: {PANNS_AVAILABLE}")
print(f" AST_AVAILABLE: {AST_AVAILABLE}")
if PANNS_AVAILABLE:
try:
print(f" 📊 PANNs labels length: {len(labels) if 'labels' in globals() else 'labels not available'}")
except:
print(f" ❌ PANNs labels not accessible")
self.processor = AudioProcessor()
self.models = {
'Silero-VAD': OptimizedSileroVAD(),
'WebRTC-VAD': OptimizedWebRTCVAD(),
'E-PANNs': OptimizedEPANNs(),
'PANNs': OptimizedPANNs(),
'AST': OptimizedAST()
}
print("\n🎤 VAD Demo initialized successfully")
print(f"📊 Available models: {list(self.models.keys())}")
# Test each model availability
print(f"\n🔍 **MODEL STATUS CHECK**:")
for name, model in self.models.items():
if hasattr(model, 'model') and model.model is not None:
print(f" ✅ {name}: Model loaded")
else:
print(f" ⚠️ {name}: Using fallback")
print("")
def process_audio_with_events(self, audio, model_a, model_b, threshold):
if audio is None:
return None, "🔇 No audio detected", "Ready to process audio..."
try:
processed_audio = self.processor.process_audio(audio)
if len(processed_audio) == 0:
return None, "🎵 Processing audio...", "No audio data processed"
# DEBUG: Add comprehensive logging
debug_info = []
debug_info.append(f"🔍 **DEBUG INFO**")
debug_info.append(f"Audio length: {len(processed_audio)} samples ({len(processed_audio)/16000:.2f}s)")
debug_info.append(f"Sample rate: {self.processor.sample_rate}")
debug_info.append(f"Selected models: {[model_a, model_b]}")
vad_results = []
selected_models = list(set([model_a, model_b]))
# Process each model with its specific window and hop size
for model_name in selected_models:
if model_name in self.models:
window_size = self.processor.model_windows[model_name]
hop_size = self.processor.model_hop_sizes[model_name]
model_threshold = threshold # CORRECTED: Use global threshold from slider
window_samples = int(self.processor.sample_rate * window_size)
hop_samples = int(self.processor.sample_rate * hop_size)
debug_info.append(f"\n📊 **{model_name}**:")
debug_info.append(f" Window: {window_size}s ({window_samples} samples)")
debug_info.append(f" Hop: {hop_size}s ({hop_samples} samples)")
debug_info.append(f" Threshold: {model_threshold}")
model_results = []
# CRITICAL FIX: Always extract chunks, both for short and long audio
window_count = 0
audio_duration = len(processed_audio) / self.processor.sample_rate
for i in range(0, len(processed_audio), hop_samples):
# CRITICAL: Extract the chunk centered on this timestamp
start_pos = max(0, i - window_samples // 2)
end_pos = min(len(processed_audio), start_pos + window_samples)
chunk = processed_audio[start_pos:end_pos]
# Pad if necessary (with reflection, not zeros to avoid artificial silence)
if len(chunk) < window_samples:
chunk = np.pad(chunk, (0, window_samples - len(chunk)), mode='reflect')
# Skip chunks with excessive padding to avoid skewed predictions
padding_ratio = (window_samples - (end_pos - start_pos)) / window_samples
if padding_ratio > 0.5:
continue # Skip heavily padded chunks
# CORRECTED: Timestamp at ACTUAL CENTER of the chunk for alignment
actual_center = start_pos + (end_pos - start_pos) / 2.0
timestamp = actual_center / self.processor.sample_rate
if window_count < 3: # Log first 3 windows
debug_info.append(f" 🔄 Window {window_count}: t={timestamp:.2f}s (center), chunk_size={len(chunk)}")
# Call predict with the chunk
result = self.models[model_name].predict(chunk, timestamp)
if window_count < 3: # Log first 3 results
debug_info.append(f" 📈 Result {window_count}: prob={result.probability:.4f}, speech={result.is_speech}")
# Use model-specific threshold
result.is_speech = result.probability > model_threshold
vad_results.append(result)
model_results.append(result)
window_count += 1
# Stop if we've gone past the audio length
if timestamp >= audio_duration:
break
debug_info.append(f" 🎯 Total windows processed: {window_count}")
# Summary for this model
if model_results:
probs = [r.probability for r in model_results]
speech_count = sum(1 for r in model_results if r.is_speech)
total_time = sum(r.processing_time for r in model_results)
avg_time = total_time / len(model_results) if model_results else 0
debug_info.append(f" 📊 Summary: {len(model_results)} results, avg_prob={np.mean(probs):.4f}, speech_ratio={speech_count/len(model_results)*100 if model_results else 0:.1f}%")
debug_info.append(f" ⏱️ Processing: total={total_time:.3f}s, avg={avg_time:.4f}s/window")
else:
debug_info.append(f" ❌ No results generated!")
debug_info.append("\n⏱️ **TEMPORAL ALIGNMENT**:")
model_delays = {}
for model_name in selected_models:
model_results = [r for r in vad_results if r.model_name.split('(')[0].strip() == model_name]
if model_results:
delay = self.processor.estimate_delay_compensation(processed_audio, model_results)
model_delays[model_name] = delay
for r in model_results:
r.timestamp += delay
debug_info.append(f" Delay compensation = {delay:.3f}s applied to {model_name} timestamps")
# Compute total duration
total_duration = len(processed_audio) / self.processor.sample_rate if self.processor.sample_rate > 0 else 0.0
# CORRECTED: Use global threshold with delay compensation and min duration
onsets_offsets = self.processor.detect_onset_offset_advanced(
vad_results, threshold, apply_delay=0.0, min_duration=0.12, total_duration=total_duration
)
debug_info.append(f"\n🎭 **EVENTS**: {len(onsets_offsets)} onset/offset pairs detected")
fig = create_realtime_plot(
processed_audio, vad_results, onsets_offsets,
self.processor, model_a, model_b, threshold
)
speech_detected = any(result.is_speech for result in vad_results)
total_speech_chunks = sum(1 for r in vad_results if r.is_speech)
if speech_detected:
status_msg = f"🎙️ SPEECH DETECTED - {total_speech_chunks} active chunks"
else:
status_msg = f"🔇 No speech detected - {len(vad_results)} total results"
# Simplified details WITH debug info
model_summaries = {}
for result in vad_results:
base_name = result.model_name.split('(')[0].strip()
if base_name not in model_summaries:
model_summaries[base_name] = {'probs': [], 'speech_chunks': 0, 'total_chunks': 0, 'total_time': 0.0}
summary = model_summaries[base_name]
summary['probs'].append(result.probability)
summary['total_chunks'] += 1
summary['total_time'] += result.processing_time
if result.is_speech:
summary['speech_chunks'] += 1
# Show global threshold in analysis results
details_lines = [f"**Analysis Results** (Global Threshold: {threshold:.2f})"]
for model_name, summary in model_summaries.items():
avg_prob = np.mean(summary['probs']) if summary['probs'] else 0
speech_ratio = (summary['speech_chunks'] / summary['total_chunks']) * 100 if summary['total_chunks'] > 0 else 0
total_time = summary['total_time']
status_icon = "🟢" if speech_ratio > 0.5 else "🟡" if speech_ratio > 0.2 else "🔴"
details_lines.append(f"{status_icon} **{model_name}**: {avg_prob:.3f} avg prob, {speech_ratio:.1f}% speech, {total_time:.3f}s total time")
if onsets_offsets:
details_lines.append(f"\n**Speech Events**: {len(onsets_offsets)} detected")
for i, event in enumerate(onsets_offsets[:5]): # Show first 5 only
duration = event.offset_time - event.onset_time if event.offset_time > event.onset_time else 0
event_model = event.model_name.split('(')[0].strip()
details_lines.append(f"• {event_model}: {event.onset_time:.2f}s - {event.offset_time:.2f}s ({duration:.2f}s)")
# Add debug info at the end
details_lines.extend([""] + debug_info)
details_text = "\n".join(details_lines)
return fig, status_msg, details_text
except Exception as e:
print(f"Processing error: {e}")
import traceback
traceback.print_exc()
error_details = f"❌ Error: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}"
return None, f"❌ Error: {str(e)}", error_details
# ===== GRADIO INTERFACE =====
def create_interface():
# Load logos
logos = load_logos()
# Create logo HTML with base64 images
logo_html = """
<div style="display: flex; justify-content: center; align-items: center; gap: 30px; margin: 20px 0; flex-wrap: wrap;">
"""
logo_info = [
('ai4s', 'AI4S'),
('surrey', 'University of Surrey'),
('epsrc', 'EPSRC'),
('cvssp', 'CVSSP')
]
for key, alt_text in logo_info:
if logos[key]:
logo_html += f'<img src="data:image/png;base64,{logos[key]}" alt="{alt_text}" style="height: 60px; object-fit: contain;">'
else:
logo_html += f'<span style="padding: 10px; background: #333; color: white; border-radius: 5px;">{alt_text}</span>'
logo_html += "</div>"
with gr.Blocks(title="VAD Demo - Voice Activity Detection", theme=gr.themes.Soft()) as interface:
# Header with logos
gr.Markdown("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>🎤 VAD Demo - Voice Activity Detection</h1>
<p><strong>Multi-Model Speech Detection Framework</strong></p>
</div>
""")
# Logos section
with gr.Row():
gr.HTML(logo_html)
# Main interface
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 🎛️ Controls")
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="Record Audio"
)
model_a = gr.Dropdown(
choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
value="E-PANNs",
label="Model A (Top Panel)"
)
model_b = gr.Dropdown(
choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
value="PANNs",
label="Model B (Bottom Panel)"
)
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Detection Threshold (Global)"
)
process_btn = gr.Button("🎤 Analyze", variant="primary", size="lg")
with gr.Column(scale=3):
status_display = gr.Textbox(
label="Status",
value="🔇 Ready to analyze audio",
interactive=False,
lines=2
)
# Results
gr.Markdown("### 📊 Results")
with gr.Row():
plot_output = gr.Plot(label="Speech Detection Visualization")
with gr.Row():
details_output = gr.Textbox(
label="Analysis Details",
lines=10,
interactive=False
)
# Event handlers
process_btn.click(
fn=demo_app.process_audio_with_events,
inputs=[audio_input, model_a, model_b, threshold_slider],
outputs=[plot_output, status_display, details_output]
)
# Footer
gr.Markdown("""
---
**Models**: Silero-VAD, WebRTC-VAD, E-PANNs, PANNs, AST | **Research**: WASPAA 2025 | **Institution**: University of Surrey, CVSSP
**Note**: Perfect temporal alignment achieved - prediction curves now start from 0s and align precisely with spectrogram features.
""")
return interface
# Initialize demo only once
demo_app = VADDemo()
# Create and launch interface
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True, debug=False)