|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import Dict, Iterable, List, Optional, Tuple |
|
|
|
|
|
import soundfile as sf |
|
|
from fastapi import HTTPException |
|
|
|
|
|
from src.asr import transcribe_file |
|
|
from src.diarization import ( |
|
|
get_diarization_stats, |
|
|
init_speaker_embedding_extractor, |
|
|
merge_consecutive_utterances, |
|
|
merge_transcription_with_diarization, |
|
|
perform_speaker_diarization_on_utterances, |
|
|
) |
|
|
from src.utils import sensevoice_models |
|
|
|
|
|
from ..core.config import get_settings |
|
|
from ..models.transcription import DiarizationOptions, TranscriptionRequest |
|
|
|
|
|
settings = get_settings() |
|
|
|
|
|
|
|
|
def _serialize_utterance(utt: Tuple[float, float, str], speaker: Optional[int] = None) -> Dict[str, object]: |
|
|
start, end, text = utt |
|
|
payload: Dict[str, object] = { |
|
|
"start": round(float(start), 3), |
|
|
"end": round(float(end), 3), |
|
|
"text": text, |
|
|
} |
|
|
if speaker is not None: |
|
|
payload["speaker"] = int(speaker) |
|
|
return payload |
|
|
|
|
|
|
|
|
def _prepare_model_name(options: TranscriptionRequest) -> str: |
|
|
if options.backend == "sensevoice": |
|
|
|
|
|
return sensevoice_models.get(options.model_name, options.model_name) |
|
|
return options.model_name |
|
|
|
|
|
|
|
|
def iter_transcription_events( |
|
|
audio_path: Path, |
|
|
audio_url: str, |
|
|
options: TranscriptionRequest, |
|
|
) -> Iterable[Dict[str, object]]: |
|
|
model_name = _prepare_model_name(options) |
|
|
|
|
|
try: |
|
|
generator = transcribe_file( |
|
|
audio_path=str(audio_path), |
|
|
vad_threshold=options.vad_threshold, |
|
|
model_name=model_name, |
|
|
backend=options.backend, |
|
|
language=options.language, |
|
|
textnorm=options.textnorm, |
|
|
) |
|
|
|
|
|
yield { |
|
|
"type": "ready", |
|
|
"audioUrl": audio_url, |
|
|
"backend": options.backend, |
|
|
"model": model_name, |
|
|
} |
|
|
|
|
|
yield { |
|
|
"type": "status", |
|
|
"message": "Transcribing audio...", |
|
|
} |
|
|
|
|
|
final_utterances: List[Tuple[float, float, str]] = [] |
|
|
|
|
|
for current_utterance, all_utterances, progress in generator: |
|
|
if current_utterance: |
|
|
start, end, text = current_utterance |
|
|
yield { |
|
|
"type": "utterance", |
|
|
"utterance": _serialize_utterance((start, end, text)), |
|
|
"index": len(all_utterances) - 1, |
|
|
"progress": round(progress, 1), |
|
|
} |
|
|
final_utterances = list(all_utterances) |
|
|
|
|
|
|
|
|
diarization_payload = None |
|
|
if options.diarization.enable: |
|
|
yield { |
|
|
"type": "status", |
|
|
"message": "Performing speaker diarization...", |
|
|
} |
|
|
diarization_gen = _run_diarization(audio_path, final_utterances, options.diarization) |
|
|
for event in diarization_gen: |
|
|
if event["type"] == "progress": |
|
|
yield event |
|
|
elif event["type"] == "result": |
|
|
diarization_payload = event["payload"] |
|
|
break |
|
|
|
|
|
transcript_text = "\n".join([utt[2] for utt in final_utterances]) |
|
|
|
|
|
yield { |
|
|
"type": "complete", |
|
|
"utterances": [_serialize_utterance(utt) for utt in final_utterances], |
|
|
"transcript": transcript_text, |
|
|
"diarization": diarization_payload, |
|
|
} |
|
|
|
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {exc}") |
|
|
|
|
|
|
|
|
def _run_diarization( |
|
|
audio_path: Path, |
|
|
utterances: List[Tuple[float, float, str]], |
|
|
options: DiarizationOptions, |
|
|
): |
|
|
if not utterances: |
|
|
yield {"type": "result", "payload": None} |
|
|
return |
|
|
|
|
|
extractor_result = init_speaker_embedding_extractor( |
|
|
cluster_threshold=options.cluster_threshold, |
|
|
num_speakers=options.num_speakers, |
|
|
) |
|
|
if not extractor_result: |
|
|
yield {"type": "result", "payload": None} |
|
|
return |
|
|
|
|
|
embedding_extractor, config_dict = extractor_result |
|
|
|
|
|
audio, sample_rate = sf.read(str(audio_path), dtype="float32") |
|
|
if audio.ndim > 1: |
|
|
audio = audio.mean(axis=1) |
|
|
|
|
|
if sample_rate != 16000: |
|
|
|
|
|
from scipy.signal import resample |
|
|
|
|
|
target_num_samples = int(len(audio) * 16000 / sample_rate) |
|
|
audio = resample(audio, target_num_samples) |
|
|
sample_rate = 16000 |
|
|
|
|
|
diarization_gen = perform_speaker_diarization_on_utterances( |
|
|
audio=audio, |
|
|
sample_rate=sample_rate, |
|
|
utterances=utterances, |
|
|
embedding_extractor=embedding_extractor, |
|
|
config_dict=config_dict, |
|
|
progress_callback=None, |
|
|
) |
|
|
|
|
|
diarization_segments = None |
|
|
try: |
|
|
while True: |
|
|
item = next(diarization_gen) |
|
|
if isinstance(item, float): |
|
|
yield {"type": "progress", "stage": "diarization", "progress": round(item * 100, 1)} |
|
|
else: |
|
|
diarization_segments = item |
|
|
break |
|
|
except StopIteration as e: |
|
|
diarization_segments = e.value |
|
|
|
|
|
if not diarization_segments: |
|
|
yield {"type": "result", "payload": None} |
|
|
return |
|
|
|
|
|
merged = merge_transcription_with_diarization(utterances, diarization_segments) |
|
|
merged = merge_consecutive_utterances(merged, max_gap=1.0) |
|
|
stats = get_diarization_stats(merged) |
|
|
|
|
|
yield {"type": "result", "payload": { |
|
|
"utterances": [ |
|
|
_serialize_utterance((start, end, text), speaker) |
|
|
for start, end, text, speaker in merged |
|
|
], |
|
|
"stats": stats, |
|
|
}} |
|
|
|