File size: 27,789 Bytes
766564c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
913c94a
 
 
 
 
 
 
 
 
 
 
766564c
913c94a
766564c
ba4a241
913c94a
 
 
 
 
7c59b7a
913c94a
 
 
 
 
766564c
9453a6f
766564c
9453a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766564c
 
9453a6f
766564c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3fe1c
 
 
766564c
fa3fe1c
7c59b7a
 
 
 
ba4a241
766564c
7c59b7a
766564c
ba4a241
7c59b7a
 
 
 
 
 
 
 
 
 
 
ba4a241
766564c
 
ba4a241
766564c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba4a241
766564c
 
e441a1a
 
766564c
ba4a241
766564c
 
ba4a241
766564c
 
 
 
 
ba4a241
766564c
ba4a241
766564c
ba4a241
766564c
 
 
 
 
 
 
 
 
913c94a
766564c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba4a241
766564c
 
ba4a241
766564c
 
ba4a241
766564c
 
 
 
 
e441a1a
 
 
 
766564c
030e33b
 
766564c
 
 
 
 
e441a1a
 
766564c
 
e441a1a
 
766564c
 
 
 
 
 
 
 
 
 
913c94a
 
766564c
913c94a
 
 
 
 
 
766564c
 
 
 
 
e441a1a
 
 
766564c
e441a1a
 
766564c
 
 
ba4a241
766564c
 
 
 
ba4a241
766564c
 
 
 
913c94a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766564c
 
913c94a
766564c
ba4a241
766564c
 
 
 
 
 
 
 
 
 
 
 
 
030e33b
766564c
 
 
 
 
 
 
 
 
 
 
 
 
 
ba4a241
766564c
ba4a241
766564c
ba4a241
766564c
 
 
 
 
 
 
 
 
ba4a241
766564c
 
 
 
 
913c94a
 
 
 
 
 
 
 
 
 
 
 
766564c
 
 
030e33b
913c94a
766564c
ba4a241
913c94a
 
 
766564c
 
ba4a241
766564c
 
 
 
ba4a241
766564c
6bf9bbb
913c94a
 
 
 
 
 
030e33b
 
 
 
 
 
913c94a
 
766564c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bf9bbb
 
 
 
 
030e33b
6bf9bbb
 
 
 
 
 
 
 
030e33b
 
 
 
 
 
 
6bf9bbb
 
 
 
913c94a
 
 
 
 
 
 
 
 
6bf9bbb
913c94a
 
 
 
 
 
 
 
 
 
 
 
 
 
6bf9bbb
913c94a
 
 
 
 
 
 
6bf9bbb
 
 
 
 
 
 
 
 
030e33b
6bf9bbb
913c94a
6bf9bbb
ba4a241
6bf9bbb
913c94a
 
6bf9bbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030e33b
6bf9bbb
 
 
030e33b
6bf9bbb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
#!/usr/bin/env python3
"""
Speaker Diarization module using Sherpa-ONNX
Integrates seamlessly with VoxSum ASR pipeline
Enhanced with adaptive clustering and quality validation

OPTIMIZED MODEL: 3dspeaker_campplus_zh_en_advanced
- Performance: F1=0.500, Accuracy=0.500 
- Speed: 60.5ms average (2x faster than baseline)
- Size: 27MB (compact for production)
- Languages: Chinese/Taiwanese + English support
- Architecture: CAM++ multilingual advanced
"""

import os
import numpy as np
try:
    import sherpa_onnx  # type: ignore
except Exception:  # pragma: no cover
    class _SherpaStub:  # minimal stub to allow tests without the dependency
        class SpeakerEmbeddingExtractorConfig:  # noqa: D401
            def __init__(self, *args, **kwargs):
                pass
        class SpeakerEmbeddingExtractor:
            def __init__(self, *args, **kwargs):
                raise RuntimeError("sherpa_onnx not installed; real embedding extraction unavailable")
    sherpa_onnx = _SherpaStub()  # type: ignore
