|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
import soundfile as sf |
|
import re |
|
from pathlib import Path |
|
from typing import Optional, Union, List, Dict, Any |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
|
from .whisperx.audio import load_audio, SAMPLE_RATE |
|
from .whisperx.vads import Pyannote, Silero |
|
from .whisperx.types import TranscriptionResult, SingleSegment, AlignedTranscriptionResult |
|
from .whisperx.alignment import load_align_model, align |
|
|
|
|
|
class MazeWhisperModel: |
|
def __init__(self, model_name: str = "sven33/maze-whisper-3000", device: str = "cuda"): |
|
self.device = device |
|
self.model_name = model_name |
|
|
|
print(f"Loading Maze Whisper model: {model_name}") |
|
self.processor = WhisperProcessor.from_pretrained(model_name) |
|
self.model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) |
|
self.tokenizer = self.processor.tokenizer |
|
self.model.eval() |
|
|
|
def transcribe_segment(self, audio_segment: np.ndarray) -> str: |
|
with torch.no_grad(): |
|
inputs = self.processor( |
|
audio_segment, |
|
sampling_rate=SAMPLE_RATE, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
generated_ids = self.model.generate( |
|
inputs["input_features"], |
|
max_length=448, |
|
num_beams=5, |
|
early_stopping=True, |
|
use_cache=True |
|
) |
|
|
|
transcription = self.processor.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
return transcription.strip() |
|
|
|
|
|
class WhisperXPipeline: |
|
def __init__(self, model_name: str = "sven33/maze-whisper-3000", device: str = "cuda", |
|
vad_method: str = "pyannote", chunk_size: int = 30, |
|
enable_alignment: bool = True, align_language: str = "en"): |
|
self.device = device |
|
self.chunk_size = chunk_size |
|
self.enable_alignment = enable_alignment |
|
self.align_language = align_language |
|
|
|
self.whisper_model = MazeWhisperModel(model_name, device) |
|
self._init_vad_model(vad_method) |
|
|
|
self.align_model = None |
|
self.align_metadata = None |
|
if enable_alignment: |
|
self._init_alignment_model() |
|
|
|
def _init_vad_model(self, vad_method: str): |
|
default_vad_options = { |
|
"chunk_size": self.chunk_size, |
|
"vad_onset": 0.500, |
|
"vad_offset": 0.363 |
|
} |
|
|
|
if vad_method == "silero": |
|
self.vad_model = Silero(**default_vad_options) |
|
elif vad_method == "pyannote": |
|
device_vad = f'cuda:0' if self.device == 'cuda' else self.device |
|
self.vad_model = Pyannote(torch.device(device_vad), **default_vad_options) |
|
else: |
|
raise ValueError(f"Invalid vad_method: {vad_method}") |
|
|
|
def _init_alignment_model(self): |
|
try: |
|
print(f"Loading alignment model for language: {self.align_language}") |
|
self.align_model, self.align_metadata = load_align_model( |
|
self.align_language, |
|
self.device |
|
) |
|
except Exception as e: |
|
print(f"Warning: Could not load alignment model: {e}") |
|
print("Continuing without forced alignment...") |
|
self.enable_alignment = False |
|
|
|
def transcribe(self, audio: Union[str, np.ndarray], verbose: bool = False) -> Union[TranscriptionResult, AlignedTranscriptionResult]: |
|
if isinstance(audio, str): |
|
audio_path = audio |
|
audio = load_audio(audio) |
|
else: |
|
audio_path = None |
|
|
|
if hasattr(self.vad_model, 'preprocess_audio'): |
|
waveform = self.vad_model.preprocess_audio(audio) |
|
else: |
|
waveform = torch.from_numpy(audio).unsqueeze(0) |
|
|
|
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) |
|
|
|
if hasattr(self.vad_model, 'merge_chunks'): |
|
vad_segments = self.vad_model.merge_chunks( |
|
vad_segments, |
|
self.chunk_size, |
|
onset=0.500, |
|
offset=0.363, |
|
) |
|
|
|
segments: List[SingleSegment] = [] |
|
|
|
print(f"Processing {len(vad_segments)} segments...") |
|
|
|
for idx, seg in enumerate(vad_segments): |
|
start_sample = int(seg['start'] * SAMPLE_RATE) |
|
end_sample = int(seg['end'] * SAMPLE_RATE) |
|
audio_segment = audio[start_sample:end_sample] |
|
|
|
text = self.whisper_model.transcribe_segment(audio_segment) |
|
|
|
if not text.strip() or len(text.strip()) < 2: |
|
if verbose: |
|
print(f"Skipping empty/short segment {idx+1}: [{seg['start']:.3f}s - {seg['end']:.3f}s]") |
|
continue |
|
|
|
if verbose: |
|
print(f"Segment {idx+1}/{len(vad_segments)}: [{seg['start']:.3f}s - {seg['end']:.3f}s] {text}") |
|
|
|
segments.append({ |
|
"text": text, |
|
"start": round(seg['start'], 3), |
|
"end": round(seg['end'], 3) |
|
}) |
|
|
|
result = {"segments": segments, "language": self.align_language} |
|
|
|
if self.enable_alignment and self.align_model is not None and len(segments) > 0: |
|
print("Preparing segments for forced alignment...") |
|
|
|
cleaned_segments = [] |
|
for segment in segments: |
|
original_text = segment["text"] |
|
cleaned_text = clean_text_for_alignment(original_text) |
|
|
|
if cleaned_text.strip() and len(cleaned_text.strip()) >= 2: |
|
cleaned_segment = { |
|
"text": cleaned_text, |
|
"start": segment["start"], |
|
"end": segment["end"] |
|
} |
|
cleaned_segments.append({ |
|
"cleaned": cleaned_segment, |
|
"original": segment |
|
}) |
|
|
|
if len(cleaned_segments) > 0: |
|
print(f"Performing forced alignment on {len(cleaned_segments)} segments...") |
|
try: |
|
segments_for_alignment = [item["cleaned"] for item in cleaned_segments] |
|
|
|
aligned_result = align( |
|
segments_for_alignment, |
|
self.align_model, |
|
self.align_metadata, |
|
audio_path if audio_path else audio, |
|
self.device, |
|
interpolate_method="nearest", |
|
return_char_alignments=False, |
|
print_progress=verbose |
|
) |
|
|
|
final_segments = [] |
|
aligned_segments = aligned_result.get("segments", []) |
|
|
|
for i, aligned_seg in enumerate(aligned_segments): |
|
if i < len(cleaned_segments): |
|
original_segment = cleaned_segments[i]["original"] |
|
|
|
final_segment = { |
|
"text": original_segment["text"], |
|
"start": aligned_seg["start"], |
|
"end": aligned_seg["end"], |
|
"words": aligned_seg.get("words", []) |
|
} |
|
|
|
if "words" in final_segment and final_segment["words"]: |
|
final_segment["words"] = fix_word_alignment( |
|
final_segment["words"], |
|
original_segment["text"], |
|
cleaned_segments[i]["cleaned"]["text"] |
|
) |
|
|
|
final_segments.append(final_segment) |
|
|
|
final_result = { |
|
"segments": final_segments, |
|
"word_segments": [], |
|
"language": self.align_language |
|
} |
|
|
|
for segment in final_segments: |
|
if "words" in segment: |
|
final_result["word_segments"].extend(segment["words"]) |
|
|
|
print(f"Alignment completed! {len(final_segments)} segments with {len(final_result['word_segments'])} words") |
|
return final_result |
|
|
|
except Exception as e: |
|
print(f"Warning: Alignment failed: {e}") |
|
print("Returning transcription without alignment...") |
|
else: |
|
print("Warning: No segments remaining after cleaning for alignment") |
|
|
|
return result |
|
|
|
|
|
def clean_text_for_alignment(text: str) -> str: |
|
cleaned_text = re.sub(r'<[^>]*>', '', text) |
|
cleaned_text = re.sub(r'[\[\]{}]', '', cleaned_text) |
|
cleaned_text = re.sub(r'[^\w\s\.\,\?\!\-\']', '', cleaned_text) |
|
cleaned_text = cleaned_text.replace('.', '') |
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() |
|
return cleaned_text |
|
|
|
|
|
def fix_word_alignment(words: List[Dict], original_text: str, cleaned_text: str) -> List[Dict]: |
|
try: |
|
original_tokens = original_text.split() |
|
cleaned_tokens = cleaned_text.split() |
|
|
|
if len(words) == 0 or len(cleaned_tokens) == 0: |
|
return words |
|
|
|
if abs(len(original_tokens) - len(cleaned_tokens)) <= 1: |
|
return words |
|
|
|
|
|
return words |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not fix word alignment: {e}") |
|
return words |
|
|
|
|
|
def generate_session_id() -> str: |
|
session_data_dir = Path("./session_data") |
|
|
|
if not session_data_dir.exists(): |
|
return "000001" |
|
|
|
existing_sessions = [] |
|
for item in session_data_dir.iterdir(): |
|
if item.is_dir() and item.name.isdigit() and len(item.name) == 6: |
|
existing_sessions.append(int(item.name)) |
|
|
|
if not existing_sessions: |
|
return "000001" |
|
|
|
next_id = max(existing_sessions) + 1 |
|
return f"{next_id:06d}" |
|
|
|
|
|
|
|
|
|
def translate_audio_file(model: str = "mazeWhisper", audio_path: str = "", device: str = "cuda", |
|
enable_alignment: bool = True, align_language: str = "en", original_filename: str = None) -> Dict[str, Any]: |
|
if model != "mazeWhisper": |
|
raise ValueError("Currently only 'mazeWhisper' model is supported") |
|
|
|
if not os.path.exists(audio_path): |
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
session_id = generate_session_id() |
|
session_data_dir = Path("./session_data") |
|
session_dir = session_data_dir / session_id |
|
session_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
print(f"Session ID: {session_id}") |
|
print(f"Session directory: {session_dir}") |
|
|
|
try: |
|
pipeline = WhisperXPipeline( |
|
model_name="sven33/maze-whisper-3000", |
|
device=device, |
|
vad_method="pyannote", |
|
chunk_size=10, |
|
enable_alignment=enable_alignment, |
|
align_language=align_language |
|
) |
|
|
|
audio = load_audio(audio_path) |
|
|
|
print("Starting transcription...") |
|
result = pipeline.transcribe(audio_path, verbose=True) |
|
|
|
has_word_timestamps = ( |
|
isinstance(result, dict) and |
|
"segments" in result and |
|
len(result["segments"]) > 0 and |
|
"words" in result["segments"][0] |
|
) |
|
|
|
formatted_segments = [] |
|
for segment in result["segments"]: |
|
formatted_segment = { |
|
"start": segment["start"], |
|
"end": segment["end"], |
|
"speaker": "", |
|
"text": segment["text"], |
|
"words": [] |
|
} |
|
|
|
if "words" in segment and segment["words"]: |
|
for word_info in segment["words"]: |
|
formatted_word = { |
|
"word": word_info["word"], |
|
"start": word_info["start"], |
|
"end": word_info["end"] |
|
} |
|
formatted_segment["words"].append(formatted_word) |
|
|
|
formatted_segments.append(formatted_segment) |
|
|
|
|
|
filename = original_filename if original_filename else os.path.basename(audio_path) |
|
output_data = { |
|
"filename": filename, |
|
"segments": formatted_segments |
|
} |
|
|
|
json_path = session_dir / "transcription.json" |
|
with open(json_path, 'w', encoding='utf-8') as f: |
|
json.dump(output_data, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"Transcription saved: {json_path}") |
|
|
|
if has_word_timestamps: |
|
total_words = sum(len(seg.get("words", [])) for seg in result["segments"]) |
|
print(f"Forced alignment completed! Total words with timestamps: {total_words}") |
|
elif enable_alignment: |
|
print("Forced alignment was enabled but failed - only segment-level timestamps available") |
|
else: |
|
print("Forced alignment disabled - only segment-level timestamps available") |
|
|
|
|
|
print(f"Transcription complete! Session: {session_id}") |
|
|
|
result_data = { |
|
"session_id": session_id, |
|
"audio_path": audio_path, |
|
"model": "sven33/maze-whisper-3000", |
|
"device": device, |
|
"alignment_enabled": enable_alignment, |
|
"has_word_timestamps": has_word_timestamps, |
|
"align_language": align_language, |
|
"transcription": result |
|
} |
|
|
|
return result_data, session_id |
|
|
|
except Exception as e: |
|
print(f"Error during transcription: {str(e)}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
print("use main_socket to test transcription model") |