File size: 9,405 Bytes
bdecbe9
645334b
7f613e2
 
 
 
a991705
 
bdecbe9
 
 
 
 
 
 
 
 
cea0ade
458a087
d7eb8e2
ae32743
f56d7bc
bdecbe9
8537a08
bdecbe9
 
 
 
 
 
 
 
ae32743
 
bdecbe9
 
 
 
d7eb8e2
bdecbe9
 
ae32743
 
8537a08
 
 
 
d7eb8e2
 
 
 
 
 
 
1cc0f11
 
 
 
 
 
 
bdecbe9
 
 
 
 
 
 
 
 
d7eb8e2
 
 
 
 
 
 
 
 
 
 
 
 
bdecbe9
 
 
 
1cc0f11
 
 
 
 
bdecbe9
 
 
 
 
 
 
 
 
 
 
d7eb8e2
ae32743
d7eb8e2
 
 
a8cd8c4
d7eb8e2
 
 
eaa333d
 
d7eb8e2
a8cd8c4
d7eb8e2
 
ae32743
 
 
 
f56d7bc
 
ae32743
 
f56d7bc
ae32743
 
 
 
 
f56d7bc
eaa333d
d7eb8e2
bdecbe9
 
 
 
 
6802a5d
458a087
6802a5d
bdecbe9
98ab6e8
bdecbe9
d7eb8e2
bdecbe9
e47bd9f
 
 
 
bdecbe9
 
 
 
 
 
 
d7eb8e2
ae32743
eaa333d
d7eb8e2
ae32743
d7eb8e2
ae32743
 
 
 
 
 
 
 
 
 
 
 
d7eb8e2
 
 
 
bdecbe9
 
 
 
 
d7eb8e2
bdecbe9
 
 
 
 
d7eb8e2
 
bdecbe9
 
 
 
 
e47bd9f
bdecbe9
d7eb8e2
 
bdecbe9
 
 
b6ba9c1
bdecbe9
cea0ade
 
 
 
 
 
d7eb8e2
cea0ade
 
d7eb8e2
cea0ade
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
os.environ["XDG_CONFIG_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface"  # pour les modèles/datasets
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"

import streamlit as st
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate
import tempfile
from huggingface_hub import snapshot_download
from transformers import pipeline
import openai
from openai import OpenAI


st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")

# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
    "Whisper Large (baseline)",
    "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
    "Whisper Large + LoRA + Post-processing Mistral 7B",
    "Whisper Large + LoRA + Post-processing GPT-4o"
))

# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")

openai_api_key = st.text_input("Clé API OpenAI (pour GPT-4o)", type="password")

if hf_token:
    from huggingface_hub import login
    login(hf_token)

# Section : Choix du split
split_option = st.selectbox(
    "Choix du split à évaluer",
    options=["Tous", "train", "validation", "test"],
    index=0  # par défaut "Tous"
)

# Section : Choix du nombre maximal d'exemples à évaluer
max_examples_option = st.selectbox(
    "Nombre maximum d'audios à traiter",
    options=["1", "5", "10", "Tous"],
    index=3  # par défaut "Tous"
)

# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")