from pathlib import Path
from typing import List, Tuple, Optional, Callable, Dict, Any, Generator
import logging
from .utils import get_writable_model_dir, num_vcpus
try:  # Optional dependency
    from huggingface_hub import hf_hub_download  # type: ignore
except Exception:  # pragma: no cover
    def hf_hub_download(*args, **kwargs):  # minimal stub
        raise RuntimeError("huggingface_hub not installed; model download unavailable")
import shutil
try:  # Optional dependency
    from sklearn.metrics import silhouette_score  # type: ignore
except Exception:  # pragma: no cover
    def silhouette_score(*args, **kwargs):
        return -1.0

# Import the improved diarization pipeline (robust: search repo tree)
try:
    from importlib import import_module
    # Try direct import first (works when repo root is in PYTHONPATH)
    try:
        mod = import_module('improved_diarization')
    except Exception:
        # Search up to 6 parent directories for improved_diarization.py
        repo_root = None
        current = Path(__file__).resolve()
        for parent in list(current.parents)[:6]:
            candidate = parent / 'improved_diarization.py'
            if candidate.exists():
                repo_root = parent
                break

        if repo_root is None:
            # Fallback to CWD
            cwd_candidate = Path.cwd() / 'improved_diarization.py'
            if cwd_candidate.exists():
                repo_root = Path.cwd()

        if repo_root is not None:
            import sys
            sys.path.insert(0, str(repo_root))
            mod = import_module('improved_diarization')
        else:
            raise ImportError('improved_diarization module not found in repository tree')

    enhance_diarization_pipeline = getattr(mod, 'enhance_diarization_pipeline')
    ENHANCED_DIARIZATION_AVAILABLE = True
    print("✅ Enhanced diarization pipeline loaded successfully")
except Exception as e:
    ENHANCED_DIARIZATION_AVAILABLE = False
    logging.warning(f"Enhanced diarization not available - using fallback: {e}")

logger = logging.getLogger(__name__)

# Speaker colors for UI visualization
SPEAKER_COLORS = [
    "#FF6B6B",  # Red
    "#4ECDC4",  # Teal  
    "#45B7D1",  # Blue
    "#96CEB4",  # Green
    "#FFEAA7",  # Yellow
    "#DDA0DD",  # Plum
    "#FFB347",  # Orange
    "#87CEEB",  # Sky Blue
    "#F0E68C",  # Khaki
    "#FF69B4",  # Hot Pink
]

def get_speaker_color(speaker_id: int) -> str:
    """Get consistent color for speaker ID"""
    return SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)]

def download_diarization_models():
    """
    Download required models for speaker diarization if not present
    Only downloads embedding model - we'll use Silero VAD for segmentation
    Returns tuple (embedding_model_path, success)
    """
    # Use a writable cache directory (works on HF Spaces and local)
    cache_dir = get_writable_model_dir()
    models_dir = cache_dir / "diarization"
    models_dir.mkdir(parents=True, exist_ok=True)

    # Model info
    repo_id = "csukuangfj/speaker-embedding-models"
    filename = "3dspeaker_speech_campplus_sv_zh_en_16k-common_advanced.onnx"
    embedding_model = models_dir / filename
    logger.info(f"Model cache directory: {models_dir}")
    try:
        # Download using huggingface_hub if not present
        if not embedding_model.exists():
            logger.info("📥 Downloading eres2netv2 Chinese speaker model from HuggingFace (29MB)...")
            downloaded_path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                cache_dir=models_dir,
                local_dir=models_dir,
                local_dir_use_symlinks=False,
                resume_download=True
            )
            # Move/copy to expected location if needed
            if Path(downloaded_path) != embedding_model:
                shutil.copy(downloaded_path, embedding_model)
            logger.info("✅ eres2netv2 Chinese embedding model downloaded!")
        return str(embedding_model), True
    except Exception as e:
        logger.error(f"❌ Failed to download diarization models: {e}")
        return None, False

