VoxSum / src /improved_diarization.py
Luigi's picture
Clean up legacy Streamlit implementation
228a065
"""
Diarisation Améliorée avec Clustering Adaptatif et Validation de Qualité
Vendored copy for importability from src/.
"""
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
from typing import List, Dict, Tuple, Any
import logging
logger = logging.getLogger(__name__)
class ImprovedDiarization:
"""Diarisation améliorée avec clustering adaptatif et validation de qualité"""
def __init__(self):
self.min_speaker_duration = 3.0 # Durée minimum par locuteur (secondes)
self.max_speakers = 10
self.quality_threshold = 0.3 # Seuil de qualité minimum
def adaptive_clustering(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]:
"""
Détermine automatiquement le nombre optimal de locuteurs
(version optimisée FAISS ; retombe sur sklearn si faiss absent)
"""
try:
import faiss
HAS_FAISS = True
except ImportError:
HAS_FAISS = False
if len(embeddings) < 2:
return 1, 1.0, np.zeros(len(embeddings))
if HAS_FAISS:
return self._adaptive_faiss(embeddings)
else:
return self._adaptive_sklearn(embeddings)
def _adaptive_faiss(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]:
"""Recherche du meilleur k via FAISS Kmeans (très rapide CPU)."""
import faiss
n_samples, dim = embeddings.shape
best_score, best_k, best_labels = -1, 2, None
max_k = min(8, max(2, n_samples // 10)) # Reduced for memory efficiency
for k in range(2, max_k + 1):
kmeans = faiss.Kmeans(dim, k, niter=20, verbose=False, seed=42)
kmeans.train(embeddings.astype(np.float32))
_, labels = kmeans.index.search(embeddings.astype(np.float32), 1)
labels = labels.ravel()
sil = silhouette_score(embeddings, labels) if len(set(labels)) > 1 else -1
unique, counts = np.unique(labels, return_counts=True)
balance = min(counts) / max(counts)
adjusted = sil * (0.7 + 0.3 * balance)
if adjusted > best_score:
best_score, best_k, best_labels = adjusted, k, labels
return best_k, best_score, best_labels
def _adaptive_sklearn(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]:
"""Ancienne logique sklearn (conservée pour fallback)."""
if len(embeddings) < 2:
return 1, 1.0, np.zeros(len(embeddings))
best_score = -1
best_n_speakers = 2
best_labels = None
# Reduced configurations for faster processing on large datasets
if len(embeddings) > 100:
# For large datasets, use faster configurations only
configurations = [
('euclidean', 'ward'),
('cosine', 'average'),
]
max_test_speakers = min(6, len(embeddings) - 1) # Limit search space
else:
# Full search for smaller datasets
configurations = [
('euclidean', 'ward'),
('cosine', 'average'),
('cosine', 'complete'),
('euclidean', 'complete'),
]
max_test_speakers = min(self.max_speakers, len(embeddings) - 1)
for n_speakers in range(2, max_test_speakers + 1):
for metric, linkage in configurations:
try:
clustering = AgglomerativeClustering(
n_clusters=n_speakers,
metric=metric,
linkage=linkage
)
labels = clustering.fit_predict(embeddings)
# Score de silhouette (with sampling for large datasets)
if len(embeddings) > 300:
# Sample for silhouette calculation to speed up
sample_size = min(300, len(embeddings))
indices = np.random.choice(len(embeddings), sample_size, replace=False)
score = silhouette_score(embeddings[indices], labels[indices], metric=metric)
else:
score = silhouette_score(embeddings, labels, metric=metric)
# Bonus pour distribution équilibrée
unique, counts = np.unique(labels, return_counts=True)
balance_ratio = min(counts) / max(counts)
adjusted_score = score * (0.7 + 0.3 * balance_ratio)
logger.debug(f"n_speakers={n_speakers}, metric={metric}, linkage={linkage}: "
f"score={score:.3f}, balance={balance_ratio:.3f}, "
f"adjusted={adjusted_score:.3f}")
if adjusted_score > best_score:
best_score = adjusted_score
best_n_speakers = n_speakers
best_labels = labels.copy()
# Early stopping for large datasets if score is decreasing
if len(embeddings) > 200 and n_speakers > 3 and adjusted_score < best_score * 0.9:
logger.debug(f"Early stopping at {n_speakers} speakers (score degrading)")
break
except Exception as e:
logger.warning(f"Clustering failed for n={n_speakers}, "
f"metric={metric}, linkage={linkage}: {e}")
continue
return best_n_speakers, best_score, best_labels
def validate_clustering_quality(self, embeddings: np.ndarray, labels: np.ndarray) -> Dict[str, Any]:
"""Valide la qualité du clustering"""
if len(np.unique(labels)) == 1:
return {
'silhouette_score': -1.0,
'cluster_balance': 1.0,
'quality': 'poor',
'reason': 'single_cluster'
}
try:
# Score de silhouette
sil_score = silhouette_score(embeddings, labels)
# Distribution des clusters
unique, counts = np.unique(labels, return_counts=True)
cluster_balance = min(counts) / max(counts)
# Distance intra vs inter-cluster (optimized with vectorization)
# Sample only 1000 pairs max for large datasets to avoid O(n²) complexity
n_samples = len(embeddings)
max_pairs = min(1000, (n_samples * (n_samples - 1)) // 2)
if n_samples > 50:
# Sample random pairs for large datasets
np.random.seed(42) # Reproducible sampling
indices = np.random.choice(n_samples, size=min(100, n_samples), replace=False)
sample_embeddings = embeddings[indices]
sample_labels = labels[indices]
else:
sample_embeddings = embeddings
sample_labels = labels
# Vectorized distance calculation
from scipy.spatial.distance import pdist, squareform
distances = pdist(sample_embeddings, metric='euclidean')
dist_matrix = squareform(distances)
intra_distances = []
inter_distances = []
for i in range(len(sample_embeddings)):
for j in range(i + 1, len(sample_embeddings)):
if sample_labels[i] == sample_labels[j]:
intra_distances.append(dist_matrix[i, j])
else:
inter_distances.append(dist_matrix[i, j])
separation_ratio = np.mean(inter_distances) / np.mean(intra_distances) if intra_distances else 1.0
# Évaluation globale
quality = 'excellent' if sil_score > 0.7 and cluster_balance > 0.5 else \
'good' if sil_score > 0.5 and cluster_balance > 0.3 else \
'fair' if sil_score > 0.3 else 'poor'
return {
'silhouette_score': sil_score,
'cluster_balance': cluster_balance,
'separation_ratio': separation_ratio,
'cluster_distribution': dict(zip(unique, counts)),
'quality': quality,
'reason': f"sil_score={sil_score:.3f}, balance={cluster_balance:.3f}"
}
except Exception as e:
logger.error(f"Quality validation failed: {e}")
return {
'silhouette_score': -1.0,
'cluster_balance': 0.0,
'quality': 'error',
'reason': str(e)
}
def refine_speaker_assignments(self, utterances: List[Dict],
min_duration: float = None) -> List[Dict]:
"""Affine les assignations de locuteurs"""
if min_duration is None:
min_duration = self.min_speaker_duration
# Calcule la durée par locuteur
speaker_durations = {}
for utt in utterances:
speaker = utt['speaker']
duration = utt['end'] - utt['start']
speaker_durations[speaker] = speaker_durations.get(speaker, 0) + duration
logger.info(f"Speaker durations: {speaker_durations}")
# Identifie les locuteurs avec durée insuffisante
weak_speakers = {s for s, d in speaker_durations.items() if d < min_duration}
if not weak_speakers:
return utterances
logger.info(f"Weak speakers to reassign: {weak_speakers}")
# Réassigne les segments des locuteurs faibles
refined_utterances = []
for utt in utterances:
if utt['speaker'] in weak_speakers:
# Trouve le locuteur dominant adjacent
new_speaker = self._find_dominant_adjacent_speaker(utt, utterances, weak_speakers)
utt['speaker'] = new_speaker
logger.debug(f"Reassigned segment [{utt['start']:.1f}-{utt['end']:.1f}s] "
f"to speaker {new_speaker}")
refined_utterances.append(utt)
return refined_utterances
def _find_dominant_adjacent_speaker(self, target_utt: Dict,
all_utterances: List[Dict],
exclude_speakers: set) -> int:
"""Trouve le locuteur dominant adjacent pour réassignation"""
# Trouve les segments adjacents
target_start = target_utt['start']
target_end = target_utt['end']
candidates = []
for utt in all_utterances:
if utt['speaker'] in exclude_speakers:
continue
# Distance temporelle
if utt['end'] <= target_start:
distance = target_start - utt['end']
elif utt['start'] >= target_end:
distance = utt['start'] - target_end
else:
distance = 0 # Chevauchement
candidates.append((utt['speaker'], distance))
if not candidates:
# Fallback: premier locuteur non exclu
for utt in all_utterances:
if utt['speaker'] not in exclude_speakers:
return utt['speaker']
return 0 # Fallback ultime
# Retourne le locuteur le plus proche
return min(candidates, key=lambda x: x[1])[0]
def merge_consecutive_same_speaker(self, utterances: List[Dict],
max_gap: float = 1.0) -> List[Dict]:
"""Fusionne les segments consécutifs du même locuteur"""
if not utterances:
return utterances
merged = []
current = utterances[0].copy()
for next_utt in utterances[1:]:
# Même locuteur et gap acceptable
if (current['speaker'] == next_utt['speaker'] and
next_utt['start'] - current['end'] <= max_gap):
# Fusionne les textes
current['text'] = current['text'].strip() + ' ' + next_utt['text'].strip()
current['end'] = next_utt['end']
logger.debug(f"Merged segments: [{current['start']:.1f}-{current['end']:.1f}s] "
f"Speaker {current['speaker']}")
else:
# Finalise le segment actuel
merged.append(current)
current = next_utt.copy()
# Ajoute le dernier segment
merged.append(current)
return merged
def diarize_with_quality_control(self, embeddings: np.ndarray,
utterances: List[Dict]) -> Tuple[List[Dict], Dict[str, Any]]:
"""
Diarisation complète avec contrôle qualité
Returns:
(utterances_with_speakers, quality_metrics)
"""
if len(embeddings) < 2:
# Cas trivial : un seul segment
for utt in utterances:
utt['speaker'] = 0
return utterances, {'quality': 'trivial', 'n_speakers': 1}
# Clustering adaptatif
n_speakers, clustering_score, labels = self.adaptive_clustering(embeddings)
# Validation de qualité
quality_metrics = self.validate_clustering_quality(embeddings, labels)
quality_metrics['n_speakers'] = n_speakers
quality_metrics['clustering_score'] = clustering_score
logger.info(f"Adaptive clustering: {n_speakers} speakers, "
f"score={clustering_score:.3f}, quality={quality_metrics['quality']}")
# Applique les labels aux utterances
for i, utt in enumerate(utterances):
utt['speaker'] = int(labels[i])
# Affinage des assignations
if quality_metrics['quality'] not in ['error']:
utterances = self.refine_speaker_assignments(utterances)
utterances = self.merge_consecutive_same_speaker(utterances)
return utterances, quality_metrics
def enhance_diarization_pipeline(embeddings: np.ndarray,
utterances: List[Dict]) -> Tuple[List[Dict], Dict[str, Any]]:
"""
Pipeline de diarisation amélioré - fonction principale
Args:
embeddings: Embeddings des segments audio (n_segments, 512)
utterances: Liste des segments avec transcription
Returns:
(utterances_with_speakers, quality_report)
"""
improved_diarizer = ImprovedDiarization()
# Diarisation avec contrôle qualité
diarized_utterances, quality_metrics = improved_diarizer.diarize_with_quality_control(
embeddings, utterances
)
# Rapport de qualité détaillé
quality_report = {
'success': quality_metrics['quality'] not in ['error', 'poor'],
'confidence': 'high' if quality_metrics['quality'] in ['excellent', 'good'] else 'low',
'metrics': quality_metrics,
'recommendations': []
}
# Recommandations basées sur la qualité
if quality_metrics['quality'] == 'poor':
quality_report['recommendations'].append(
"Consider using single-speaker mode - clustering quality too low"
)
elif quality_metrics['silhouette_score'] < 0.3:
quality_report['recommendations'].append(
"Low speaker differentiation - verify audio quality"
)
elif quality_metrics['cluster_balance'] < 0.2:
quality_report['recommendations'].append(
"Unbalanced speaker distribution - check audio content"
)
return diarized_utterances, quality_report