#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Audio Feature Extractor Script Main Functions: 1. Extract word-level features from single audio files 2. Use Whisper for speech recognition 3. Use Wav2Vec2 for word-level alignment with Triton optimization 4. Support English audio processing Usage: extractor = AudioFeatureExtractor() features = extractor.extract_features("path/to/audio.wav") Basic Implementation Logic: 1. Load audio file using librosa 2. Use Whisper for speech transcription 3. Use Wav2Vec2 for word-level alignment (Triton optimized) 4. Extract audio features based on word-level timestamps 5. Return feature dictionary with word segments and audio features """ import os import json import warnings import argparse from typing import List, Optional, Dict, Any from dataclasses import dataclass import numpy as np import torch import librosa import soundfile as sf import parselmouth # Check Triton availability try: import triton import triton.language as tl TRITON_AVAILABLE = True except ImportError: TRITON_AVAILABLE = False print("Warning: Triton not available, will use original implementation") # HuggingFace libraries from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC ) # ===== Configuration Constants ===== # Audio processing parameters SAMPLE_RATE = 16000 # Standard sample rate for Whisper and Wav2Vec2 MAX_DURATION = 30 # Maximum audio segment duration (seconds) # Default model paths for English DEFAULT_WHISPER_MODEL = "openai/whisper-large-v3" DEFAULT_ALIGN_MODEL = "facebook/wav2vec2-large-960h-lv60-self" # ===== Data Structure Definitions ===== @dataclass class WordSegment: """Word-level segment information""" word: str # Word text start: Optional[float] # Start time (seconds) end: Optional[float] # End time (seconds) score: Optional[float] # Confidence score @dataclass class AlignedSegment: """Aligned sentence segment""" text: str # Sentence text start: Optional[float] # Start time (seconds) end: Optional[float] # End time (seconds) words: List[WordSegment] # Word-level information list class AudioFeatureExtractor: """ Audio Feature Extractor Class Main Functions: 1. Load and process single audio files 2. Extract word-level features using Triton optimization 3. Support English audio processing Usage: extractor = AudioFeatureExtractor() features = extractor.extract_features("audio.wav") """ def __init__(self, whisper_model: str = DEFAULT_WHISPER_MODEL, align_model: str = DEFAULT_ALIGN_MODEL, device: str = "auto", merge_threshold: float = 0.5): """ Initialize Audio Feature Extractor Args: whisper_model: Path to Whisper model for speech recognition align_model: Path to Wav2Vec2 model for word alignment device: Computing device ("auto", "cpu", "cuda") merge_threshold: Word merging threshold (seconds) Implementation Logic: 1. Set up device configuration 2. Load Whisper and Wav2Vec2 models 3. Initialize vocabulary for alignment 4. Ensure Triton optimization is available """ self.merge_threshold = merge_threshold # Device selection if device == "auto": self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device print(f"🚀 Initializing Audio Feature Extractor...") print(f" Device: {self.device}") print(f" Whisper Model: {whisper_model}") print(f" Align Model: {align_model}") print(f" Word Merge Threshold: {merge_threshold}s") print(f" Triton Available: {TRITON_AVAILABLE}") # Ensure Triton is available for optimization if not TRITON_AVAILABLE: raise RuntimeError("Triton is required but not available. Please install triton.") if self.device != "cuda": raise RuntimeError("Triton optimization requires CUDA device.") # Load models self._load_models(whisper_model, align_model) print("✅ Audio Feature Extractor initialized successfully") def _load_models(self, whisper_model: str, align_model: str): """ Load Whisper and Wav2Vec2 models Args: whisper_model: Path to Whisper model align_model: Path to Wav2Vec2 alignment model Implementation Logic: 1. Load Whisper model for speech recognition 2. Load Wav2Vec2 model for word-level alignment 3. Set models to evaluation mode 4. Build character-level vocabulary dictionary """ try: # Load Whisper model print("📥 Loading Whisper model...") self.whisper_processor = WhisperProcessor.from_pretrained(whisper_model) self.whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model) self.whisper_model.to(self.device) self.whisper_model.eval() # Load Wav2Vec2 alignment model print(f"📥 Loading Wav2Vec2 alignment model...") self.align_processor = Wav2Vec2Processor.from_pretrained(align_model) self.align_model = Wav2Vec2ForCTC.from_pretrained(align_model) self.align_model.to(self.device) self.align_model.eval() # Build character-level vocabulary dictionary labels = self.align_processor.tokenizer.get_vocab() # Create character to ID mapping, convert all characters to lowercase self.vocab = {char.lower(): code for char, code in labels.items()} self.id_to_token = {v: k for k, v in self.vocab.items()} print("✅ Models loaded successfully") print(f"Vocabulary size: {len(self.vocab)}") except Exception as e: print(f"❌ Model loading failed: {e}") raise def load_audio_file(self, audio_path: str) -> tuple[np.ndarray, int]: """ Load audio file from given path Args: audio_path: Path to audio file (absolute or relative) Returns: tuple: (audio_array, sampling_rate) Implementation Logic: 1. Check if path exists 2. Load audio using librosa 3. Return audio array and sampling rate 4. Handle errors gracefully """ try: # Check if file exists if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") # Load audio file audio_array, sampling_rate = librosa.load(audio_path, sr=None) print(f"📁 Loaded audio: {audio_path}") print(f" Duration: {len(audio_array)/sampling_rate:.2f}s") print(f" Sample Rate: {sampling_rate}Hz") return audio_array, sampling_rate except Exception as e: print(f"❌ Failed to load audio file: {audio_path}, Error: {e}") raise def transcribe_audio(self, audio: np.ndarray, sampling_rate: int) -> str: """ Transcribe audio using Whisper model Args: audio: Audio array sampling_rate: Sampling rate Returns: Transcribed text Implementation Logic: 1. Resample audio to 16kHz if needed 2. Preprocess audio for Whisper 3. Generate transcription using Whisper model 4. Return cleaned transcription text """ try: # Resample to 16kHz if sampling_rate != SAMPLE_RATE: audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLE_RATE) # Preprocess audio inputs = self.whisper_processor( audio, sampling_rate=SAMPLE_RATE, return_tensors="pt" ) inputs = inputs.to(self.device) # Generate transcription with torch.no_grad(): predicted_ids = self.whisper_model.generate(inputs["input_features"]) transcription = self.whisper_processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] print(f"🎯 Transcription: {transcription}") return transcription.strip() except Exception as e: print(f"❌ Transcription failed: {e}") return "" def get_word_timestamps(self, audio: np.ndarray, text: str) -> List[AlignedSegment]: """ Get word-level timestamps using Wav2Vec2 forced alignment Args: audio: Audio array text: Transcribed text Returns: List of aligned segments with word-level timestamps Implementation Logic: 1. Preprocess text for alignment 2. Use Wav2Vec2 model for CTC alignment 3. Calculate word-level boundaries using Triton optimization 4. Return aligned segments with timestamps """ try: print("🔄 Starting Wav2Vec2 forced alignment...") # Preprocess text clean_transcript = self._preprocess_text(text) if not clean_transcript: print("Warning: Text preprocessing resulted in empty string") return [AlignedSegment( text=text, start=0.0, end=len(audio) / SAMPLE_RATE, words=[WordSegment( word=text, start=0.0, end=len(audio) / SAMPLE_RATE, score=0.0 )] )] # Preprocess audio inputs = self.align_processor( audio, sampling_rate=SAMPLE_RATE, return_tensors="pt" ) inputs = inputs.to(self.device) # Get model output with torch.no_grad(): logits = self.align_model(inputs.input_values).logits # Convert to log probabilities log_probs = torch.log_softmax(logits, dim=-1) emission = log_probs[0] # Remove batch dimension # Perform CTC alignment with Triton optimization aligned_segments = self._ctc_align_triton(emission, clean_transcript, audio) print("✅ Forced alignment completed") return aligned_segments except Exception as e: print(f"❌ Word-level alignment failed: {e}") return [] def _preprocess_text(self, text: str) -> str: """ Preprocess text by removing characters not in vocabulary Args: text: Original text Returns: Cleaned text Implementation Logic: 1. Convert to lowercase 2. Replace spaces with | (Wav2Vec2 convention) 3. Keep only characters in vocabulary 4. Replace unknown characters with wildcards """ # Convert to lowercase text = text.lower().strip() # Replace spaces with | (Wav2Vec2 convention for English) text = text.replace(" ", "|") # Keep only characters in vocabulary clean_chars = [] for char in text: if char in self.vocab: clean_chars.append(char) else: # Replace unknown characters with wildcards clean_chars.append("*") return "".join(clean_chars) def _ctc_align_triton(self, emission: torch.Tensor, transcript: str, audio: np.ndarray) -> List[AlignedSegment]: """ Perform CTC forced alignment using Triton optimization Args: emission: Model output emission probabilities transcript: Cleaned transcript text audio: Original audio array Returns: List of aligned segments Implementation Logic: 1. Convert text to token IDs 2. Build trellis using Triton-optimized kernels 3. Backtrack optimal path 4. Merge repeated characters 5. Generate word alignments with timestamps """ # Convert text to token IDs tokens = [self.vocab.get(char, self.vocab.get("[UNK]", 0)) for char in transcript] # Get blank token ID blank_id = self.vocab.get("[PAD]", 0) if "[PAD]" not in self.vocab: blank_id = self.vocab.get("", 0) # Build trellis using Triton optimization trellis = self._get_trellis(emission, tokens, blank_id) # Backtrack optimal path path = self._backtrack(trellis, emission, tokens, blank_id) if path is None: print("Warning: CTC alignment failed, returning original timestamps") return [AlignedSegment( text=transcript.replace("|", " "), start=0.0, end=len(audio) / SAMPLE_RATE, words=[WordSegment( word=transcript.replace("|", " "), start=0.0, end=len(audio) / SAMPLE_RATE, score=0.0 )] )] # Merge repeated characters char_segments = self._merge_repeats(path, transcript) # Convert to timestamps duration = len(audio) / SAMPLE_RATE time_ratio = duration / (emission.size(0) - 1) # Generate word-level alignments words = self._generate_word_alignments(char_segments, transcript, time_ratio) return [AlignedSegment( text=transcript.replace("|", " "), start=words[0].start if words else 0.0, end=words[-1].end if words else duration, words=words )] @staticmethod @triton.jit def _trellis_row_kernel_optimized( # Pointers trellis_t_ptr, trellis_tm1_ptr, emission_t_ptr, tokens_ptr, # Scalar arguments num_tokens, blank_emit: tl.float32, t, # Tensor strides trellis_stride_n, # Meta-parameters BLOCK_SIZE_N: tl.constexpr, ): """ Triton-optimized kernel for trellis row computation This kernel computes one row of the CTC trellis matrix in parallel, significantly speeding up the forced alignment process. Implementation Logic: 1. Calculate parallel indices for token positions 2. Load previous trellis values (stay and advance paths) 3. Load emission probabilities for current tokens 4. Compute path scores using numerically stable logsumexp 5. Store results back to trellis """ # Calculate parallel indices starting from j=1 pid = tl.program_id(axis=0) offs_n = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + 1 # Ensure indices are within bounds mask = (offs_n < num_tokens) & (offs_n <= t) # Load trellis[t-1, j] (stay path) prev_stay_ptr = trellis_tm1_ptr + offs_n * trellis_stride_n prev_stay_score = tl.load(prev_stay_ptr, mask=mask, other=float('-inf')) # Load trellis[t-1, j-1] (advance path) prev_advance_ptr = trellis_tm1_ptr + (offs_n - 1) * trellis_stride_n prev_advance_score = tl.load(prev_advance_ptr, mask=mask, other=float('-inf')) # Load emission[t, tokens[j]] tokens_j = tl.load(tokens_ptr + offs_n, mask=mask, other=0) emission_token = tl.load(emission_t_ptr + tokens_j, mask=mask, other=float('-inf')) # Calculate path scores stay_score = prev_stay_score + blank_emit advance_score = prev_advance_score + emission_token # Numerically stable logsumexp max_val = tl.maximum(stay_score, advance_score) min_val = tl.minimum(stay_score, advance_score) log_sum = tl.where( max_val > float('-inf'), max_val + tl.log(1.0 + tl.exp(min_val - max_val)), float('-inf') ) # Store results trellis_t_ptr_j = trellis_t_ptr + offs_n * trellis_stride_n tl.store(trellis_t_ptr_j, log_sum, mask=mask) def _get_trellis(self, emission: torch.Tensor, tokens: List[int], blank_id: int) -> torch.Tensor: """ Build CTC alignment trellis using Triton optimization Args: emission: Emission probability matrix [T, V] tokens: Token ID list blank_id: Blank token ID Returns: Trellis matrix [T, N] """ # Use Triton optimized version if available and on CUDA device if TRITON_AVAILABLE and torch.cuda.is_available() and emission.device.type == 'cuda': return self._get_trellis_triton(emission, tokens, blank_id) else: # Fallback to original implementation print("Warning: Triton not enabled or CUDA unavailable, falling back to original CTC alignment implementation") return self._get_trellis_original(emission, tokens, blank_id) def _get_trellis_original(self, emission: torch.Tensor, tokens: List[int], blank_id: int) -> torch.Tensor: """ Original trellis construction implementation (as fallback) """ num_frame = emission.size(0) num_tokens = len(tokens) # Initialize trellis trellis = torch.full((num_frame, num_tokens), float('-inf'), device=emission.device) trellis[0, 0] = emission[0, blank_id] # Fill first row for t in range(1, num_frame): trellis[t, 0] = trellis[t-1, 0] + emission[t, blank_id] # Fill trellis for t in range(1, num_frame): for j in range(1, min(num_tokens, t + 1)): # Stay at current token (insert blank) stay_score = trellis[t-1, j] + emission[t, blank_id] # Advance to next token advance_score = trellis[t-1, j-1] + emission[t, tokens[j]] trellis[t, j] = torch.logsumexp(torch.stack([stay_score, advance_score]), dim=0) return trellis def _get_trellis_triton(self, emission: torch.Tensor, tokens: List[int], blank_id: int) -> torch.Tensor: """ Triton-optimized trellis construction - optimized version """ assert emission.is_cuda, "Input tensor must be on CUDA device" num_frame, vocab_size = emission.size() tokens_tensor = torch.tensor(tokens, device=emission.device, dtype=torch.long) num_tokens = len(tokens_tensor) # --- Optimization 3: Ensure memory contiguity --- # Fix: Remove memory_format parameter, use .contiguous() method to ensure memory contiguity trellis = torch.full((num_frame, num_tokens), float('-inf'), device=emission.device, dtype=torch.float32).contiguous() if num_tokens == 0: return trellis # --- Optimization 4: Use vectorized cumsum to initialize first column --- # Calculate cumulative blank probabilities from t=0 # Note: For consistency with original logic, use emission directly instead of log_softmax trellis[:, 0] = emission[:, blank_id].cumsum(dim=0) # --- Optimization 2: Dynamic block size --- # Adapt to different token sequence lengths, improve GPU utilization BLOCK_SIZE_N = min(1024, triton.next_power_of_2(num_tokens)) if num_tokens > 1 else 1 # Main loop for t in range(1, num_frame): # --- Optimization 1: Scalar broadcasting --- # Pass blank emission as scalar to avoid redundant loading blank_emit = emission[t, blank_id].item() # Launch grid, only compute j > 0 part if num_tokens > 1: grid = lambda meta: (triton.cdiv(num_tokens - 1, meta['BLOCK_SIZE_N']),) self._trellis_row_kernel_optimized[grid]( trellis_t_ptr=trellis[t], trellis_tm1_ptr=trellis[t-1], emission_t_ptr=emission[t], tokens_ptr=tokens_tensor, num_tokens=num_tokens, blank_emit=blank_emit, t=t, trellis_stride_n=trellis.stride(1), BLOCK_SIZE_N=BLOCK_SIZE_N, ) return trellis def _backtrack(self, trellis: torch.Tensor, emission: torch.Tensor, tokens: List[int], blank_id: int) -> Optional[List]: """ Backtrack through trellis to find optimal alignment path Args: trellis: Completed trellis matrix emission: Emission probabilities tokens: Token ID list blank_id: Blank token ID Returns: Optimal path through trellis Implementation Logic: 1. Start from final position in trellis 2. Trace back through highest probability path 3. Record token positions and timestamps 4. Return path for character merging """ # Implementation details would be similar to the original _backtrack method # This is a simplified version for demonstration try: num_frame, num_tokens = trellis.size() # Start from the end t, j = num_frame - 1, num_tokens - 1 path = [] while t >= 0 and j >= 0: path.append((t, j)) if j == 0: t -= 1 elif t == 0: j -= 1 else: # Choose path with higher probability stay_score = trellis[t-1, j] + emission[t, blank_id] advance_score = trellis[t-1, j-1] + emission[t, tokens[j]] if stay_score > advance_score: t -= 1 else: t -= 1 j -= 1 return list(reversed(path)) except Exception as e: print(f"Backtracking failed: {e}") return None def _merge_repeats(self, path: List, transcript: str) -> List: """ Merge repeated characters in alignment path Args: path: Alignment path from backtracking transcript: Original transcript Returns: List of character segments with merged repeats Implementation Logic: 1. Group consecutive identical characters 2. Calculate start and end frames for each character 3. Return character segments for word boundary detection """ if not path: return [] char_segments = [] current_char = None start_frame = None for t, j in path: char = transcript[j] if j < len(transcript) else None if char != current_char: if current_char is not None: char_segments.append({ 'char': current_char, 'start': start_frame, 'end': t - 1 }) current_char = char start_frame = t # Add final character if current_char is not None: char_segments.append({ 'char': current_char, 'start': start_frame, 'end': path[-1][0] }) return char_segments def _generate_word_alignments(self, char_segments: List, transcript: str, time_ratio: float) -> List[WordSegment]: """ Generate word-level alignments from character segments Args: char_segments: Character-level segments transcript: Original transcript time_ratio: Frame to time conversion ratio Returns: List of word segments with timestamps Implementation Logic: 1. Group characters into words using | delimiter 2. Calculate word boundaries from character segments 3. Convert frame indices to time stamps 4. Return word segments with confidence scores """ words = [] current_word = "" word_start = None word_chars = [] for segment in char_segments: char = segment['char'] if char == '|': # Word boundary if current_word and word_chars: # Calculate word timing start_time = word_chars[0]['start'] * time_ratio end_time = word_chars[-1]['end'] * time_ratio words.append(WordSegment( word=current_word, start=start_time, end=end_time, score=1.0 # Simplified confidence score )) current_word = "" word_chars = [] else: current_word += char word_chars.append(segment) # Add final word if current_word and word_chars: start_time = word_chars[0]['start'] * time_ratio end_time = word_chars[-1]['end'] * time_ratio words.append(WordSegment( word=current_word, start=start_time, end=end_time, score=1.0 )) return words def merge_short_words(self, word_segments: List[WordSegment]) -> List[WordSegment]: """ Merge short words with neighboring words Args: word_segments: List of word segments Returns: List of merged word segments Implementation Logic: 1. Identify words shorter than merge threshold 2. Find shortest neighboring word for merging 3. Merge words and update timestamps 4. Repeat until no more merging is needed """ if not word_segments: return [] merged_segments = word_segments.copy() while True: # Find short words short_indices = [] for i, segment in enumerate(merged_segments): if segment.start is not None and segment.end is not None: duration = segment.end - segment.start if duration < self.merge_threshold: short_indices.append(i) if not short_indices: break # Merge shortest word with its shortest neighbor shortest_idx = min(short_indices, key=lambda i: merged_segments[i].end - merged_segments[i].start) neighbor_idx = self._find_shortest_neighbor(merged_segments, shortest_idx) if neighbor_idx is not None: # Merge segments merged_segment = self._merge_two_segments( merged_segments[shortest_idx], merged_segments[neighbor_idx] ) # Remove original segments and insert merged one indices_to_remove = sorted([shortest_idx, neighbor_idx], reverse=True) for idx in indices_to_remove: merged_segments.pop(idx) # Insert merged segment at appropriate position insert_pos = min(shortest_idx, neighbor_idx) merged_segments.insert(insert_pos, merged_segment) else: break return merged_segments def _find_shortest_neighbor(self, segments: List[WordSegment], current_idx: int) -> Optional[int]: """ Find the shortest neighboring word for merging Args: segments: List of word segments current_idx: Index of current word Returns: Index of shortest neighbor, or None if no valid neighbor """ neighbors = [] # Check left neighbor if current_idx > 0: neighbors.append(current_idx - 1) # Check right neighbor if current_idx < len(segments) - 1: neighbors.append(current_idx + 1) if not neighbors: return None # Find shortest neighbor shortest_neighbor = min(neighbors, key=lambda i: segments[i].end - segments[i].start if segments[i].start and segments[i].end else float('inf')) return shortest_neighbor def _merge_two_segments(self, segment1: WordSegment, segment2: WordSegment) -> WordSegment: """ Merge two word segments into one Args: segment1: First word segment segment2: Second word segment Returns: Merged word segment Implementation Logic: 1. Combine word texts with space 2. Use earliest start time 3. Use latest end time 4. Average confidence scores """ # Determine order based on start times if segment1.start <= segment2.start: first, second = segment1, segment2 else: first, second = segment2, segment1 # Merge word texts merged_word = f"{first.word} {second.word}" # Merge timestamps merged_start = first.start merged_end = second.end # Average confidence scores merged_score = (first.score + second.score) / 2 if first.score and second.score else None return WordSegment( word=merged_word, start=merged_start, end=merged_end, score=merged_score ) def extract_audio_features(self, audio: np.ndarray, sampling_rate: int, word_segments: List[WordSegment]) -> Dict[str, Any]: """ Extract comprehensive audio features for word segments (compatible with extractor.py format) Args: audio: Audio array sampling_rate: Sampling rate word_segments: List of word segments with timestamps Returns: Dictionary containing extracted features matching extractor.py format Implementation Logic: 1. Extract features for each word segment 2. Calculate acoustic features (pitch, energy, spectral) 3. Calculate speaking rate and other statistics 4. Return features in extractor.py compatible format """ # Calculate total duration total_duration = len(audio) / sampling_rate # Count original and processed words original_word_count = len(word_segments) processed_word_count = 0 word_features = [] for word_segment in word_segments: if word_segment.start is None or word_segment.end is None: continue # Extract audio segment for this word start_sample = int(word_segment.start * sampling_rate) end_sample = int(word_segment.end * sampling_rate) word_audio = audio[start_sample:end_sample] if len(word_audio) == 0: continue # Extract acoustic features word_feature = self._extract_word_features(word_audio, sampling_rate, word_segment) word_features.append(word_feature) processed_word_count += 1 # Calculate speaking rate (words per minute) speaking_rate = (processed_word_count / total_duration * 60) if total_duration > 0 else 0.0 # Return features in extractor.py compatible format return { "total_duration": total_duration, "speaking_rate": speaking_rate, "original_word_count": original_word_count, "processed_word_count": processed_word_count, "word_features": word_features } def _extract_word_features(self, word_audio: np.ndarray, sampling_rate: int, word_segment: WordSegment) -> Dict[str, Any]: """ Extract acoustic features for a single word (compatible with extractor.py format) Args: word_audio: Audio array for the word sampling_rate: Sampling rate word_segment: Word segment information Returns: Dictionary of acoustic features matching extractor.py format Implementation Logic: 1. Calculate basic timing features 2. Extract pitch features using Parselmouth (pitch_mean, pitch_slope) 3. Calculate energy features (energy_rms, energy_slope) 4. Calculate spectral features (spectral_centroid) 5. Return features in extractor.py compatible format """ # Clean word text, remove special symbols clean_word = word_segment.word # Initialize features with basic information features = { "word": clean_word, "start_time": word_segment.start, "end_time": word_segment.end, "duration": word_segment.end - word_segment.start, "confidence_score": word_segment.score } try: # === Pitch feature extraction === avg_pitch = np.nan pitch_slope = np.nan try: # Create temporary audio file for parselmouth analysis import tempfile with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_audio: sf.write(temp_audio.name, word_audio, sampling_rate) sound = parselmouth.Sound(temp_audio.name) pitch = sound.to_pitch(pitch_floor=50.0, pitch_ceiling=600.0) pitch_times = pitch.xs() pitch_values = pitch.selected_array['frequency'] # Remove unvoiced frames (pitch = 0) pitch_values[pitch_values == 0] = np.nan valid_pitch = pitch_values[~np.isnan(pitch_values)] if len(valid_pitch) > 0: avg_pitch = np.mean(valid_pitch) # Calculate pitch change trend (slope) - safer method if len(valid_pitch) >= 15: # Only calculate slope when there are enough points (using linear regression) duration = features["duration"] time_points = np.linspace(0, duration, len(valid_pitch)) coeffs = np.polyfit(time_points, valid_pitch, 1) pitch_slope = coeffs[0] # Slope else: # Set to NaN when too few data points, safer pitch_slope = np.nan else: # Set to default values when no valid pitch values avg_pitch = np.nan pitch_slope = np.nan # Clean up temporary file try: os.unlink(temp_audio.name) except: pass except Exception as e: print(f"Warning: Pitch extraction failed for word '{clean_word}': {e}") avg_pitch = np.nan pitch_slope = np.nan # === Energy feature extraction === rms_energy = np.sqrt(np.mean(word_audio**2)) # RMS energy try: # 使用已创建的 sound 对象 intensity = sound.to_intensity() intensity_times = intensity.xs() intensity_values = intensity.values[0] # 计算对应时间段的强度值 word_duration = word_segment.end - word_segment.start if len(intensity_values) > 2: start_energy = np.nanmean(intensity_values[:len(intensity_values)//3]) end_energy = np.nanmean(intensity_values[len(intensity_values)*2//3:]) if not (np.isnan(start_energy) or np.isnan(end_energy)): energy_slope = (end_energy - start_energy) / word_duration except Exception as e: print(f"Warning: Energy slope calculation failed for word '{clean_word}': {e}") energy_slope = np.nan # === Spectral feature extraction === try: # Spectral centroid: reflects timbre characteristics segment_length = len(word_audio) if segment_length < 2048: n_fft = 2 ** int(np.log2(segment_length)) n_fft = max(n_fft, 512) # Minimum value is 512 else: n_fft = 2048 # Default value spectral_centroid = librosa.feature.spectral_centroid( y=word_audio, sr=sampling_rate, n_fft=n_fft)[0].mean() except Exception as e: print(f"Warning: Spectral centroid calculation failed for word '{clean_word}': {e}") spectral_centroid = np.nan # Store all features for current word (matching extractor.py format) features.update({ "pitch_mean": float(avg_pitch) if not np.isnan(avg_pitch) else None, "pitch_slope": float(pitch_slope) if not np.isnan(pitch_slope) else None, "energy_rms": float(rms_energy), "energy_slope": float(energy_slope) if not np.isnan(energy_slope) else None, "spectral_centroid": float(spectral_centroid) if not np.isnan(spectral_centroid) else None }) except Exception as e: print(f"Warning: Feature extraction failed for word '{clean_word}': {e}") # Set default values for failed features features.update({ "pitch_mean": None, "pitch_slope": None, "energy_rms": 0.0, "energy_slope": None, "spectral_centroid": None }) return features def extract_features(self, audio_path: str, text: Optional[str] = None, enable_word_merging: bool = True) -> Dict[str, Any]: """ Main method to extract features from audio file Args: audio_path: Path to audio file text: Optional transcription text (if None, will use Whisper) enable_word_merging: Whether to merge short words Returns: Dictionary containing all extracted features Implementation Logic: 1. Load audio file 2. Transcribe audio if text not provided 3. Get word-level timestamps using Triton optimization 4. Optionally merge short words 5. Extract comprehensive audio features 6. Return complete feature dictionary """ try: print(f"🎵 Starting feature extraction for: {audio_path}") # Load audio file audio_array, sampling_rate = self.load_audio_file(audio_path) # Resample to standard rate if needed if sampling_rate != SAMPLE_RATE: audio_array = librosa.resample( audio_array, orig_sr=sampling_rate, target_sr=SAMPLE_RATE ) sampling_rate = SAMPLE_RATE # Transcribe audio if text not provided if text is None: text = self.transcribe_audio(audio_array, sampling_rate) if not text.strip(): return { "error": "Transcription text is empty", "word_features": [], "audio_path": audio_path } # Get word-level timestamps aligned_segments = self.get_word_timestamps(audio_array, text) if not aligned_segments: return { "error": "Word-level alignment failed", "word_features": [], "audio_path": audio_path, "transcribed_text": text } # Collect all word segments all_word_segments = [] for segment in aligned_segments: all_word_segments.extend(segment.words) # Record original word count original_word_count = len(all_word_segments) # Optional word merging if enable_word_merging: all_word_segments = self.merge_short_words(all_word_segments) print(f"📊 Word merging: {original_word_count} → {len(all_word_segments)} words") # Extract audio features features = self.extract_audio_features(audio_array, sampling_rate, all_word_segments) # Format word features using the provided formatting function formatted_word_features = self.format_word_features(features.get("word_features", [])) # Add metadata and formatted features features.update({ "audio_path": audio_path, "transcribed_text": text, "original_word_count": original_word_count, "final_word_count": len(all_word_segments), "word_merging_enabled": enable_word_merging, "triton_optimization": True, "formatted_word_features": formatted_word_features # Add formatted features }) print(f"✅ Feature extraction completed successfully") print(f" Transcribed text: {text}") print(f" Word count: {len(all_word_segments)}") print(f" Total duration: {features['total_duration']:.2f}s") return features except Exception as e: error_msg = f"Feature extraction failed: {str(e)}" print(f"❌ {error_msg}") return { "error": error_msg, "word_features": [], "audio_path": audio_path } def clean_word_text(self, word: str) -> str: """ Clean word text by removing special symbols and formatting Args: word: Original word text Returns: Cleaned word text """ if not word: return "" # Remove common punctuation and special characters import re # Keep only letters, numbers, and basic punctuation cleaned = re.sub(r'[^\w\s\-\']', '', word) # Remove extra whitespace cleaned = cleaned.strip() return cleaned def format_word_features(self, word_features): """ 格式化word_features数据 保留原有的格式化逻辑,处理word_features列表中的每个特征字典 Args: word_features (list): 原始word_features列表 Returns: str: 格式化后的word_features字符串 """ if not isinstance(word_features, list): print(f"Warning: word_features is not a list: {type(word_features)}") return "" formatted_features = [] for feature_dict in word_features: if not isinstance(feature_dict, dict): print(f"Warning: Feature item is not a dict: {type(feature_dict)}") continue # 提取并格式化各个字段 formatted_feature = {} # 处理word字段 word = feature_dict.get('word', '') formatted_feature['word'] = self.clean_word_text(word) # 处理数值字段,按照要求的精度格式化 numeric_fields = { 'pitch_mean': 0, # 仅保留整数 'pitch_slope': 0, # 仅保留整数 'energy_rms': 3, # 保留三位小数 'energy_slope': 0, # 仅保留整数 'spectral_centroid': 0 # 仅保留整数 } for field, decimal_places in numeric_fields.items(): value = feature_dict.get(field) if value is not None: try: if decimal_places == 0: formatted_feature[field] = int(float(value)) else: formatted_feature[field] = round(float(value), decimal_places) except (ValueError, TypeError): print(f"Warning: Failed to format {field}: {value}") formatted_feature[field] = None else: formatted_feature[field] = None formatted_features.append(formatted_feature) # 转换为字符串格式 return str(formatted_features) def extract_json_from_response(response_text: str) -> Dict[str, Any]: """ Extract JSON data from response text Args: response_text: Text containing JSON data Returns: Extracted JSON dictionary Implementation Logic: 1. Try to parse entire text as JSON 2. If that fails, search for JSON blocks 3. Return parsed JSON or error information """ try: # Try to parse entire text as JSON return json.loads(response_text) except json.JSONDecodeError: # Search for JSON blocks in text import re json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' matches = re.findall(json_pattern, response_text, re.DOTALL) for match in matches: try: return json.loads(match) except json.JSONDecodeError: continue return {"error": "No valid JSON found in response"} def main(): """ Main function for Audio Feature Extractor with command line support Implementation Logic: 1. Parse command line arguments for audio input and output paths 2. Initialize AudioFeatureExtractor with default settings 3. Process specified audio file 4. Extract features using Triton optimization 5. Save results to specified output file """ # Set up command line argument parser parser = argparse.ArgumentParser( description="Audio Feature Extractor - Extract word-level features from audio files", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python audio_feature_extractor.py --audio input.wav --output features.json python audio_feature_extractor.py -a speech.wav -o output.json python audio_feature_extractor.py --audio /path/to/audio.wav --output /path/to/features.json """ ) parser.add_argument( '--audio', '-a', type=str, required=True, help='Path to input audio file (required)' ) parser.add_argument( '--output', '-o', type=str, required=True, help='Path to output JSON file for extracted features (required)' ) parser.add_argument( '--device', type=str, default="auto", choices=["auto", "cpu", "cuda"], help='Computing device (default: auto)' ) parser.add_argument( '--merge-threshold', type=float, default=0.5, help='Word merging threshold in seconds (default: 0.5)' ) parser.add_argument( '--text', type=str, default=None, help='Pre-transcribed text (optional, if not provided Whisper will transcribe)' ) parser.add_argument( '--no-word-merging', action='store_true', help='Disable word merging for short segments' ) # Parse arguments args = parser.parse_args() print("🎵 Audio Feature Extractor") print("=" * 50) print(f"📁 Input audio: {args.audio}") print(f"💾 Output file: {args.output}") print(f"🖥️ Device: {args.device}") print(f"⏱️ Merge threshold: {args.merge_threshold}s") print(f"📝 Pre-transcribed text: {'Yes' if args.text else 'No (will use Whisper)'}") print(f"🔗 Word merging: {'Disabled' if args.no_word_merging else 'Enabled'}") try: # Check if input audio file exists if not os.path.exists(args.audio): print(f"❌ Audio file not found: {args.audio}") return 1 # Create output directory if it doesn't exist output_dir = os.path.dirname(args.output) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) print(f"📁 Created output directory: {output_dir}") # Initialize extractor print(f"\n🚀 Initializing Audio Feature Extractor...") extractor = AudioFeatureExtractor( device=args.device, merge_threshold=args.merge_threshold ) # Extract features print(f"\n🔄 Processing audio file: {args.audio}") features = extractor.extract_features( audio_path=args.audio, text=args.text, enable_word_merging=not args.no_word_merging ) # Display results print("\n📊 Extraction Results:") print("-" * 30) if "error" in features: print(f"❌ Error: {features['error']}") return 1 else: print(f"✅ Success!") print(f" Audio file: {features.get('audio_path', 'N/A')}") print(f" Transcribed text: {features.get('transcribed_text', 'N/A')}") print(f" Total duration: {features.get('total_duration', 0):.2f}s") print(f" Word count: {features.get('final_word_count', 0)}") print(f" Triton optimization: {features.get('triton_optimization', False)}") # Display word-level features summary word_features = features.get('word_features', []) if word_features: print(f"\n📝 Word-level Features Summary:") print(f" Total words: {len(word_features)}") if len(word_features) > 0: print(f" First word: '{word_features[0].get('word', 'N/A')}'") print(f" Last word: '{word_features[-1].get('word', 'N/A')}'") # Show first 3 words details if available if len(word_features) >= 3: print(f"\n📝 First 3 words details:") for i, word_feat in enumerate(word_features[:3]): print(f" Word {i+1}: '{word_feat.get('word', 'N/A')}'") print(f" Duration: {word_feat.get('duration', 0):.3f}s") print(f" Pitch mean: {word_feat.get('pitch_mean', 0):.1f}Hz") print(f" RMS energy: {word_feat.get('rms_energy', 0):.4f}") # Save results to specified output file print(f"\n💾 Saving results to: {args.output}") with open(args.output, 'w', encoding='utf-8') as f: json.dump(features, f, indent=2, ensure_ascii=False) print(f"✅ Results saved successfully!") print("\n🎉 Processing completed successfully!") return 0 except Exception as e: print(f"❌ Processing failed: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": main()