if start_eval:
    st.subheader("🔍 Traitement en cours...")

    # 🔹 Télécharger dataset
    with st.spinner("Chargement du dataset..."):
        try:

            dataset_full = load_dataset(dataset_link, split="train", token=hf_token)

            # 🔹 Filtrage selon la colonne 'split'
            if split_option != "Tous":
                dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
            else:
                dataset = dataset_full
            
            if len(dataset) == 0:
                st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
                st.stop()
            
        except Exception as e:
            st.error(f"Erreur lors du chargement du dataset : {e}")
            st.stop()

     # Limiter le nombre d'exemples selon la sélection
    if max_examples_option != "Tous":
        max_examples = int(max_examples_option)
        dataset = dataset.select(range(min(max_examples, len(dataset))))

    # 🔹 Charger le modèle choisi
    with st.spinner("Chargement du modèle..."):
        base_model_name = "openai/whisper-large"
        model = WhisperForConditionalGeneration.from_pretrained(base_model_name)

        if "LoRA" in model_option:
            model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)

        processor = WhisperProcessor.from_pretrained(base_model_name)
        model.eval()

        # Charger le pipeline de Mistral si post-processing demandé
        if "Post-processing Mistral" in model_option:
            with st.spinner("Chargement du modèle de post-traitement Mistral..."):
                postproc_pipe = pipeline(
                    "text2text-generation",
                    model="mistralai/Mistral-7B-Instruct-v0.2",
                    device_map="auto",  # ou device=0 si tu veux forcer le GPU
                    torch_dtype=torch.float16  # optionnel mais plus léger
                )
                st.success("✅ Modèle Mistral chargé.")
                
                def postprocess_with_llm(text):
                    prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
                    result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
                    return result.strip()

    
    #fonction process GPT4o
    def postprocess_with_gpt4o(text, api_key):
        client = OpenAI(api_key=api_key)
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire."},
                {"role": "user", "content": f"Corrige ce texte : {text}"}
            ],
            temperature=0.3,
            max_tokens=512
        )
        return response.choices[0].message.content.strip()
    

    # 🔹 Préparer WER metric
    wer_metric = evaluate.load("wer")

    results = []

    # Téléchargement explicite du dossier audio (chemin local vers chaque fichier)
    repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token)

    for example in dataset:
        st.write("Exemple brut :", example)
        try:
           
            reference = example["text"]
            
            waveform = example["audio"]["array"]
            audio_path = example["audio"]["path"]
            
            waveform = np.expand_dims(waveform, axis=0)
            inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")

            with torch.no_grad():
                pred_ids = model.generate(input_features=inputs.input_features)
            prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

            # === Post-processing conditionnel ===
            if "Post-processing Mistral" in model_option:
                st.write("⏳ Post-processing avec Mistral...")
                postprocessed_prediction = postprocess_with_llm(prediction)
                st.write("✅ Terminé.")
                final_prediction = postprocessed_prediction
            
            elif "Post-processing GPT-4o" in model_option:
                if not openai_api_key:
                    st.error("Clé API OpenAI requise pour GPT-4o.")
                    st.stop()
                st.write("🤖 Post-processing avec GPT-4o...")
                try:
                    postprocessed_prediction = postprocess_with_gpt4o(prediction, openai_api_key)
                except Exception as e:
                    postprocessed_prediction = f"[Erreur GPT-4o: {e}]"
                final_prediction = postprocessed_prediction
            
            else:
                postprocessed_prediction = "-"
                final_prediction = prediction

            # 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
            def clean(text):
                return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()

            ref_clean = clean(reference)
            pred_clean = clean(final_prediction)
            wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])

            results.append({
                "Fichier": audio_path,
                "Référence": reference,
                "Transcription brute": prediction,
                "Transcription corrigée": postprocessed_prediction,
                "WER": round(wer, 4)
            })

        except Exception as e:
            results.append({
                "Fichier": example["audio"].get("path", "unknown"),
                "Référence": "Erreur",
                "Transcription brute": f"Erreur: {e}",
                "Transcription corrigée": "-",
                "WER": "-"
            })

    # 🔹 Générer le tableau de résultats
    df = pd.DataFrame(results)

    # 🔹 Créer un fichier temporaire pour le CSV
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv:
        df.to_csv(tmp_csv.name, index=False)

        mean_wer = df[df["WER"] != "-"]["WER"].mean()
        
        st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")




        # 🔹 Bouton de téléchargement
        with open(tmp_csv.name, "rb") as f:
            st.download_button(
                label="📥 Télécharger les résultats WER (.csv)",
                data=f,
                file_name="wer_results.csv",
                mime="text/csv"
            )