def init_speaker_embedding_extractor(
    cluster_threshold: float = 0.5,
    num_speakers: int = -1
) -> Optional[Tuple[object, dict]]:
    """
    Initialize speaker embedding extractor (without segmentation)
    We use Silero VAD segments from ASR pipeline instead of PyAnnote
    
    Args:
        cluster_threshold: Clustering threshold (lower = more speakers)
        num_speakers: Number of speakers (-1 for auto-detection)
        
    Returns:
        Tuple of (embedding_extractor, config_dict) or None
    """
    try:
        # Download models if needed (only embedding model now)
        embedding_model, success = download_diarization_models()
        if not success:
            return None

        # Create embedding extractor config
        embedding_config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
            model=embedding_model,
            num_threads=num_vcpus
        )

        # Initialize embedding extractor
        embedding_extractor = sherpa_onnx.SpeakerEmbeddingExtractor(embedding_config)

        # Store clustering parameters separately
        config_dict = {
            'cluster_threshold': cluster_threshold,
            'num_speakers': num_speakers
        }

        return embedding_extractor, config_dict

    except Exception as e:
        logger.error(f"❌ Failed to initialize speaker embedding extractor: {e}")
        return None

def perform_speaker_diarization_on_utterances(
    audio: np.ndarray,
    sample_rate: int,
    utterances: List[Tuple[float, float, str]],
    embedding_extractor: object,
    config_dict: dict,
    progress_callback: Optional[Callable] = None
) -> Generator[float | List[Tuple[float, float, int]], None, List[Tuple[float, float, int]]]:
    """
    Perform speaker diarization using existing ASR utterance segments
    This avoids double segmentation by reusing Silero VAD results
    
    Args:
        audio: Audio samples (float32, mono)
        sample_rate: Sample rate (should be 16kHz for optimal results)
        utterances: ASR utterances from Silero VAD segmentation
        embedding_extractor: Initialized embedding extractor
        config_dict: Configuration dictionary with clustering parameters
        progress_callback: Optional progress callback function
        
    Returns:
        List of (start_time, end_time, speaker_id) tuples
    """
    print(f"🔍 DEBUG: perform_speaker_diarization_on_utterances called with {len(utterances)} utterances")
    
    try:
        # Ensure audio is float32 and mono
        if audio.dtype != np.float32:
            audio = audio.astype(np.float32)
        
        if len(audio.shape) > 1:
            audio = audio.mean(axis=1)  # Convert to mono
        
        # Check sample rate
        if sample_rate != 16000:
            warning_msg = f"⚠️ Audio sample rate is {sample_rate}Hz, but 16kHz is optimal for diarization"
            logger.warning(warning_msg)
        
        if not utterances:
            logger.warning("⚠️ No utterances provided for diarization")
            return []
        
        logger.info(f"🎭 Extracting embeddings from {len(utterances)} utterance segments...")
        
        # Extract embeddings for each utterance segment
        embeddings = []
        valid_utterances = []
        
        # Progress tracking for UI
        total_utterances = len(utterances)
        batch_size = max(1, total_utterances // 20)  # Process in batches for progress updates
        
        for i, (start, end, text) in enumerate(utterances):
            if i % batch_size == 0:
                yield i / total_utterances * 0.8
            
            # Extract audio segment
            start_sample = int(start * sample_rate)
            end_sample = int(end * sample_rate)
            
            if i % 50 == 0:  # Reduce debug frequency for large files
                print(f"🔍 DEBUG: Processing utterance {i}/{total_utterances}: [{start:.1f}-{end:.1f}s]")
            
            if start_sample >= len(audio) or end_sample <= start_sample:
                if i % 50 == 0:  # Reduce debug spam
                    print(f"⚠️ DEBUG: Skipping invalid segment {i}: start_sample={start_sample}, end_sample={end_sample}, audio_len={len(audio)}")
                continue  # Skip invalid segments
                
            segment = audio[start_sample:end_sample]
            
            # Skip very short segments (< 0.5 seconds)
            if len(segment) < sample_rate * 0.5:
                continue
            
            try:
                # Extract embedding using Sherpa-ONNX with proper stream API
                if not hasattr(embedding_extractor, "create_stream"):
                    raise RuntimeError("Embedding extractor missing create_stream(); sherpa_onnx not available?")
                stream = embedding_extractor.create_stream()
                if hasattr(stream, "accept_waveform"):
                    stream.accept_waveform(sample_rate, segment)
                if hasattr(stream, "input_finished"):
                    stream.input_finished()
                if not hasattr(embedding_extractor, "compute"):
                    raise RuntimeError("Embedding extractor missing compute(); sherpa_onnx not available?")
                embedding = embedding_extractor.compute(stream)
                
                if embedding is not None and len(embedding) > 0:
                    embeddings.append(embedding)
                    valid_utterances.append((start, end, text))
                    if i % 100 == 0:  # Progress log every 100 segments
                        print(f"✅ Extracted {len(embeddings)} embeddings so far...")
                        
            except Exception as e:
                if i % 50 == 0:  # Reduce error spam
                    print(f"⚠️ Failed to extract embedding for segment {i}: {e}")
                continue
        
        if not embeddings:
            logger.error("❌ No valid embeddings extracted")
            print(f"❌ DEBUG: Failed to extract any embeddings from {len(utterances)} utterances")
            return []
        
        print(f"✅ DEBUG: Extracted {len(embeddings)} embeddings for clustering")
        logger.info(f"✅ Extracted {len(embeddings)} embeddings, performing clustering...")
        
        # Convert embeddings to numpy array
        embeddings_array = np.array(embeddings)
        print(f"✅ DEBUG: Embeddings array shape: {embeddings_array.shape}")
        n_embeddings = embeddings_array.shape[0]

        # Cas très faible nombre de segments: éviter tout clustering complexe
        if n_embeddings < 3:
            print("⚠️ DEBUG: Moins de 3 segments – utilisation d'une heuristique simple sans clustering")
            assignments: List[Tuple[float, float, int]] = []
            if n_embeddings == 1:
                (s, e, _t) = valid_utterances[0]
                assignments.append((s, e, 0))
            elif n_embeddings == 2:
                try:
                    from sklearn.metrics.pairwise import cosine_similarity  # type: ignore
                    sim = float(cosine_similarity(embeddings_array[0:1], embeddings_array[1:2])[0, 0])
                except Exception:
                    a = embeddings_array[0].astype(float)
                    b = embeddings_array[1].astype(float)
                    denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1e-9
                    sim = float(np.dot(a, b) / denom)
                (s1, e1, _t1) = valid_utterances[0]
                (s2, e2, _t2) = valid_utterances[1]
                if sim >= 0.80:
                    assignments.append((s1, e1, 0))
                    assignments.append((s2, e2, 0))
                    print(f"🟢 DEBUG: Deux segments fusionnés en un seul speaker (similarité={sim:.3f})")
                else:
                    assignments.append((s1, e1, 0))
                    assignments.append((s2, e2, 1))
                    print(f"🟦 DEBUG: Deux speakers distincts (similarité={sim:.3f})")
            if progress_callback:
                progress_callback(1.0)
            yield 1.0
            yield assignments
            return
        
        # Use enhanced diarization if available
        if ENHANCED_DIARIZATION_AVAILABLE and n_embeddings >= 3:
            print("🚀 Using enhanced diarization with adaptive clustering...")
            logger.info("🚀 Using enhanced adaptive clustering...")
            
            # Prepare utterances dict format for enhanced pipeline
            utterances_dict = []
            for i, (start, end, text) in enumerate(valid_utterances):
                utterances_dict.append({
                    'start': start,
                    'end': end,
                    'text': text,
                    'index': i
                })
            
            if progress_callback:
                progress_callback(0.9)  # 90% for clustering
            yield 0.9
            
            # Run enhanced diarization
            try:
                enhanced_utterances, quality_report = enhance_diarization_pipeline(
                    embeddings_array, utterances_dict
                )
                
                # Display quality report
                quality = quality_report['metrics']['quality']
                confidence = quality_report['confidence']
                n_speakers = quality_report['metrics']['n_speakers']
                
                quality_msg = f"🎯 Diarization Quality: {confidence} confidence ({quality})"
                if quality in ['excellent', 'good']:
                    logger.info(quality_msg)
                elif quality == 'fair':
                    logger.warning(quality_msg)
                else:
                    logger.error(quality_msg)
                
                print(f"✅ Enhanced diarization quality report:")
                print(f"   - Quality: {quality}")
                print(f"   - Confidence: {confidence}")
                print(f"   - Silhouette score: {quality_report['metrics'].get('silhouette_score', 'N/A'):.3f}")
                print(f"   - Cluster balance: {quality_report['metrics'].get('cluster_balance', 'N/A'):.3f}")
                print(f"   - Speakers detected: {n_speakers}")
                
                if quality_report['recommendations']:
                    logger.info("💡 " + "; ".join(quality_report['recommendations']))
                
                # Convert back to tuple format
                diarization_result = []
                for utt in enhanced_utterances:
                    diarization_result.append((utt['start'], utt['end'], utt['speaker']))

                    # Si l'enhanced pipeline a tout fusionné en un seul segment alors qu'on avait peu de segments
                    # on restaure la granularité originale pour ne pas perdre l'alignement temporel côté UI/tests.
                    if (
                        len(diarization_result) == 1
                        and len(valid_utterances) == n_embeddings
                        and n_embeddings <= 4
                    ):
                        single_speaker = diarization_result[0][2]
                        diarization_result = [
                            (s, e, single_speaker) for (s, e, _t) in valid_utterances
                        ]
                
                if progress_callback:
                    progress_callback(1.0)  # 100% complete
                yield 1.0

                print(f"✅ DEBUG: Enhanced result - {n_speakers} speakers, {len(diarization_result)} segments")
                logger.info(f"🎭 Enhanced clustering completed! Detected {n_speakers} speakers with {confidence} confidence")

                yield diarization_result
                return
                
            except Exception as e:
                logger.error(f"❌ Enhanced diarization failed: {e}")
                print(f"❌ Enhanced diarization failed: {e}")
                # Fall back to original clustering
        
        # Fallback to original clustering
        logger.warning("⚠️ Using fallback clustering")
        print("⚠️ Using fallback clustering")

        gen = faiss_clustering(
            embeddings_array,
            valid_utterances,
            config_dict,
            progress_callback,
        )
        try:
            while True:
                p = next(gen)
                yield p
        except StopIteration as e:
            diarization_result = e.value
        yield diarization_result
        return
        
    except Exception as e:
        error_msg = f"❌ Speaker diarization failed: {e}"
        print(error_msg)
        import traceback
        traceback.print_exc()
        return []

def merge_transcription_with_diarization(
    utterances: List[Tuple[float, float, str]],
    diarization: List[Tuple[float, float, int]]
) -> List[Tuple[float, float, str, int]]:
    """
    Merge ASR transcription with speaker diarization results
    
    Args:
        utterances: List of (start, end, text) from ASR
        diarization: List of (start, end, speaker_id) from diarization
        
    Returns:
        List of (start, end, text, speaker_id) tuples
    """
    if not diarization:
        # No diarization available, assign speaker 0 to all
        return [(start, end, text, 0) for start, end, text in utterances]
    
    merged_result = []
    
    for utt_start, utt_end, text in utterances:
        # Find overlapping speaker segments
        best_speaker = 0
        max_overlap = 0.0
        
        for dia_start, dia_end, speaker_id in diarization:
            # Calculate overlap between utterance and diarization segment
            overlap_start = max(utt_start, dia_start)
            overlap_end = min(utt_end, dia_end)
            
            if overlap_end > overlap_start:
                overlap_duration = overlap_end - overlap_start
                if overlap_duration > max_overlap:
                    max_overlap = overlap_duration
                    best_speaker = speaker_id
        
        merged_result.append((utt_start, utt_end, text, best_speaker))
    
    return merged_result

def merge_consecutive_utterances(
    utterances_with_speakers: List[Tuple[float, float, str, int]],
    max_gap: float = 1.0
) -> List[Tuple[float, float, str, int]]:
    """
    Merge consecutive utterances from the same speaker into single utterances
    
    Args:
        utterances_with_speakers: List of (start, end, text, speaker_id) tuples
        max_gap: Maximum gap in seconds between utterances to merge
        
    Returns:
        List of merged (start, end, text, speaker_id) tuples
    """
    if not utterances_with_speakers:
        return utterances_with_speakers
    
    # Sort by start time to ensure correct order
    sorted_utterances = sorted(utterances_with_speakers, key=lambda x: x[0])
    
    merged = []
    current_start, current_end, current_text, current_speaker = sorted_utterances[0]
    
    for i in range(1, len(sorted_utterances)):
        next_start, next_end, next_text, next_speaker = sorted_utterances[i]
        
        # Check if we should merge: same speaker and gap is acceptable
        gap = next_start - current_end
        if current_speaker == next_speaker and gap <= max_gap:
            # Merge the utterances
            current_text = current_text.strip() + ' ' + next_text.strip()
            current_end = next_end
            print(f"✅ DEBUG: Merged consecutive utterances from Speaker {current_speaker}: [{current_start:.1f}-{current_end:.1f}s]")
        else:
            # Finalize current utterance and start new one
            merged.append((current_start, current_end, current_text, current_speaker))
            current_start, current_end, current_text, current_speaker = next_start, next_end, next_text, next_speaker
    
    # Add the last utterance
    merged.append((current_start, current_end, current_text, current_speaker))
    
    print(f"✅ DEBUG: Utterance merging complete: {len(utterances_with_speakers)}{len(merged)} utterances")
    return merged

def format_speaker_transcript(
    utterances_with_speakers: List[Tuple[float, float, str, int]]
) -> str:
    """
    Format transcript with speaker labels
    
    Args:
        utterances_with_speakers: List of (start, end, text, speaker_id)
        
    Returns:
        Formatted transcript string
    """
    if not utterances_with_speakers:
        return ""
    
    formatted_lines = []
    current_speaker = None
    
    for start, end, text, speaker_id in utterances_with_speakers:
        # Add speaker label when speaker changes
        if speaker_id != current_speaker:
            formatted_lines.append(f"\n**Speaker {speaker_id + 1}:**")
            current_speaker = speaker_id
        
        # Add timestamped utterance
        minutes = int(start // 60)
        seconds = int(start % 60)
        formatted_lines.append(f"[{minutes:02d}:{seconds:02d}] {text}")
    
    return "\n".join(formatted_lines)

def get_diarization_stats(
    utterances_with_speakers: List[Tuple[float, float, str, int]]
) -> dict:
    """
    Calculate speaker diarization statistics
    
    Returns:
        Dictionary with speaking time per speaker and other stats
    """
    if not utterances_with_speakers:
        return {}
    
    speaker_times = {}
    speaker_utterances = {}
    total_duration = 0
    
    for start, end, text, speaker_id in utterances_with_speakers:
        duration = end - start
        total_duration += duration
        
        if speaker_id not in speaker_times:
            speaker_times[speaker_id] = 0
            speaker_utterances[speaker_id] = 0
        
        speaker_times[speaker_id] += duration
        speaker_utterances[speaker_id] += 1
    
    # Calculate percentages
    stats = {
        "total_speakers": len(speaker_times),
        "total_duration": total_duration,
        "speakers": {}
    }
    
    for speaker_id in sorted(speaker_times.keys()):
        speaking_time = speaker_times[speaker_id]
        percentage = (speaking_time / total_duration * 100) if total_duration > 0 else 0
        
        stats["speakers"][speaker_id] = {
            "speaking_time": speaking_time,
            "percentage": percentage,
            "utterances": speaker_utterances[speaker_id],
            "avg_utterance_length": speaking_time / speaker_utterances[speaker_id] if speaker_utterances[speaker_id] > 0 else 0
        }
    
    return stats

def faiss_clustering(embeddings: np.ndarray,
                     utterances: list,
                     config_dict: dict,
                     progress_callback=None):
    """
    Clustering via FAISS (K-means) ultra-rapide CPU.
    Retourne la liste (start, end, speaker_id) compatible avec l'ancien code.
    """
    try:
        import faiss
    except ImportError:
        # FAISS absent → on retombe sur AgglomerativeClustering d'origine
        gen = sklearn_fallback_clustering(embeddings, utterances, config_dict, progress_callback)
        try:
            while True:
                p = next(gen)
                yield p
        except StopIteration as e:
            return e.value

    n_samples, dim = embeddings.shape
    n_clusters = config_dict['num_speakers']
    if n_clusters == -1:
        # Si très peu d'échantillons, attribuer tout au locuteur 0
        if n_samples < 3:
            if progress_callback:
                progress_callback(1.0)
            yield 1.0
            return [(s, e, 0) for (s, e, _t) in utterances]
        max_k = min(10, max(2, n_samples // 2))
        best_score, best_k, best_labels = -1.0, 2, None
        emb32 = embeddings.astype(np.float32)
        for k in range(2, max_k + 1):
            if k >= n_samples:  # éviter k == n_samples (silhouette invalide)
                break
            kmeans = faiss.Kmeans(dim, k, niter=25, verbose=False, seed=42)
            kmeans.train(emb32)
            _, lbls = kmeans.index.search(emb32, 1)
            lbls = lbls.ravel()
            uniq = set(lbls)
            if 1 < len(uniq) < n_samples:
                try:
                    sil = silhouette_score(embeddings, lbls)
                except Exception:
                    sil = -1.0
            else:
                sil = -1.0
            if sil > best_score:
                best_score, best_k, best_labels = sil, k, lbls
        if best_labels is None:
            # Fallback trivial: tout un seul locuteur
            if progress_callback:
                progress_callback(1.0)
            yield 1.0
            return [(s, e, 0) for (s, e, _t) in utterances]
        labels = best_labels
    else:
        kmeans = faiss.Kmeans(dim, min(n_clusters, n_samples), niter=20, verbose=False, seed=42)
        kmeans.train(embeddings.astype(np.float32))
        _, labels = kmeans.index.search(embeddings.astype(np.float32), 1)
        labels = labels.ravel()

    if progress_callback:
        progress_callback(1.0)
    yield 1.0

    num_speakers = len(set(labels)) if labels is not None else 1
    print(f"✅ DEBUG: FAISS clustering — {num_speakers} speakers, {len(utterances)} segments")
    logger.info(f"🎭 FAISS clustering completed! Detected {num_speakers} speakers")

    if labels is None:
        return [(s, e, 0) for (s, e, _t) in utterances]
    return [(start, end, int(lbl)) for (start, end, _), lbl in zip(utterances, labels)]


def sklearn_fallback_clustering(embeddings, utterances, config_dict, progress_callback=None):
    """
    Ancienne voie sklearn conservée pour fallback sans FAISS.
    """
    from sklearn.cluster import AgglomerativeClustering
    from sklearn.metrics.pairwise import cosine_similarity

    similarity_matrix = cosine_similarity(embeddings)
    distance_matrix = 1 - similarity_matrix

    n_clusters = config_dict['num_speakers']
    if n_clusters == -1:
        clustering = AgglomerativeClustering(
            n_clusters=None,
            distance_threshold=config_dict['cluster_threshold'],
            metric='precomputed',
            linkage='average'
        )
    else:
        clustering = AgglomerativeClustering(
            n_clusters=min(n_clusters, len(embeddings)),
            metric='precomputed',
            linkage='average'
        )

    if progress_callback:
        progress_callback(0.9)
    yield 0.9
    labels = clustering.fit_predict(distance_matrix)
    if progress_callback:
        progress_callback(1.0)
    yield 1.0

    return [(start, end, int(lbl)) for (start, end, _), lbl in zip(utterances, labels)]