File size: 4,590 Bytes
537e135
 
 
 
 
 
 
 
 
3605c65
 
9d7081c
 
537e135
3605c65
537e135
 
 
 
 
 
9d7081c
537e135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3605c65
 
 
 
 
9d7081c
537e135
 
 
 
 
3605c65
 
537e135
 
 
 
 
3605c65
537e135
3605c65
 
537e135
 
 
 
 
3605c65
537e135
3605c65
 
537e135
3605c65
537e135
 
 
 
 
 
 
 
3605c65
537e135
 
 
9d7081c
537e135
3605c65
537e135
 
 
 
 
 
 
 
8555957
537e135
 
8555957
 
537e135
8555957
537e135
 
8555957
 
537e135
8555957
 
537e135
 
 
 
 
 
 
8555957
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import tempfile
import json
import torchaudio
import gradio as gr

from faster_whisper import WhisperModel
import whisperx
from pyannote.audio import Pipeline as DiarizationPipeline
from transformers import pipeline

# NER via transformers (CamemBERT pour le français)
ner_pipeline = pipeline("ner", model="Jean-Baptiste/camembert-ner", aggregation_strategy="simple")

# Configuration des modèles
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

diarization_pipeline = DiarizationPipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=HUGGINGFACE_TOKEN
)

whisper_model = WhisperModel("large-v2", device="cpu", compute_type="int8")
whisperx_model = whisperx.load_model("large-v2", device="cpu", compute_type="int8")
align_model, metadata = whisperx.load_align_model(language_code="fr", device="cpu")

def convert_to_wav_if_needed(audio_path: str) -> str:
    if audio_path.lower().endswith(".mp3"):
        new_path = audio_path[:-4] + ".wav"
        waveform, sr = torchaudio.load(audio_path)
        torchaudio.save(new_path, waveform, sr)
        return new_path
    return audio_path

def get_speaker_segments(audio_path: str) -> list:
    diarization = diarization_pipeline(audio_path)
    segments = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        start, end = float(turn.start), float(turn.end)
        if end - start < 1.0:
            continue
        if end - start > 10.0:
            end = start + 10.0
        segments.append({"start": start, "end": end, "speaker": speaker})
    return segments

def transcribe_with_alignment(audio_path: str, segments: list) -> list:
    word_segments_all = []
    waveform, sr = torchaudio.load(audio_path)
    for seg in segments:
        start, end, speaker = seg["start"], seg["end"], seg["speaker"]
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            temp_audio_path = tmp.name
        segment_waveform = waveform[:, int(start * sr): int(end * sr)]
        torchaudio.save(temp_audio_path, segment_waveform, sr)
        whisper_result = whisperx_model.transcribe(temp_audio_path)
        aligned = whisperx.align(whisper_result["segments"], align_model, metadata, temp_audio_path, device="cpu")
        for word in aligned.get("word_segments", []):
            word["start"] += start
            word["end"] += start
            word["speaker"] = speaker
            word_segments_all.append(word)
        os.remove(temp_audio_path)
    return word_segments_all

def extract_entities(word_segments: list):
    full_text = " ".join([w["text"] for w in word_segments])
    entities_raw = ner_pipeline(full_text)
    return full_text, entities_raw

def process_pipeline(audio_path: str):
    audio_path = convert_to_wav_if_needed(audio_path)
    segments = get_speaker_segments(audio_path)
    words = transcribe_with_alignment(audio_path, segments)

    aligned_path = audio_path.replace(".wav", "_aligned.json")
    with open(aligned_path, "w", encoding="utf-8") as f:
        json.dump(words, f, ensure_ascii=False, indent=2)

    full_text, named_entities = extract_entities(words)

    meta_path = audio_path.replace(".wav", "_meta.json")
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump({"text": full_text, "entities": named_entities}, f, ensure_ascii=False, indent=2)

    return full_text, named_entities, aligned_path, meta_path

def gradio_process(audio_file_path):
    if audio_file_path is None:
        return "", [], None, None
    try:
        texte, ents, aligned_json, meta_json = process_pipeline(audio_file_path)
        return texte, ents, aligned_json, meta_json
    except Exception as e:
        return f"Erreur : {str(e)}", [], None, None

with gr.Blocks() as demo:
    gr.Markdown("## Transcription + Diarisation + NER en français")
    gr.Markdown("- Le texte brut\n- Les entités nommées détectées\n- Les fichiers JSON générés")
    with gr.Row():
        audio_input = gr.File(label="Sélectionnez un fichier audio", type="filepath")
        run_button = gr.Button("Lancer")
    with gr.Row():
        punctuated_output = gr.Textbox(label="Texte brut", lines=10)
        entities_output = gr.JSON(label="Entités nommées")
    with gr.Row():
        aligned_output = gr.File(label="Aligned JSON")
        meta_output = gr.File(label="Meta JSON")
    run_button.click(
        gradio_process,
        inputs=[audio_input],
        outputs=[punctuated_output, entities_output, aligned_output, meta_output]
    )

if __name__ == "__main__":
    demo.launch(share=True)