diarizacion / app.py
Merlintxu's picture
Update app.py
3b3b78f verified
# app.py
import os
import tempfile
import torch
import numpy as np
import datetime
import gc
import whisper
from pyannote.audio import Audio
from pyannote.core import Segment
from sklearn.cluster import AgglomerativeClustering
import gradio as gr
import warnings
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# --- Configuraci贸n de Modelos ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Usando dispositivo: {DEVICE}")
# --- Cargar Whisper (intentar con una versi贸n m谩s reciente si es viable) ---
WHISPER_MODEL_NAME = "small" # Empezar con 'small' para Spaces. Probar 'medium' o 'large-v3' si hay recursos.
try:
print(f"Cargando modelo Whisper '{WHISPER_MODEL_NAME}'...")
whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=DEVICE)
print(f"Modelo Whisper '{WHISPER_MODEL_NAME}' cargado exitosamente.")
except Exception as e:
print(f"Error cargando Whisper '{WHISPER_MODEL_NAME}': {e}")
print("Intentando cargar 'base' como fallback...")
WHISPER_MODEL_NAME = "base"
whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=DEVICE)
print(f"Modelo Whisper '{WHISPER_MODEL_NAME}' cargado.")
# --- Cargar modelo de embeddings de Pyannote v3.x ---
# Usar el nuevo modelo de embedding recomendado para pyannote.audio 3.x
EMBEDDING_MODEL_NAME = "pyannote/embedding"
EMBEDDING_REVISION = "main" # O especificar un commit si es necesario
try:
print(f"Cargando modelo de embeddings '{EMBEDDING_MODEL_NAME}'...")
# Importar el pipeline de embedding de pyannote v3
from pyannote.audio import Model
embedding_model = Model.from_pretrained(
EMBEDDING_MODEL_NAME,
use_auth_token=False, # No se necesita token para modelos p煤blicos
revision=EMBEDDING_REVISION
)
embedding_model.to(DEVICE)
print(f"Modelo de embeddings '{EMBEDDING_MODEL_NAME}' cargado.")
except Exception as e:
print(f"Error cargando el modelo de embeddings '{EMBEDDING_MODEL_NAME}': {e}")
print("Intentando con speechbrain como fallback...")
# Fallback al modelo SpeechBrain si el de Pyannote falla
try:
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
embedding_model = PretrainedSpeakerEmbedding(
"speechbrain/spkrec-ecapa-voxceleb",
device=DEVICE
)
print("Modelo de embeddings 'speechbrain/spkrec-ecapa-voxceleb' cargado como fallback.")
except Exception as e_fallback:
print(f"Error cr铆tico cargando modelo de embeddings: {e_fallback}")
raise RuntimeError("No se pudo cargar ning煤n modelo de embeddings.")
audio_processor = Audio()
def time(secs):
return datetime.timedelta(seconds=round(secs))
def convert_to_wav(input_path):
"""Convierte cualquier audio a WAV mono 16kHz usando ffmpeg."""
if input_path.lower().endswith('.wav'):
# Verificar si ya es mono y 16kHz podr铆a ser 煤til, pero para simplificar, convertimos siempre
pass
# Usar un nombre temporal seguro
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpfile:
output_path = tmpfile.name
# Comando ffmpeg para convertir a WAV mono 16kHz
cmd = f"ffmpeg -y -i '{input_path}' -ac 1 -ar 16000 -acodec pcm_s16le '{output_path}'"
print(f"Ejecutando conversi贸n: {cmd}")
os.system(cmd)
if not os.path.exists(output_path) or os.path.getsize(output_path) == 0:
raise RuntimeError("La conversi贸n a WAV fall贸 o produjo un archivo vac铆o.")
return output_path
def get_duration(path):
import soundfile as sf
try:
info = sf.info(path)
return info.duration
except Exception as e:
print(f"Error obteniendo duraci贸n con soundfile: {e}")
# Fallback a wave (menos robusto)
import wave
import contextlib
with contextlib.closing(wave.open(path,'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
return frames / float(rate)
def segment_embedding(path, segment, duration):
start = segment["start"]
end = min(duration, segment["end"])
clip = Segment(start, end)
try:
waveform, sample_rate = audio_processor.crop(path, clip)
with torch.no_grad():
# Para modelos Pyannote v3
if hasattr(embedding_model, 'encode'):
# Modelos nuevos de pyannote devuelven diccionarios
output = embedding_model.encode(waveform[None].to(DEVICE))
if isinstance(output, dict) and 'embedding' in output:
embedding = output['embedding']
else:
embedding = output
else:
# Fallback para modelos compatibles con la API antigua o SpeechBrain
embedding = embedding_model(waveform[None].to(DEVICE))
# Asegurar que el embedding sea un tensor y luego numpy
if isinstance(embedding, torch.Tensor):
return embedding.squeeze().cpu().numpy()
else:
# Para embeddings que ya son numpy (ej. SpeechBrain wrapper)
return np.squeeze(embedding)
except Exception as e:
print(f"Error extrayendo embedding para segmento {start}-{end}: {e}")
# Devolver un embedding de ceros en caso de error
return np.zeros(512) # Ajustar tama帽o si se sabe el dim del embedding
def transcribe_and_diarize(audio_file, num_speakers):
"""Funci贸n principal de transcripci贸n y diarizaci贸n."""
temp_files = []
try:
status_update = ""
# --- 1. Conversi贸n ---
status_update += "1. Convirtiendo audio a formato WAV (16kHz, mono)...\n"
yield status_update, ""
wav_path = convert_to_wav(audio_file)
temp_files.append(wav_path) # Para limpieza posterior
# --- 2. Duraci贸n ---
status_update += "2. Obteniendo duraci贸n del audio...\n"
yield status_update, ""
duration = get_duration(wav_path)
if duration > 30 * 60: # Limitar a 30 minutos
yield status_update + "Error: El audio es demasiado largo (m谩ximo 30 minutos).\n", ""
return
# --- 3. Transcripci贸n ---
status_update += f"3. Transcribiendo audio con Whisper (modelo '{WHISPER_MODEL_NAME}')...\n"
yield status_update, ""
# Transcribir en espa帽ol
result = whisper_model.transcribe(wav_path, language='es', task='transcribe', verbose=False)
segments = result["segments"]
if not segments:
yield status_update + "Error: No se detect贸 habla en el audio.\n", ""
return
# --- 4. Diarizaci贸n ---
status_update += "4. Preparando para diarizaci贸n...\n"
yield status_update, ""
# Limitar n煤mero de hablantes
num_speakers = max(2, min(6, int(num_speakers)))
num_speakers = min(num_speakers, len(segments))
if len(segments) <= 1:
segments[0]['speaker'] = 'HABLANTE 1'
status_update += " -> Solo se detect贸 1 segmento de habla. Asignando un hablante.\n"
else:
status_update += " -> Extrayendo embeddings de audio...\n"
yield status_update, ""
# Determinar la dimensi贸n del embedding con una muestra
sample_embedding = segment_embedding(wav_path, segments[0], duration)
embedding_dim = sample_embedding.shape[-1] if hasattr(sample_embedding, 'shape') else 512
print(f"Dimensi贸n del embedding detectada: {embedding_dim}")
embeddings = np.zeros(shape=(len(segments), embedding_dim))
for i, segment in enumerate(segments):
embeddings[i] = segment_embedding(wav_path, segment, duration)
embeddings = np.nan_to_num(embeddings)
status_update += " -> Agrupando hablantes...\n"
yield status_update, ""
# Clustering
clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
labels = clustering.labels_
for i in range(len(segments)):
segments[i]["speaker"] = f'HABLANTE {labels[i] + 1}'
# --- 5. Formateo de salida ---
status_update += "5. Generando transcripci贸n final...\n"
yield status_update, ""
output_text = ""
for (i, segment) in enumerate(segments):
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
if i != 0:
output_text += '\n\n'
output_text += f"{segment['speaker']} [{time(segment['start'])}]\n\n"
output_text += segment["text"].strip() + ' '
yield status_update + "隆Proceso completado!\n", output_text
except Exception as e:
error_msg = f"Error durante el proceso: {str(e)}"
print(error_msg)
yield f"Error: {error_msg}\n", ""
finally:
# Limpiar archivos temporales
for f in temp_files:
try:
os.remove(f)
print(f"Archivo temporal eliminado: {f}")
except OSError:
pass
# Liberar memoria GPU/CPU (manera m谩s segura)
import sys
if 'whisper_model' in sys.modules.get(__name__, {}).__dict__:
try:
del sys.modules[__name__].whisper_model
print("Modelo Whisper eliminado de la memoria.")
except Exception as e:
print(f"Error al eliminar whisper_model: {e}")
if 'embedding_model' in sys.modules.get(__name__, {}).__dict__:
try:
del sys.modules[__name__].embedding_model
print("Modelo de embeddings eliminado de la memoria.")
except Exception as e:
print(f"Error al eliminar embedding_model: {e}")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Interfaz Gradio ---
with gr.Blocks(title="Diarizaci贸n de Audio en Espa帽ol") as demo:
gr.Markdown("# 馃帳 Diarizaci贸n de Audio en Espa帽ol")
gr.Markdown("Sube un archivo de audio (hasta 30 minutos) y obt茅n una transcripci贸n separada por hablantes. Optimizado para espa帽ol.")
gr.Markdown("**Nota:** Este demo usa modelos ligeros. Para audio con mucho ruido o m谩s de 10 minutos, los resultados pueden ser menos precisos.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Subir Audio", type="filepath")
num_speakers = gr.Slider(2, 6, value=3, step=1, label="N煤mero aproximado de hablantes")
run_button = gr.Button("馃殌 Iniciar Diarizaci贸n")
with gr.Column():
status_output = gr.Textbox(label="Estado", interactive=False, lines=10, max_lines=10)
text_output = gr.Textbox(label="Transcripci贸n con Hablantes", interactive=False, lines=20)
run_button.click(
fn=transcribe_and_diarize,
inputs=[audio_input, num_speakers],
outputs=[status_output, text_output],
queue=True,
concurrency_limit=1 # Limitar a 1 ejecuci贸n simult谩nea para evitar sobrecarga
)
gr.Markdown("---")
gr.Markdown("**Modelos Usados:**\n"
"* **Transcripci贸n:** Whisper (`large-v3`)\n"
"* **Diarizaci贸n:** Pyannote.Audio (`pyannote/embedding` o `speechbrain/spkrec-ecapa-voxceleb`)\n")
# Para Hugging Face Spaces
demo.launch()