VoxSum / src /diarization.py
Luigi's picture
Consolidate tests under tests/, add LLM default tests with opt-out flag, model selection, README update
913c94a
#!/usr/bin/env python3
"""
Speaker Diarization module using Sherpa-ONNX
Integrates seamlessly with VoxSum ASR pipeline
Enhanced with adaptive clustering and quality validation
OPTIMIZED MODEL: 3dspeaker_campplus_zh_en_advanced
- Performance: F1=0.500, Accuracy=0.500
- Speed: 60.5ms average (2x faster than baseline)
- Size: 27MB (compact for production)
- Languages: Chinese/Taiwanese + English support
- Architecture: CAM++ multilingual advanced
"""
import os
import numpy as np
try:
import sherpa_onnx # type: ignore
except Exception: # pragma: no cover
class _SherpaStub: # minimal stub to allow tests without the dependency
class SpeakerEmbeddingExtractorConfig: # noqa: D401
def __init__(self, *args, **kwargs):
pass
class SpeakerEmbeddingExtractor:
def __init__(self, *args, **kwargs):
raise RuntimeError("sherpa_onnx not installed; real embedding extraction unavailable")
sherpa_onnx = _SherpaStub() # type: ignore
from pathlib import Path
from typing import List, Tuple, Optional, Callable, Dict, Any, Generator
import logging
from .utils import get_writable_model_dir, num_vcpus
try: # Optional dependency
from huggingface_hub import hf_hub_download # type: ignore
except Exception: # pragma: no cover
def hf_hub_download(*args, **kwargs): # minimal stub
raise RuntimeError("huggingface_hub not installed; model download unavailable")
import shutil
try: # Optional dependency
from sklearn.metrics import silhouette_score # type: ignore
except Exception: # pragma: no cover
def silhouette_score(*args, **kwargs):
return -1.0
# Import the improved diarization pipeline (robust: search repo tree)
try:
from importlib import import_module
# Try direct import first (works when repo root is in PYTHONPATH)
try:
mod = import_module('improved_diarization')
except Exception:
# Search up to 6 parent directories for improved_diarization.py
repo_root = None
current = Path(__file__).resolve()
for parent in list(current.parents)[:6]:
candidate = parent / 'improved_diarization.py'
if candidate.exists():
repo_root = parent
break
if repo_root is None:
# Fallback to CWD
cwd_candidate = Path.cwd() / 'improved_diarization.py'
if cwd_candidate.exists():
repo_root = Path.cwd()
if repo_root is not None:
import sys
sys.path.insert(0, str(repo_root))
mod = import_module('improved_diarization')
else:
raise ImportError('improved_diarization module not found in repository tree')
enhance_diarization_pipeline = getattr(mod, 'enhance_diarization_pipeline')
ENHANCED_DIARIZATION_AVAILABLE = True
print("✅ Enhanced diarization pipeline loaded successfully")
except Exception as e:
ENHANCED_DIARIZATION_AVAILABLE = False
logging.warning(f"Enhanced diarization not available - using fallback: {e}")
logger = logging.getLogger(__name__)
# Speaker colors for UI visualization
SPEAKER_COLORS = [
"#FF6B6B", # Red
"#4ECDC4", # Teal
"#45B7D1", # Blue
"#96CEB4", # Green
"#FFEAA7", # Yellow
"#DDA0DD", # Plum
"#FFB347", # Orange
"#87CEEB", # Sky Blue
"#F0E68C", # Khaki
"#FF69B4", # Hot Pink
]
def get_speaker_color(speaker_id: int) -> str:
"""Get consistent color for speaker ID"""
return SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)]
def download_diarization_models():
"""
Download required models for speaker diarization if not present
Only downloads embedding model - we'll use Silero VAD for segmentation
Returns tuple (embedding_model_path, success)
"""
# Use a writable cache directory (works on HF Spaces and local)
cache_dir = get_writable_model_dir()
models_dir = cache_dir / "diarization"
models_dir.mkdir(parents=True, exist_ok=True)
# Model info
repo_id = "csukuangfj/speaker-embedding-models"
filename = "3dspeaker_speech_campplus_sv_zh_en_16k-common_advanced.onnx"
embedding_model = models_dir / filename
logger.info(f"Model cache directory: {models_dir}")
try:
# Download using huggingface_hub if not present
if not embedding_model.exists():
logger.info("📥 Downloading eres2netv2 Chinese speaker model from HuggingFace (29MB)...")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=models_dir,
local_dir=models_dir,
local_dir_use_symlinks=False,
resume_download=True
)
# Move/copy to expected location if needed
if Path(downloaded_path) != embedding_model:
shutil.copy(downloaded_path, embedding_model)
logger.info("✅ eres2netv2 Chinese embedding model downloaded!")
return str(embedding_model), True
except Exception as e:
logger.error(f"❌ Failed to download diarization models: {e}")
return None, False
def init_speaker_embedding_extractor(
cluster_threshold: float = 0.5,
num_speakers: int = -1
) -> Optional[Tuple[object, dict]]:
"""
Initialize speaker embedding extractor (without segmentation)
We use Silero VAD segments from ASR pipeline instead of PyAnnote
Args:
cluster_threshold: Clustering threshold (lower = more speakers)
num_speakers: Number of speakers (-1 for auto-detection)
Returns:
Tuple of (embedding_extractor, config_dict) or None
"""
try:
# Download models if needed (only embedding model now)
embedding_model, success = download_diarization_models()
if not success:
return None
# Create embedding extractor config
embedding_config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=embedding_model,
num_threads=num_vcpus
)
# Initialize embedding extractor
embedding_extractor = sherpa_onnx.SpeakerEmbeddingExtractor(embedding_config)
# Store clustering parameters separately
config_dict = {
'cluster_threshold': cluster_threshold,
'num_speakers': num_speakers
}
return embedding_extractor, config_dict
except Exception as e:
logger.error(f"❌ Failed to initialize speaker embedding extractor: {e}")
return None
def perform_speaker_diarization_on_utterances(
audio: np.ndarray,
sample_rate: int,
utterances: List[Tuple[float, float, str]],
embedding_extractor: object,
config_dict: dict,
progress_callback: Optional[Callable] = None
) -> Generator[float | List[Tuple[float, float, int]], None, List[Tuple[float, float, int]]]:
"""
Perform speaker diarization using existing ASR utterance segments
This avoids double segmentation by reusing Silero VAD results
Args:
audio: Audio samples (float32, mono)
sample_rate: Sample rate (should be 16kHz for optimal results)
utterances: ASR utterances from Silero VAD segmentation
embedding_extractor: Initialized embedding extractor
config_dict: Configuration dictionary with clustering parameters
progress_callback: Optional progress callback function
Returns:
List of (start_time, end_time, speaker_id) tuples
"""
print(f"🔍 DEBUG: perform_speaker_diarization_on_utterances called with {len(utterances)} utterances")
try:
# Ensure audio is float32 and mono
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
if len(audio.shape) > 1:
audio = audio.mean(axis=1) # Convert to mono
# Check sample rate
if sample_rate != 16000:
warning_msg = f"⚠️ Audio sample rate is {sample_rate}Hz, but 16kHz is optimal for diarization"
logger.warning(warning_msg)
if not utterances:
logger.warning("⚠️ No utterances provided for diarization")
return []
logger.info(f"🎭 Extracting embeddings from {len(utterances)} utterance segments...")
# Extract embeddings for each utterance segment
embeddings = []
valid_utterances = []
# Progress tracking for UI
total_utterances = len(utterances)
batch_size = max(1, total_utterances // 20) # Process in batches for progress updates
for i, (start, end, text) in enumerate(utterances):
if i % batch_size == 0:
yield i / total_utterances * 0.8
# Extract audio segment
start_sample = int(start * sample_rate)
end_sample = int(end * sample_rate)
if i % 50 == 0: # Reduce debug frequency for large files
print(f"🔍 DEBUG: Processing utterance {i}/{total_utterances}: [{start:.1f}-{end:.1f}s]")
if start_sample >= len(audio) or end_sample <= start_sample:
if i % 50 == 0: # Reduce debug spam
print(f"⚠️ DEBUG: Skipping invalid segment {i}: start_sample={start_sample}, end_sample={end_sample}, audio_len={len(audio)}")
continue # Skip invalid segments
segment = audio[start_sample:end_sample]
# Skip very short segments (< 0.5 seconds)
if len(segment) < sample_rate * 0.5:
continue
try:
# Extract embedding using Sherpa-ONNX with proper stream API
if not hasattr(embedding_extractor, "create_stream"):
raise RuntimeError("Embedding extractor missing create_stream(); sherpa_onnx not available?")
stream = embedding_extractor.create_stream()
if hasattr(stream, "accept_waveform"):
stream.accept_waveform(sample_rate, segment)
if hasattr(stream, "input_finished"):
stream.input_finished()
if not hasattr(embedding_extractor, "compute"):
raise RuntimeError("Embedding extractor missing compute(); sherpa_onnx not available?")
embedding = embedding_extractor.compute(stream)
if embedding is not None and len(embedding) > 0:
embeddings.append(embedding)
valid_utterances.append((start, end, text))
if i % 100 == 0: # Progress log every 100 segments
print(f"✅ Extracted {len(embeddings)} embeddings so far...")
except Exception as e:
if i % 50 == 0: # Reduce error spam
print(f"⚠️ Failed to extract embedding for segment {i}: {e}")
continue
if not embeddings:
logger.error("❌ No valid embeddings extracted")
print(f"❌ DEBUG: Failed to extract any embeddings from {len(utterances)} utterances")
return []
print(f"✅ DEBUG: Extracted {len(embeddings)} embeddings for clustering")
logger.info(f"✅ Extracted {len(embeddings)} embeddings, performing clustering...")
# Convert embeddings to numpy array
embeddings_array = np.array(embeddings)
print(f"✅ DEBUG: Embeddings array shape: {embeddings_array.shape}")
n_embeddings = embeddings_array.shape[0]
# Cas très faible nombre de segments: éviter tout clustering complexe
if n_embeddings < 3:
print("⚠️ DEBUG: Moins de 3 segments – utilisation d'une heuristique simple sans clustering")
assignments: List[Tuple[float, float, int]] = []
if n_embeddings == 1:
(s, e, _t) = valid_utterances[0]
assignments.append((s, e, 0))
elif n_embeddings == 2:
try:
from sklearn.metrics.pairwise import cosine_similarity # type: ignore
sim = float(cosine_similarity(embeddings_array[0:1], embeddings_array[1:2])[0, 0])
except Exception:
a = embeddings_array[0].astype(float)
b = embeddings_array[1].astype(float)
denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1e-9
sim = float(np.dot(a, b) / denom)
(s1, e1, _t1) = valid_utterances[0]
(s2, e2, _t2) = valid_utterances[1]
if sim >= 0.80:
assignments.append((s1, e1, 0))
assignments.append((s2, e2, 0))
print(f"🟢 DEBUG: Deux segments fusionnés en un seul speaker (similarité={sim:.3f})")
else:
assignments.append((s1, e1, 0))
assignments.append((s2, e2, 1))
print(f"🟦 DEBUG: Deux speakers distincts (similarité={sim:.3f})")
if progress_callback:
progress_callback(1.0)
yield 1.0
yield assignments
return
# Use enhanced diarization if available
if ENHANCED_DIARIZATION_AVAILABLE and n_embeddings >= 3:
print("🚀 Using enhanced diarization with adaptive clustering...")
logger.info("🚀 Using enhanced adaptive clustering...")
# Prepare utterances dict format for enhanced pipeline
utterances_dict = []
for i, (start, end, text) in enumerate(valid_utterances):
utterances_dict.append({
'start': start,
'end': end,
'text': text,
'index': i
})
if progress_callback:
progress_callback(0.9) # 90% for clustering
yield 0.9
# Run enhanced diarization
try:
enhanced_utterances, quality_report = enhance_diarization_pipeline(
embeddings_array, utterances_dict
)
# Display quality report
quality = quality_report['metrics']['quality']
confidence = quality_report['confidence']
n_speakers = quality_report['metrics']['n_speakers']
quality_msg = f"🎯 Diarization Quality: {confidence} confidence ({quality})"
if quality in ['excellent', 'good']:
logger.info(quality_msg)
elif quality == 'fair':
logger.warning(quality_msg)
else:
logger.error(quality_msg)
print(f"✅ Enhanced diarization quality report:")
print(f" - Quality: {quality}")
print(f" - Confidence: {confidence}")
print(f" - Silhouette score: {quality_report['metrics'].get('silhouette_score', 'N/A'):.3f}")
print(f" - Cluster balance: {quality_report['metrics'].get('cluster_balance', 'N/A'):.3f}")
print(f" - Speakers detected: {n_speakers}")
if quality_report['recommendations']:
logger.info("💡 " + "; ".join(quality_report['recommendations']))
# Convert back to tuple format
diarization_result = []
for utt in enhanced_utterances:
diarization_result.append((utt['start'], utt['end'], utt['speaker']))
# Si l'enhanced pipeline a tout fusionné en un seul segment alors qu'on avait peu de segments
# on restaure la granularité originale pour ne pas perdre l'alignement temporel côté UI/tests.
if (
len(diarization_result) == 1
and len(valid_utterances) == n_embeddings
and n_embeddings <= 4
):
single_speaker = diarization_result[0][2]
diarization_result = [
(s, e, single_speaker) for (s, e, _t) in valid_utterances
]
if progress_callback:
progress_callback(1.0) # 100% complete
yield 1.0
print(f"✅ DEBUG: Enhanced result - {n_speakers} speakers, {len(diarization_result)} segments")
logger.info(f"🎭 Enhanced clustering completed! Detected {n_speakers} speakers with {confidence} confidence")
yield diarization_result
return
except Exception as e:
logger.error(f"❌ Enhanced diarization failed: {e}")
print(f"❌ Enhanced diarization failed: {e}")
# Fall back to original clustering
# Fallback to original clustering
logger.warning("⚠️ Using fallback clustering")
print("⚠️ Using fallback clustering")
gen = faiss_clustering(
embeddings_array,
valid_utterances,
config_dict,
progress_callback,
)
try:
while True:
p = next(gen)
yield p
except StopIteration as e:
diarization_result = e.value
yield diarization_result
return
except Exception as e:
error_msg = f"❌ Speaker diarization failed: {e}"
print(error_msg)
import traceback
traceback.print_exc()
return []
def merge_transcription_with_diarization(
utterances: List[Tuple[float, float, str]],
diarization: List[Tuple[float, float, int]]
) -> List[Tuple[float, float, str, int]]:
"""
Merge ASR transcription with speaker diarization results
Args:
utterances: List of (start, end, text) from ASR
diarization: List of (start, end, speaker_id) from diarization
Returns:
List of (start, end, text, speaker_id) tuples
"""
if not diarization:
# No diarization available, assign speaker 0 to all
return [(start, end, text, 0) for start, end, text in utterances]
merged_result = []
for utt_start, utt_end, text in utterances:
# Find overlapping speaker segments
best_speaker = 0
max_overlap = 0.0
for dia_start, dia_end, speaker_id in diarization:
# Calculate overlap between utterance and diarization segment
overlap_start = max(utt_start, dia_start)
overlap_end = min(utt_end, dia_end)
if overlap_end > overlap_start:
overlap_duration = overlap_end - overlap_start
if overlap_duration > max_overlap:
max_overlap = overlap_duration
best_speaker = speaker_id
merged_result.append((utt_start, utt_end, text, best_speaker))
return merged_result
def merge_consecutive_utterances(
utterances_with_speakers: List[Tuple[float, float, str, int]],
max_gap: float = 1.0
) -> List[Tuple[float, float, str, int]]:
"""
Merge consecutive utterances from the same speaker into single utterances
Args:
utterances_with_speakers: List of (start, end, text, speaker_id) tuples
max_gap: Maximum gap in seconds between utterances to merge
Returns:
List of merged (start, end, text, speaker_id) tuples
"""
if not utterances_with_speakers:
return utterances_with_speakers
# Sort by start time to ensure correct order
sorted_utterances = sorted(utterances_with_speakers, key=lambda x: x[0])
merged = []
current_start, current_end, current_text, current_speaker = sorted_utterances[0]
for i in range(1, len(sorted_utterances)):
next_start, next_end, next_text, next_speaker = sorted_utterances[i]
# Check if we should merge: same speaker and gap is acceptable
gap = next_start - current_end
if current_speaker == next_speaker and gap <= max_gap:
# Merge the utterances
current_text = current_text.strip() + ' ' + next_text.strip()
current_end = next_end
print(f"✅ DEBUG: Merged consecutive utterances from Speaker {current_speaker}: [{current_start:.1f}-{current_end:.1f}s]")
else:
# Finalize current utterance and start new one
merged.append((current_start, current_end, current_text, current_speaker))
current_start, current_end, current_text, current_speaker = next_start, next_end, next_text, next_speaker
# Add the last utterance
merged.append((current_start, current_end, current_text, current_speaker))
print(f"✅ DEBUG: Utterance merging complete: {len(utterances_with_speakers)}{len(merged)} utterances")
return merged
def format_speaker_transcript(
utterances_with_speakers: List[Tuple[float, float, str, int]]
) -> str:
"""
Format transcript with speaker labels
Args:
utterances_with_speakers: List of (start, end, text, speaker_id)
Returns:
Formatted transcript string
"""
if not utterances_with_speakers:
return ""
formatted_lines = []
current_speaker = None
for start, end, text, speaker_id in utterances_with_speakers:
# Add speaker label when speaker changes
if speaker_id != current_speaker:
formatted_lines.append(f"\n**Speaker {speaker_id + 1}:**")
current_speaker = speaker_id
# Add timestamped utterance
minutes = int(start // 60)
seconds = int(start % 60)
formatted_lines.append(f"[{minutes:02d}:{seconds:02d}] {text}")
return "\n".join(formatted_lines)
def get_diarization_stats(
utterances_with_speakers: List[Tuple[float, float, str, int]]
) -> dict:
"""
Calculate speaker diarization statistics
Returns:
Dictionary with speaking time per speaker and other stats
"""
if not utterances_with_speakers:
return {}
speaker_times = {}
speaker_utterances = {}
total_duration = 0
for start, end, text, speaker_id in utterances_with_speakers:
duration = end - start
total_duration += duration
if speaker_id not in speaker_times:
speaker_times[speaker_id] = 0
speaker_utterances[speaker_id] = 0
speaker_times[speaker_id] += duration
speaker_utterances[speaker_id] += 1
# Calculate percentages
stats = {
"total_speakers": len(speaker_times),
"total_duration": total_duration,
"speakers": {}
}
for speaker_id in sorted(speaker_times.keys()):
speaking_time = speaker_times[speaker_id]
percentage = (speaking_time / total_duration * 100) if total_duration > 0 else 0
stats["speakers"][speaker_id] = {
"speaking_time": speaking_time,
"percentage": percentage,
"utterances": speaker_utterances[speaker_id],
"avg_utterance_length": speaking_time / speaker_utterances[speaker_id] if speaker_utterances[speaker_id] > 0 else 0
}
return stats
def faiss_clustering(embeddings: np.ndarray,
utterances: list,
config_dict: dict,
progress_callback=None):
"""
Clustering via FAISS (K-means) ultra-rapide CPU.
Retourne la liste (start, end, speaker_id) compatible avec l'ancien code.
"""
try:
import faiss
except ImportError:
# FAISS absent → on retombe sur AgglomerativeClustering d'origine
gen = sklearn_fallback_clustering(embeddings, utterances, config_dict, progress_callback)
try:
while True:
p = next(gen)
yield p
except StopIteration as e:
return e.value
n_samples, dim = embeddings.shape
n_clusters = config_dict['num_speakers']
if n_clusters == -1:
# Si très peu d'échantillons, attribuer tout au locuteur 0
if n_samples < 3:
if progress_callback:
progress_callback(1.0)
yield 1.0
return [(s, e, 0) for (s, e, _t) in utterances]
max_k = min(10, max(2, n_samples // 2))
best_score, best_k, best_labels = -1.0, 2, None
emb32 = embeddings.astype(np.float32)
for k in range(2, max_k + 1):
if k >= n_samples: # éviter k == n_samples (silhouette invalide)
break
kmeans = faiss.Kmeans(dim, k, niter=25, verbose=False, seed=42)
kmeans.train(emb32)
_, lbls = kmeans.index.search(emb32, 1)
lbls = lbls.ravel()
uniq = set(lbls)
if 1 < len(uniq) < n_samples:
try:
sil = silhouette_score(embeddings, lbls)
except Exception:
sil = -1.0
else:
sil = -1.0
if sil > best_score:
best_score, best_k, best_labels = sil, k, lbls
if best_labels is None:
# Fallback trivial: tout un seul locuteur
if progress_callback:
progress_callback(1.0)
yield 1.0
return [(s, e, 0) for (s, e, _t) in utterances]
labels = best_labels
else:
kmeans = faiss.Kmeans(dim, min(n_clusters, n_samples), niter=20, verbose=False, seed=42)
kmeans.train(embeddings.astype(np.float32))
_, labels = kmeans.index.search(embeddings.astype(np.float32), 1)
labels = labels.ravel()
if progress_callback:
progress_callback(1.0)
yield 1.0
num_speakers = len(set(labels)) if labels is not None else 1
print(f"✅ DEBUG: FAISS clustering — {num_speakers} speakers, {len(utterances)} segments")
logger.info(f"🎭 FAISS clustering completed! Detected {num_speakers} speakers")
if labels is None:
return [(s, e, 0) for (s, e, _t) in utterances]
return [(start, end, int(lbl)) for (start, end, _), lbl in zip(utterances, labels)]
def sklearn_fallback_clustering(embeddings, utterances, config_dict, progress_callback=None):
"""
Ancienne voie sklearn conservée pour fallback sans FAISS.
"""
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(embeddings)
distance_matrix = 1 - similarity_matrix
n_clusters = config_dict['num_speakers']
if n_clusters == -1:
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=config_dict['cluster_threshold'],
metric='precomputed',
linkage='average'
)
else:
clustering = AgglomerativeClustering(
n_clusters=min(n_clusters, len(embeddings)),
metric='precomputed',
linkage='average'
)
if progress_callback:
progress_callback(0.9)
yield 0.9
labels = clustering.fit_predict(distance_matrix)
if progress_callback:
progress_callback(1.0)
yield 1.0
return [(start, end, int(lbl)) for (start, end, _), lbl in zip(utterances, labels)]