|
|
""" |
|
|
Diarisation Améliorée avec Clustering Adaptatif et Validation de Qualité |
|
|
Vendored copy for importability from src/. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from sklearn.cluster import AgglomerativeClustering |
|
|
from sklearn.metrics import silhouette_score |
|
|
from typing import List, Dict, Tuple, Any |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ImprovedDiarization: |
|
|
"""Diarisation améliorée avec clustering adaptatif et validation de qualité""" |
|
|
|
|
|
def __init__(self): |
|
|
self.min_speaker_duration = 3.0 |
|
|
self.max_speakers = 10 |
|
|
self.quality_threshold = 0.3 |
|
|
|
|
|
def adaptive_clustering(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]: |
|
|
""" |
|
|
Détermine automatiquement le nombre optimal de locuteurs |
|
|
(version optimisée FAISS ; retombe sur sklearn si faiss absent) |
|
|
""" |
|
|
try: |
|
|
import faiss |
|
|
HAS_FAISS = True |
|
|
except ImportError: |
|
|
HAS_FAISS = False |
|
|
|
|
|
if len(embeddings) < 2: |
|
|
return 1, 1.0, np.zeros(len(embeddings)) |
|
|
|
|
|
if HAS_FAISS: |
|
|
return self._adaptive_faiss(embeddings) |
|
|
else: |
|
|
return self._adaptive_sklearn(embeddings) |
|
|
|
|
|
def _adaptive_faiss(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]: |
|
|
"""Recherche du meilleur k via FAISS Kmeans (très rapide CPU).""" |
|
|
import faiss |
|
|
n_samples, dim = embeddings.shape |
|
|
best_score, best_k, best_labels = -1, 2, None |
|
|
max_k = min(8, max(2, n_samples // 10)) |
|
|
for k in range(2, max_k + 1): |
|
|
kmeans = faiss.Kmeans(dim, k, niter=20, verbose=False, seed=42) |
|
|
kmeans.train(embeddings.astype(np.float32)) |
|
|
_, labels = kmeans.index.search(embeddings.astype(np.float32), 1) |
|
|
labels = labels.ravel() |
|
|
sil = silhouette_score(embeddings, labels) if len(set(labels)) > 1 else -1 |
|
|
unique, counts = np.unique(labels, return_counts=True) |
|
|
balance = min(counts) / max(counts) |
|
|
adjusted = sil * (0.7 + 0.3 * balance) |
|
|
if adjusted > best_score: |
|
|
best_score, best_k, best_labels = adjusted, k, labels |
|
|
return best_k, best_score, best_labels |
|
|
|
|
|
def _adaptive_sklearn(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]: |
|
|
"""Ancienne logique sklearn (conservée pour fallback).""" |
|
|
if len(embeddings) < 2: |
|
|
return 1, 1.0, np.zeros(len(embeddings)) |
|
|
|
|
|
best_score = -1 |
|
|
best_n_speakers = 2 |
|
|
best_labels = None |
|
|
|
|
|
|
|
|
if len(embeddings) > 100: |
|
|
|
|
|
configurations = [ |
|
|
('euclidean', 'ward'), |
|
|
('cosine', 'average'), |
|
|
] |
|
|
max_test_speakers = min(6, len(embeddings) - 1) |
|
|
else: |
|
|
|
|
|
configurations = [ |
|
|
('euclidean', 'ward'), |
|
|
('cosine', 'average'), |
|
|
('cosine', 'complete'), |
|
|
('euclidean', 'complete'), |
|
|
] |
|
|
max_test_speakers = min(self.max_speakers, len(embeddings) - 1) |
|
|
|
|
|
for n_speakers in range(2, max_test_speakers + 1): |
|
|
for metric, linkage in configurations: |
|
|
try: |
|
|
clustering = AgglomerativeClustering( |
|
|
n_clusters=n_speakers, |
|
|
metric=metric, |
|
|
linkage=linkage |
|
|
) |
|
|
labels = clustering.fit_predict(embeddings) |
|
|
|
|
|
|
|
|
if len(embeddings) > 300: |
|
|
|
|
|
sample_size = min(300, len(embeddings)) |
|
|
indices = np.random.choice(len(embeddings), sample_size, replace=False) |
|
|
score = silhouette_score(embeddings[indices], labels[indices], metric=metric) |
|
|
else: |
|
|
score = silhouette_score(embeddings, labels, metric=metric) |
|
|
|
|
|
|
|
|
unique, counts = np.unique(labels, return_counts=True) |
|
|
balance_ratio = min(counts) / max(counts) |
|
|
adjusted_score = score * (0.7 + 0.3 * balance_ratio) |
|
|
|
|
|
logger.debug(f"n_speakers={n_speakers}, metric={metric}, linkage={linkage}: " |
|
|
f"score={score:.3f}, balance={balance_ratio:.3f}, " |
|
|
f"adjusted={adjusted_score:.3f}") |
|
|
|
|
|
if adjusted_score > best_score: |
|
|
best_score = adjusted_score |
|
|
best_n_speakers = n_speakers |
|
|
best_labels = labels.copy() |
|
|
|
|
|
|
|
|
if len(embeddings) > 200 and n_speakers > 3 and adjusted_score < best_score * 0.9: |
|
|
logger.debug(f"Early stopping at {n_speakers} speakers (score degrading)") |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Clustering failed for n={n_speakers}, " |
|
|
f"metric={metric}, linkage={linkage}: {e}") |
|
|
continue |
|
|
|
|
|
return best_n_speakers, best_score, best_labels |
|
|
|
|
|
def validate_clustering_quality(self, embeddings: np.ndarray, labels: np.ndarray) -> Dict[str, Any]: |
|
|
"""Valide la qualité du clustering""" |
|
|
|
|
|
if len(np.unique(labels)) == 1: |
|
|
return { |
|
|
'silhouette_score': -1.0, |
|
|
'cluster_balance': 1.0, |
|
|
'quality': 'poor', |
|
|
'reason': 'single_cluster' |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
sil_score = silhouette_score(embeddings, labels) |
|
|
|
|
|
|
|
|
unique, counts = np.unique(labels, return_counts=True) |
|
|
cluster_balance = min(counts) / max(counts) |
|
|
|
|
|
|
|
|
|
|
|
n_samples = len(embeddings) |
|
|
max_pairs = min(1000, (n_samples * (n_samples - 1)) // 2) |
|
|
|
|
|
if n_samples > 50: |
|
|
|
|
|
np.random.seed(42) |
|
|
indices = np.random.choice(n_samples, size=min(100, n_samples), replace=False) |
|
|
sample_embeddings = embeddings[indices] |
|
|
sample_labels = labels[indices] |
|
|
else: |
|
|
sample_embeddings = embeddings |
|
|
sample_labels = labels |
|
|
|
|
|
|
|
|
from scipy.spatial.distance import pdist, squareform |
|
|
distances = pdist(sample_embeddings, metric='euclidean') |
|
|
dist_matrix = squareform(distances) |
|
|
|
|
|
intra_distances = [] |
|
|
inter_distances = [] |
|
|
|
|
|
for i in range(len(sample_embeddings)): |
|
|
for j in range(i + 1, len(sample_embeddings)): |
|
|
if sample_labels[i] == sample_labels[j]: |
|
|
intra_distances.append(dist_matrix[i, j]) |
|
|
else: |
|
|
inter_distances.append(dist_matrix[i, j]) |
|
|
|
|
|
separation_ratio = np.mean(inter_distances) / np.mean(intra_distances) if intra_distances else 1.0 |
|
|
|
|
|
|
|
|
quality = 'excellent' if sil_score > 0.7 and cluster_balance > 0.5 else \ |
|
|
'good' if sil_score > 0.5 and cluster_balance > 0.3 else \ |
|
|
'fair' if sil_score > 0.3 else 'poor' |
|
|
|
|
|
return { |
|
|
'silhouette_score': sil_score, |
|
|
'cluster_balance': cluster_balance, |
|
|
'separation_ratio': separation_ratio, |
|
|
'cluster_distribution': dict(zip(unique, counts)), |
|
|
'quality': quality, |
|
|
'reason': f"sil_score={sil_score:.3f}, balance={cluster_balance:.3f}" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Quality validation failed: {e}") |
|
|
return { |
|
|
'silhouette_score': -1.0, |
|
|
'cluster_balance': 0.0, |
|
|
'quality': 'error', |
|
|
'reason': str(e) |
|
|
} |
|
|
|
|
|
def refine_speaker_assignments(self, utterances: List[Dict], |
|
|
min_duration: float = None) -> List[Dict]: |
|
|
"""Affine les assignations de locuteurs""" |
|
|
|
|
|
if min_duration is None: |
|
|
min_duration = self.min_speaker_duration |
|
|
|
|
|
|
|
|
speaker_durations = {} |
|
|
for utt in utterances: |
|
|
speaker = utt['speaker'] |
|
|
duration = utt['end'] - utt['start'] |
|
|
speaker_durations[speaker] = speaker_durations.get(speaker, 0) + duration |
|
|
|
|
|
logger.info(f"Speaker durations: {speaker_durations}") |
|
|
|
|
|
|
|
|
weak_speakers = {s for s, d in speaker_durations.items() if d < min_duration} |
|
|
|
|
|
if not weak_speakers: |
|
|
return utterances |
|
|
|
|
|
logger.info(f"Weak speakers to reassign: {weak_speakers}") |
|
|
|
|
|
|
|
|
refined_utterances = [] |
|
|
for utt in utterances: |
|
|
if utt['speaker'] in weak_speakers: |
|
|
|
|
|
new_speaker = self._find_dominant_adjacent_speaker(utt, utterances, weak_speakers) |
|
|
utt['speaker'] = new_speaker |
|
|
logger.debug(f"Reassigned segment [{utt['start']:.1f}-{utt['end']:.1f}s] " |
|
|
f"to speaker {new_speaker}") |
|
|
|
|
|
refined_utterances.append(utt) |
|
|
|
|
|
return refined_utterances |
|
|
|
|
|
def _find_dominant_adjacent_speaker(self, target_utt: Dict, |
|
|
all_utterances: List[Dict], |
|
|
exclude_speakers: set) -> int: |
|
|
"""Trouve le locuteur dominant adjacent pour réassignation""" |
|
|
|
|
|
|
|
|
target_start = target_utt['start'] |
|
|
target_end = target_utt['end'] |
|
|
|
|
|
candidates = [] |
|
|
for utt in all_utterances: |
|
|
if utt['speaker'] in exclude_speakers: |
|
|
continue |
|
|
|
|
|
|
|
|
if utt['end'] <= target_start: |
|
|
distance = target_start - utt['end'] |
|
|
elif utt['start'] >= target_end: |
|
|
distance = utt['start'] - target_end |
|
|
else: |
|
|
distance = 0 |
|
|
|
|
|
candidates.append((utt['speaker'], distance)) |
|
|
|
|
|
if not candidates: |
|
|
|
|
|
for utt in all_utterances: |
|
|
if utt['speaker'] not in exclude_speakers: |
|
|
return utt['speaker'] |
|
|
return 0 |
|
|
|
|
|
|
|
|
return min(candidates, key=lambda x: x[1])[0] |
|
|
|
|
|
def merge_consecutive_same_speaker(self, utterances: List[Dict], |
|
|
max_gap: float = 1.0) -> List[Dict]: |
|
|
"""Fusionne les segments consécutifs du même locuteur""" |
|
|
|
|
|
if not utterances: |
|
|
return utterances |
|
|
|
|
|
merged = [] |
|
|
current = utterances[0].copy() |
|
|
|
|
|
for next_utt in utterances[1:]: |
|
|
|
|
|
if (current['speaker'] == next_utt['speaker'] and |
|
|
next_utt['start'] - current['end'] <= max_gap): |
|
|
|
|
|
|
|
|
current['text'] = current['text'].strip() + ' ' + next_utt['text'].strip() |
|
|
current['end'] = next_utt['end'] |
|
|
|
|
|
logger.debug(f"Merged segments: [{current['start']:.1f}-{current['end']:.1f}s] " |
|
|
f"Speaker {current['speaker']}") |
|
|
else: |
|
|
|
|
|
merged.append(current) |
|
|
current = next_utt.copy() |
|
|
|
|
|
|
|
|
merged.append(current) |
|
|
|
|
|
return merged |
|
|
|
|
|
def diarize_with_quality_control(self, embeddings: np.ndarray, |
|
|
utterances: List[Dict]) -> Tuple[List[Dict], Dict[str, Any]]: |
|
|
""" |
|
|
Diarisation complète avec contrôle qualité |
|
|
|
|
|
Returns: |
|
|
(utterances_with_speakers, quality_metrics) |
|
|
""" |
|
|
|
|
|
if len(embeddings) < 2: |
|
|
|
|
|
for utt in utterances: |
|
|
utt['speaker'] = 0 |
|
|
return utterances, {'quality': 'trivial', 'n_speakers': 1} |
|
|
|
|
|
|
|
|
n_speakers, clustering_score, labels = self.adaptive_clustering(embeddings) |
|
|
|
|
|
|
|
|
quality_metrics = self.validate_clustering_quality(embeddings, labels) |
|
|
quality_metrics['n_speakers'] = n_speakers |
|
|
quality_metrics['clustering_score'] = clustering_score |
|
|
|
|
|
logger.info(f"Adaptive clustering: {n_speakers} speakers, " |
|
|
f"score={clustering_score:.3f}, quality={quality_metrics['quality']}") |
|
|
|
|
|
|
|
|
for i, utt in enumerate(utterances): |
|
|
utt['speaker'] = int(labels[i]) |
|
|
|
|
|
|
|
|
if quality_metrics['quality'] not in ['error']: |
|
|
utterances = self.refine_speaker_assignments(utterances) |
|
|
utterances = self.merge_consecutive_same_speaker(utterances) |
|
|
|
|
|
return utterances, quality_metrics |
|
|
|
|
|
|
|
|
def enhance_diarization_pipeline(embeddings: np.ndarray, |
|
|
utterances: List[Dict]) -> Tuple[List[Dict], Dict[str, Any]]: |
|
|
""" |
|
|
Pipeline de diarisation amélioré - fonction principale |
|
|
|
|
|
Args: |
|
|
embeddings: Embeddings des segments audio (n_segments, 512) |
|
|
utterances: Liste des segments avec transcription |
|
|
|
|
|
Returns: |
|
|
(utterances_with_speakers, quality_report) |
|
|
""" |
|
|
|
|
|
improved_diarizer = ImprovedDiarization() |
|
|
|
|
|
|
|
|
diarized_utterances, quality_metrics = improved_diarizer.diarize_with_quality_control( |
|
|
embeddings, utterances |
|
|
) |
|
|
|
|
|
|
|
|
quality_report = { |
|
|
'success': quality_metrics['quality'] not in ['error', 'poor'], |
|
|
'confidence': 'high' if quality_metrics['quality'] in ['excellent', 'good'] else 'low', |
|
|
'metrics': quality_metrics, |
|
|
'recommendations': [] |
|
|
} |
|
|
|
|
|
|
|
|
if quality_metrics['quality'] == 'poor': |
|
|
quality_report['recommendations'].append( |
|
|
"Consider using single-speaker mode - clustering quality too low" |
|
|
) |
|
|
elif quality_metrics['silhouette_score'] < 0.3: |
|
|
quality_report['recommendations'].append( |
|
|
"Low speaker differentiation - verify audio quality" |
|
|
) |
|
|
elif quality_metrics['cluster_balance'] < 0.2: |
|
|
quality_report['recommendations'].append( |
|
|
"Unbalanced speaker distribution - check audio content" |
|
|
) |
|
|
|
|
|
return diarized_utterances, quality_report |
|
|
|