Spaces:
Sleeping
Sleeping
File size: 9,405 Bytes
bdecbe9 645334b 7f613e2 a991705 bdecbe9 cea0ade 458a087 d7eb8e2 ae32743 f56d7bc bdecbe9 8537a08 bdecbe9 ae32743 bdecbe9 d7eb8e2 bdecbe9 ae32743 8537a08 d7eb8e2 1cc0f11 bdecbe9 d7eb8e2 bdecbe9 1cc0f11 bdecbe9 d7eb8e2 ae32743 d7eb8e2 a8cd8c4 d7eb8e2 eaa333d d7eb8e2 a8cd8c4 d7eb8e2 ae32743 f56d7bc ae32743 f56d7bc ae32743 f56d7bc eaa333d d7eb8e2 bdecbe9 6802a5d 458a087 6802a5d bdecbe9 98ab6e8 bdecbe9 d7eb8e2 bdecbe9 e47bd9f bdecbe9 d7eb8e2 ae32743 eaa333d d7eb8e2 ae32743 d7eb8e2 ae32743 d7eb8e2 bdecbe9 d7eb8e2 bdecbe9 d7eb8e2 bdecbe9 e47bd9f bdecbe9 d7eb8e2 bdecbe9 b6ba9c1 bdecbe9 cea0ade d7eb8e2 cea0ade d7eb8e2 cea0ade |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
os.environ["XDG_CONFIG_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
import streamlit as st
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate
import tempfile
from huggingface_hub import snapshot_download
from transformers import pipeline
import openai
from openai import OpenAI
st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
"Whisper Large (baseline)",
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
"Whisper Large + LoRA + Post-processing Mistral 7B",
"Whisper Large + LoRA + Post-processing GPT-4o"
))
# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
openai_api_key = st.text_input("Clé API OpenAI (pour GPT-4o)", type="password")
if hf_token:
from huggingface_hub import login
login(hf_token)
# Section : Choix du split
split_option = st.selectbox(
"Choix du split à évaluer",
options=["Tous", "train", "validation", "test"],
index=0 # par défaut "Tous"
)
# Section : Choix du nombre maximal d'exemples à évaluer
max_examples_option = st.selectbox(
"Nombre maximum d'audios à traiter",
options=["1", "5", "10", "Tous"],
index=3 # par défaut "Tous"
)
# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")
if start_eval:
st.subheader("🔍 Traitement en cours...")
# 🔹 Télécharger dataset
with st.spinner("Chargement du dataset..."):
try:
dataset_full = load_dataset(dataset_link, split="train", token=hf_token)
# 🔹 Filtrage selon la colonne 'split'
if split_option != "Tous":
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
else:
dataset = dataset_full
if len(dataset) == 0:
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
st.stop()
except Exception as e:
st.error(f"Erreur lors du chargement du dataset : {e}")
st.stop()
# Limiter le nombre d'exemples selon la sélection
if max_examples_option != "Tous":
max_examples = int(max_examples_option)
dataset = dataset.select(range(min(max_examples, len(dataset))))
# 🔹 Charger le modèle choisi
with st.spinner("Chargement du modèle..."):
base_model_name = "openai/whisper-large"
model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
if "LoRA" in model_option:
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)
processor = WhisperProcessor.from_pretrained(base_model_name)
model.eval()
# Charger le pipeline de Mistral si post-processing demandé
if "Post-processing Mistral" in model_option:
with st.spinner("Chargement du modèle de post-traitement Mistral..."):
postproc_pipe = pipeline(
"text2text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
device_map="auto", # ou device=0 si tu veux forcer le GPU
torch_dtype=torch.float16 # optionnel mais plus léger
)
st.success("✅ Modèle Mistral chargé.")
def postprocess_with_llm(text):
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
return result.strip()
#fonction process GPT4o
def postprocess_with_gpt4o(text, api_key):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire."},
{"role": "user", "content": f"Corrige ce texte : {text}"}
],
temperature=0.3,
max_tokens=512
)
return response.choices[0].message.content.strip()
# 🔹 Préparer WER metric
wer_metric = evaluate.load("wer")
results = []
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier)
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token)
for example in dataset:
st.write("Exemple brut :", example)
try:
reference = example["text"]
waveform = example["audio"]["array"]
audio_path = example["audio"]["path"]
waveform = np.expand_dims(waveform, axis=0)
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
pred_ids = model.generate(input_features=inputs.input_features)
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
# === Post-processing conditionnel ===
if "Post-processing Mistral" in model_option:
st.write("⏳ Post-processing avec Mistral...")
postprocessed_prediction = postprocess_with_llm(prediction)
st.write("✅ Terminé.")
final_prediction = postprocessed_prediction
elif "Post-processing GPT-4o" in model_option:
if not openai_api_key:
st.error("Clé API OpenAI requise pour GPT-4o.")
st.stop()
st.write("🤖 Post-processing avec GPT-4o...")
try:
postprocessed_prediction = postprocess_with_gpt4o(prediction, openai_api_key)
except Exception as e:
postprocessed_prediction = f"[Erreur GPT-4o: {e}]"
final_prediction = postprocessed_prediction
else:
postprocessed_prediction = "-"
final_prediction = prediction
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
def clean(text):
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
ref_clean = clean(reference)
pred_clean = clean(final_prediction)
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
results.append({
"Fichier": audio_path,
"Référence": reference,
"Transcription brute": prediction,
"Transcription corrigée": postprocessed_prediction,
"WER": round(wer, 4)
})
except Exception as e:
results.append({
"Fichier": example["audio"].get("path", "unknown"),
"Référence": "Erreur",
"Transcription brute": f"Erreur: {e}",
"Transcription corrigée": "-",
"WER": "-"
})
# 🔹 Générer le tableau de résultats
df = pd.DataFrame(results)
# 🔹 Créer un fichier temporaire pour le CSV
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv:
df.to_csv(tmp_csv.name, index=False)
mean_wer = df[df["WER"] != "-"]["WER"].mean()
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
# 🔹 Bouton de téléchargement
with open(tmp_csv.name, "rb") as f:
st.download_button(
label="📥 Télécharger les résultats WER (.csv)",
data=f,
file_name="wer_results.csv",
mime="text/csv"
)
|