Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import tempfile | |
import torch | |
import numpy as np | |
import datetime | |
import gc | |
import whisper | |
from pyannote.audio import Audio | |
from pyannote.core import Segment | |
from sklearn.cluster import AgglomerativeClustering | |
import gradio as gr | |
import warnings | |
from huggingface_hub import hf_hub_download | |
warnings.filterwarnings("ignore", category=UserWarning) | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
# --- Configuraci贸n de Modelos --- | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Usando dispositivo: {DEVICE}") | |
# --- Cargar Whisper (intentar con una versi贸n m谩s reciente si es viable) --- | |
WHISPER_MODEL_NAME = "small" # Empezar con 'small' para Spaces. Probar 'medium' o 'large-v3' si hay recursos. | |
try: | |
print(f"Cargando modelo Whisper '{WHISPER_MODEL_NAME}'...") | |
whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=DEVICE) | |
print(f"Modelo Whisper '{WHISPER_MODEL_NAME}' cargado exitosamente.") | |
except Exception as e: | |
print(f"Error cargando Whisper '{WHISPER_MODEL_NAME}': {e}") | |
print("Intentando cargar 'base' como fallback...") | |
WHISPER_MODEL_NAME = "base" | |
whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=DEVICE) | |
print(f"Modelo Whisper '{WHISPER_MODEL_NAME}' cargado.") | |
# --- Cargar modelo de embeddings de Pyannote v3.x --- | |
# Usar el nuevo modelo de embedding recomendado para pyannote.audio 3.x | |
EMBEDDING_MODEL_NAME = "pyannote/embedding" | |
EMBEDDING_REVISION = "main" # O especificar un commit si es necesario | |
try: | |
print(f"Cargando modelo de embeddings '{EMBEDDING_MODEL_NAME}'...") | |
# Importar el pipeline de embedding de pyannote v3 | |
from pyannote.audio import Model | |
embedding_model = Model.from_pretrained( | |
EMBEDDING_MODEL_NAME, | |
use_auth_token=False, # No se necesita token para modelos p煤blicos | |
revision=EMBEDDING_REVISION | |
) | |
embedding_model.to(DEVICE) | |
print(f"Modelo de embeddings '{EMBEDDING_MODEL_NAME}' cargado.") | |
except Exception as e: | |
print(f"Error cargando el modelo de embeddings '{EMBEDDING_MODEL_NAME}': {e}") | |
print("Intentando con speechbrain como fallback...") | |
# Fallback al modelo SpeechBrain si el de Pyannote falla | |
try: | |
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding | |
embedding_model = PretrainedSpeakerEmbedding( | |
"speechbrain/spkrec-ecapa-voxceleb", | |
device=DEVICE | |
) | |
print("Modelo de embeddings 'speechbrain/spkrec-ecapa-voxceleb' cargado como fallback.") | |
except Exception as e_fallback: | |
print(f"Error cr铆tico cargando modelo de embeddings: {e_fallback}") | |
raise RuntimeError("No se pudo cargar ning煤n modelo de embeddings.") | |
audio_processor = Audio() | |
def time(secs): | |
return datetime.timedelta(seconds=round(secs)) | |
def convert_to_wav(input_path): | |
"""Convierte cualquier audio a WAV mono 16kHz usando ffmpeg.""" | |
if input_path.lower().endswith('.wav'): | |
# Verificar si ya es mono y 16kHz podr铆a ser 煤til, pero para simplificar, convertimos siempre | |
pass | |
# Usar un nombre temporal seguro | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpfile: | |
output_path = tmpfile.name | |
# Comando ffmpeg para convertir a WAV mono 16kHz | |
cmd = f"ffmpeg -y -i '{input_path}' -ac 1 -ar 16000 -acodec pcm_s16le '{output_path}'" | |
print(f"Ejecutando conversi贸n: {cmd}") | |
os.system(cmd) | |
if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: | |
raise RuntimeError("La conversi贸n a WAV fall贸 o produjo un archivo vac铆o.") | |
return output_path | |
def get_duration(path): | |
import soundfile as sf | |
try: | |
info = sf.info(path) | |
return info.duration | |
except Exception as e: | |
print(f"Error obteniendo duraci贸n con soundfile: {e}") | |
# Fallback a wave (menos robusto) | |
import wave | |
import contextlib | |
with contextlib.closing(wave.open(path,'r')) as f: | |
frames = f.getnframes() | |
rate = f.getframerate() | |
return frames / float(rate) | |
def segment_embedding(path, segment, duration): | |
start = segment["start"] | |
end = min(duration, segment["end"]) | |
clip = Segment(start, end) | |
try: | |
waveform, sample_rate = audio_processor.crop(path, clip) | |
with torch.no_grad(): | |
# Para modelos Pyannote v3 | |
if hasattr(embedding_model, 'encode'): | |
# Modelos nuevos de pyannote devuelven diccionarios | |
output = embedding_model.encode(waveform[None].to(DEVICE)) | |
if isinstance(output, dict) and 'embedding' in output: | |
embedding = output['embedding'] | |
else: | |
embedding = output | |
else: | |
# Fallback para modelos compatibles con la API antigua o SpeechBrain | |
embedding = embedding_model(waveform[None].to(DEVICE)) | |
# Asegurar que el embedding sea un tensor y luego numpy | |
if isinstance(embedding, torch.Tensor): | |
return embedding.squeeze().cpu().numpy() | |
else: | |
# Para embeddings que ya son numpy (ej. SpeechBrain wrapper) | |
return np.squeeze(embedding) | |
except Exception as e: | |
print(f"Error extrayendo embedding para segmento {start}-{end}: {e}") | |
# Devolver un embedding de ceros en caso de error | |
return np.zeros(512) # Ajustar tama帽o si se sabe el dim del embedding | |
def transcribe_and_diarize(audio_file, num_speakers): | |
"""Funci贸n principal de transcripci贸n y diarizaci贸n.""" | |
temp_files = [] | |
try: | |
status_update = "" | |
# --- 1. Conversi贸n --- | |
status_update += "1. Convirtiendo audio a formato WAV (16kHz, mono)...\n" | |
yield status_update, "" | |
wav_path = convert_to_wav(audio_file) | |
temp_files.append(wav_path) # Para limpieza posterior | |
# --- 2. Duraci贸n --- | |
status_update += "2. Obteniendo duraci贸n del audio...\n" | |
yield status_update, "" | |
duration = get_duration(wav_path) | |
if duration > 30 * 60: # Limitar a 30 minutos | |
yield status_update + "Error: El audio es demasiado largo (m谩ximo 30 minutos).\n", "" | |
return | |
# --- 3. Transcripci贸n --- | |
status_update += f"3. Transcribiendo audio con Whisper (modelo '{WHISPER_MODEL_NAME}')...\n" | |
yield status_update, "" | |
# Transcribir en espa帽ol | |
result = whisper_model.transcribe(wav_path, language='es', task='transcribe', verbose=False) | |
segments = result["segments"] | |
if not segments: | |
yield status_update + "Error: No se detect贸 habla en el audio.\n", "" | |
return | |
# --- 4. Diarizaci贸n --- | |
status_update += "4. Preparando para diarizaci贸n...\n" | |
yield status_update, "" | |
# Limitar n煤mero de hablantes | |
num_speakers = max(2, min(6, int(num_speakers))) | |
num_speakers = min(num_speakers, len(segments)) | |
if len(segments) <= 1: | |
segments[0]['speaker'] = 'HABLANTE 1' | |
status_update += " -> Solo se detect贸 1 segmento de habla. Asignando un hablante.\n" | |
else: | |
status_update += " -> Extrayendo embeddings de audio...\n" | |
yield status_update, "" | |
# Determinar la dimensi贸n del embedding con una muestra | |
sample_embedding = segment_embedding(wav_path, segments[0], duration) | |
embedding_dim = sample_embedding.shape[-1] if hasattr(sample_embedding, 'shape') else 512 | |
print(f"Dimensi贸n del embedding detectada: {embedding_dim}") | |
embeddings = np.zeros(shape=(len(segments), embedding_dim)) | |
for i, segment in enumerate(segments): | |
embeddings[i] = segment_embedding(wav_path, segment, duration) | |
embeddings = np.nan_to_num(embeddings) | |
status_update += " -> Agrupando hablantes...\n" | |
yield status_update, "" | |
# Clustering | |
clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings) | |
labels = clustering.labels_ | |
for i in range(len(segments)): | |
segments[i]["speaker"] = f'HABLANTE {labels[i] + 1}' | |
# --- 5. Formateo de salida --- | |
status_update += "5. Generando transcripci贸n final...\n" | |
yield status_update, "" | |
output_text = "" | |
for (i, segment) in enumerate(segments): | |
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: | |
if i != 0: | |
output_text += '\n\n' | |
output_text += f"{segment['speaker']} [{time(segment['start'])}]\n\n" | |
output_text += segment["text"].strip() + ' ' | |
yield status_update + "隆Proceso completado!\n", output_text | |
except Exception as e: | |
error_msg = f"Error durante el proceso: {str(e)}" | |
print(error_msg) | |
yield f"Error: {error_msg}\n", "" | |
finally: | |
# Limpiar archivos temporales | |
for f in temp_files: | |
try: | |
os.remove(f) | |
print(f"Archivo temporal eliminado: {f}") | |
except OSError: | |
pass | |
# Liberar memoria GPU/CPU (manera m谩s segura) | |
import sys | |
if 'whisper_model' in sys.modules.get(__name__, {}).__dict__: | |
try: | |
del sys.modules[__name__].whisper_model | |
print("Modelo Whisper eliminado de la memoria.") | |
except Exception as e: | |
print(f"Error al eliminar whisper_model: {e}") | |
if 'embedding_model' in sys.modules.get(__name__, {}).__dict__: | |
try: | |
del sys.modules[__name__].embedding_model | |
print("Modelo de embeddings eliminado de la memoria.") | |
except Exception as e: | |
print(f"Error al eliminar embedding_model: {e}") | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# --- Interfaz Gradio --- | |
with gr.Blocks(title="Diarizaci贸n de Audio en Espa帽ol") as demo: | |
gr.Markdown("# 馃帳 Diarizaci贸n de Audio en Espa帽ol") | |
gr.Markdown("Sube un archivo de audio (hasta 30 minutos) y obt茅n una transcripci贸n separada por hablantes. Optimizado para espa帽ol.") | |
gr.Markdown("**Nota:** Este demo usa modelos ligeros. Para audio con mucho ruido o m谩s de 10 minutos, los resultados pueden ser menos precisos.") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(label="Subir Audio", type="filepath") | |
num_speakers = gr.Slider(2, 6, value=3, step=1, label="N煤mero aproximado de hablantes") | |
run_button = gr.Button("馃殌 Iniciar Diarizaci贸n") | |
with gr.Column(): | |
status_output = gr.Textbox(label="Estado", interactive=False, lines=10, max_lines=10) | |
text_output = gr.Textbox(label="Transcripci贸n con Hablantes", interactive=False, lines=20) | |
run_button.click( | |
fn=transcribe_and_diarize, | |
inputs=[audio_input, num_speakers], | |
outputs=[status_output, text_output], | |
queue=True, | |
concurrency_limit=1 # Limitar a 1 ejecuci贸n simult谩nea para evitar sobrecarga | |
) | |
gr.Markdown("---") | |
gr.Markdown("**Modelos Usados:**\n" | |
"* **Transcripci贸n:** Whisper (`large-v3`)\n" | |
"* **Diarizaci贸n:** Pyannote.Audio (`pyannote/embedding` o `speechbrain/spkrec-ecapa-voxceleb`)\n") | |
# Para Hugging Face Spaces | |
demo.launch() | |