Spaces:
Paused
Paused
| """Advanced voice enhancement and consistency system for CSM-1B.""" | |
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import soundfile as sf | |
| from typing import Dict, List, Optional, Tuple | |
| import logging | |
| from dataclasses import dataclass | |
| from scipy import signal | |
| # Setup logging | |
| logger = logging.getLogger(__name__) | |
| # Define persistent paths | |
| VOICE_REFERENCES_DIR = "/app/voice_references" | |
| VOICE_PROFILES_DIR = "/app/voice_profiles" | |
| # Ensure directories exist | |
| os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True) | |
| os.makedirs(VOICE_PROFILES_DIR, exist_ok=True) | |
| class VoiceProfile: | |
| """Detailed voice profile with acoustic characteristics.""" | |
| name: str | |
| speaker_id: int | |
| # Acoustic parameters | |
| pitch_range: Tuple[float, float] # Min/max pitch in Hz | |
| intensity_range: Tuple[float, float] # Min/max intensity (volume) | |
| spectral_tilt: float # Brightness vs. darkness | |
| prosody_pattern: str # Pattern of intonation and rhythm | |
| speech_rate: float # Relative speech rate (1.0 = normal) | |
| formant_shift: float # Formant frequency shift (1.0 = no shift) | |
| # Reference audio | |
| reference_segments: List[torch.Tensor] | |
| # Normalization parameters | |
| target_rms: float = 0.2 | |
| target_peak: float = 0.95 | |
| def get_enhancement_params(self) -> Dict: | |
| """Get parameters for enhancing generated audio.""" | |
| return { | |
| "target_rms": self.target_rms, | |
| "target_peak": self.target_peak, | |
| "pitch_range": self.pitch_range, | |
| "formant_shift": self.formant_shift, | |
| "speech_rate": self.speech_rate, | |
| "spectral_tilt": self.spectral_tilt | |
| } | |
| # Voice profiles with carefully tuned parameters | |
| VOICE_PROFILES = { | |
| "alloy": VoiceProfile( | |
| name="alloy", | |
| speaker_id=0, | |
| pitch_range=(85, 180), # Hz - balanced range | |
| intensity_range=(0.15, 0.3), # moderate intensity | |
| spectral_tilt=0.0, # neutral tilt | |
| prosody_pattern="balanced", | |
| speech_rate=1.0, # normal rate | |
| formant_shift=1.0, # no shift | |
| reference_segments=[], | |
| target_rms=0.2, | |
| target_peak=0.95 | |
| ), | |
| "echo": VoiceProfile( | |
| name="echo", | |
| speaker_id=1, | |
| pitch_range=(75, 165), # Hz - lower, resonant | |
| intensity_range=(0.2, 0.35), # slightly stronger | |
| spectral_tilt=-0.2, # more low frequencies | |
| prosody_pattern="deliberate", | |
| speech_rate=0.95, # slightly slower | |
| formant_shift=0.95, # slightly lower formants | |
| reference_segments=[], | |
| target_rms=0.22, # slightly louder | |
| target_peak=0.95 | |
| ), | |
| "fable": VoiceProfile( | |
| name="fable", | |
| speaker_id=2, | |
| pitch_range=(120, 250), # Hz - higher range | |
| intensity_range=(0.15, 0.28), # moderate intensity | |
| spectral_tilt=0.2, # more high frequencies | |
| prosody_pattern="animated", | |
| speech_rate=1.05, # slightly faster | |
| formant_shift=1.05, # slightly higher formants | |
| reference_segments=[], | |
| target_rms=0.19, | |
| target_peak=0.95 | |
| ), | |
| "onyx": VoiceProfile( | |
| name="onyx", | |
| speaker_id=3, | |
| pitch_range=(65, 150), # Hz - deeper range | |
| intensity_range=(0.18, 0.32), # moderate-strong | |
| spectral_tilt=-0.3, # more low frequencies | |
| prosody_pattern="authoritative", | |
| speech_rate=0.93, # slightly slower | |
| formant_shift=0.9, # lower formants | |
| reference_segments=[], | |
| target_rms=0.23, # stronger | |
| target_peak=0.95 | |
| ), | |
| "nova": VoiceProfile( | |
| name="nova", | |
| speaker_id=4, | |
| pitch_range=(90, 200), # Hz - warm midrange | |
| intensity_range=(0.15, 0.27), # moderate | |
| spectral_tilt=-0.1, # slightly warm | |
| prosody_pattern="flowing", | |
| speech_rate=1.0, # normal rate | |
| formant_shift=1.0, # no shift | |
| reference_segments=[], | |
| target_rms=0.2, | |
| target_peak=0.95 | |
| ), | |
| "shimmer": VoiceProfile( | |
| name="shimmer", | |
| speaker_id=5, | |
| pitch_range=(140, 280), # Hz - brighter, higher | |
| intensity_range=(0.15, 0.25), # moderate-light | |
| spectral_tilt=0.3, # more high frequencies | |
| prosody_pattern="light", | |
| speech_rate=1.07, # slightly faster | |
| formant_shift=1.1, # higher formants | |
| reference_segments=[], | |
| target_rms=0.18, # slightly softer | |
| target_peak=0.95 | |
| ) | |
| } | |
| # Voice-specific prompt templates - crafted to establish voice identity clearly | |
| VOICE_PROMPTS = { | |
| "alloy": [ | |
| "Hello, I'm Alloy. I speak with a balanced, natural tone that's easy to understand.", | |
| "This is Alloy speaking. My voice is designed to be clear and conversational.", | |
| "Alloy here - I have a neutral, friendly voice with balanced tone qualities." | |
| ], | |
| "echo": [ | |
| "Hello, I'm Echo. I speak with a resonant, deeper voice that carries well.", | |
| "This is Echo speaking. My voice has a rich, resonant quality with depth.", | |
| "Echo here - My voice is characterized by its warm, resonant tones." | |
| ], | |
| "fable": [ | |
| "Hello, I'm Fable. I speak with a bright, higher-pitched voice that's full of energy.", | |
| "This is Fable speaking. My voice is characterized by its clear, bright quality.", | |
| "Fable here - My voice is light, articulate, and slightly higher-pitched." | |
| ], | |
| "onyx": [ | |
| "Hello, I'm Onyx. I speak with a deep, authoritative voice that commands attention.", | |
| "This is Onyx speaking. My voice has a powerful, deep quality with gravitas.", | |
| "Onyx here - My voice is characterized by its depth and commanding presence." | |
| ], | |
| "nova": [ | |
| "Hello, I'm Nova. I speak with a warm, pleasant mid-range voice that's easy to listen to.", | |
| "This is Nova speaking. My voice has a smooth, harmonious quality.", | |
| "Nova here - My voice is characterized by its warm, friendly mid-tones." | |
| ], | |
| "shimmer": [ | |
| "Hello, I'm Shimmer. I speak with a light, bright voice that's expressive and clear.", | |
| "This is Shimmer speaking. My voice has an airy, higher-pitched quality.", | |
| "Shimmer here - My voice is characterized by its bright, crystalline tones." | |
| ] | |
| } | |
| def initialize_voice_profiles(): | |
| """Initialize voice profiles with default settings. | |
| This function loads existing voice profiles from disk if available, | |
| or initializes them with default settings. | |
| """ | |
| global VOICE_PROFILES | |
| # Try to load existing profiles from persistent storage | |
| profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt") | |
| if os.path.exists(profile_path): | |
| try: | |
| logger.info(f"Loading voice profiles from {profile_path}") | |
| saved_profiles = torch.load(profile_path) | |
| # Update existing profiles with saved data | |
| for name, data in saved_profiles.items(): | |
| if name in VOICE_PROFILES: | |
| VOICE_PROFILES[name].reference_segments = [ | |
| seg.to(torch.device("cpu")) for seg in data.get('reference_segments', []) | |
| ] | |
| logger.info(f"Loaded voice profiles for {len(saved_profiles)} voices") | |
| except Exception as e: | |
| logger.error(f"Error loading voice profiles: {e}") | |
| logger.info("Using default voice profiles") | |
| else: | |
| logger.info("No saved voice profiles found, using defaults") | |
| # Ensure all voices have at least empty reference segments | |
| for name, profile in VOICE_PROFILES.items(): | |
| if not hasattr(profile, 'reference_segments'): | |
| profile.reference_segments = [] | |
| logger.info(f"Voice profiles initialized for {len(VOICE_PROFILES)} voices") | |
| return VOICE_PROFILES | |
| def normalize_audio(audio: torch.Tensor, target_rms: float = 0.2, target_peak: float = 0.95) -> torch.Tensor: | |
| """Apply professional-grade normalization to audio. | |
| Args: | |
| audio: Audio tensor | |
| target_rms: Target RMS level for normalization | |
| target_peak: Target peak level for limiting | |
| Returns: | |
| Normalized audio tensor | |
| """ | |
| # Ensure audio is on CPU for processing | |
| audio_cpu = audio.detach().cpu() | |
| # Handle silent audio | |
| if audio_cpu.abs().max() < 1e-6: | |
| logger.warning("Audio is nearly silent, returning original") | |
| return audio | |
| # Calculate current RMS | |
| current_rms = torch.sqrt(torch.mean(audio_cpu ** 2)) | |
| # Apply RMS normalization | |
| if current_rms > 0: | |
| gain = target_rms / current_rms | |
| normalized = audio_cpu * gain | |
| else: | |
| normalized = audio_cpu | |
| # Apply peak limiting | |
| current_peak = normalized.abs().max() | |
| if current_peak > target_peak: | |
| normalized = normalized * (target_peak / current_peak) | |
| # Return to original device | |
| return normalized.to(audio.device) | |
| def apply_anti_muffling(audio: torch.Tensor, sample_rate: int, clarity_boost: float = 1.2) -> torch.Tensor: | |
| """Apply anti-muffling to improve clarity. | |
| Args: | |
| audio: Audio tensor | |
| sample_rate: Audio sample rate | |
| clarity_boost: Amount of high frequency boost (1.0 = no boost) | |
| Returns: | |
| Processed audio tensor | |
| """ | |
| # Convert to numpy for filtering | |
| audio_np = audio.detach().cpu().numpy() | |
| try: | |
| # Design a high shelf filter to boost high frequencies | |
| # Use a standard high-shelf filter that's supported by scipy.signal | |
| # We'll use a second-order Butterworth high-pass filter as an alternative | |
| cutoff = 2000 # Hz | |
| b, a = signal.butter(2, cutoff/(sample_rate/2), btype='high', analog=False) | |
| # Apply the filter with the clarity boost gain | |
| boosted = signal.filtfilt(b, a, audio_np, axis=0) * clarity_boost | |
| # Mix with original to maintain some warmth | |
| mix_ratio = 0.7 # 70% processed, 30% original | |
| processed = mix_ratio * boosted + (1-mix_ratio) * audio_np | |
| except Exception as e: | |
| logger.warning(f"Audio enhancement failed, using original: {e}") | |
| # Return original audio if enhancement fails | |
| return audio | |
| # Convert back to tensor on original device | |
| return torch.tensor(processed, dtype=audio.dtype, device=audio.device) | |
| def enhance_audio(audio: torch.Tensor, sample_rate: int, voice_profile: VoiceProfile) -> torch.Tensor: | |
| """Apply comprehensive audio enhancement based on voice profile. | |
| Args: | |
| audio: Audio tensor | |
| sample_rate: Audio sample rate | |
| voice_profile: Voice profile containing enhancement parameters | |
| Returns: | |
| Enhanced audio tensor | |
| """ | |
| if audio is None or audio.numel() == 0: | |
| logger.error("Cannot enhance empty audio") | |
| return audio | |
| try: | |
| # Step 1: Normalize audio levels | |
| params = voice_profile.get_enhancement_params() | |
| normalized = normalize_audio( | |
| audio, | |
| target_rms=params["target_rms"], | |
| target_peak=params["target_peak"] | |
| ) | |
| # Step 2: Apply anti-muffling based on spectral tilt | |
| # Positive tilt means brighter voice so less clarity boost needed | |
| clarity_boost = 1.0 + max(0, -params["spectral_tilt"]) * 0.5 | |
| clarified = apply_anti_muffling( | |
| normalized, | |
| sample_rate, | |
| clarity_boost=clarity_boost | |
| ) | |
| # Log the enhancement | |
| logger.debug( | |
| f"Enhanced audio for {voice_profile.name}: " | |
| f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{clarified.pow(2).mean().sqrt().item():.3f}, " | |
| f"Peak: {audio.abs().max().item():.3f}->{clarified.abs().max().item():.3f}" | |
| ) | |
| return clarified | |
| except Exception as e: | |
| logger.error(f"Error in audio enhancement: {e}") | |
| return audio # Return original audio if enhancement fails | |
| def validate_generated_audio( | |
| audio: torch.Tensor, | |
| voice_name: str, | |
| sample_rate: int, | |
| min_expected_duration: float = 0.5 | |
| ) -> Tuple[bool, torch.Tensor, str]: | |
| """Validate and fix generated audio. | |
| Args: | |
| audio: Audio tensor to validate | |
| voice_name: Name of the voice used | |
| sample_rate: Audio sample rate | |
| min_expected_duration: Minimum expected duration in seconds | |
| Returns: | |
| Tuple of (is_valid, fixed_audio, message) | |
| """ | |
| if audio is None: | |
| return False, torch.zeros(1), "Audio is None" | |
| # Check for NaN values | |
| if torch.isnan(audio).any(): | |
| logger.warning(f"Audio for {voice_name} contains NaN values, replacing with zeros") | |
| audio = torch.where(torch.isnan(audio), torch.zeros_like(audio), audio) | |
| # Check audio duration | |
| duration = audio.shape[0] / sample_rate | |
| if duration < min_expected_duration: | |
| logger.warning(f"Audio for {voice_name} is too short ({duration:.2f}s < {min_expected_duration}s)") | |
| return False, audio, f"Audio too short: {duration:.2f}s" | |
| # Check for silent sections - this can indicate generation problems | |
| rms = torch.sqrt(torch.mean(audio ** 2)) | |
| if rms < 0.01: # Very low RMS indicates near silence | |
| logger.warning(f"Audio for {voice_name} is nearly silent (RMS: {rms:.6f})") | |
| return False, audio, f"Audio nearly silent: RMS = {rms:.6f}" | |
| # Check if audio suddenly cuts off - this detects premature stopping | |
| # Calculate RMS in the last 100ms | |
| last_samples = int(0.1 * sample_rate) | |
| if audio.shape[0] > last_samples: | |
| end_rms = torch.sqrt(torch.mean(audio[-last_samples:] ** 2)) | |
| if end_rms > 0.1: # High RMS at the end suggests an abrupt cutoff | |
| logger.warning(f"Audio for {voice_name} may have cut off prematurely (end RMS: {end_rms:.3f})") | |
| return True, audio, "Audio may have cut off prematurely" | |
| return True, audio, "Audio validation passed" | |
| def create_voice_segments(app_state, regenerate: bool = False): | |
| """Create high-quality voice reference segments. | |
| Args: | |
| app_state: Application state containing generator | |
| regenerate: Whether to regenerate existing references | |
| """ | |
| generator = app_state.generator | |
| if not generator: | |
| logger.error("Cannot create voice segments: generator not available") | |
| return | |
| # Use persistent directory for voice reference segments | |
| os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True) | |
| for voice_name, profile in VOICE_PROFILES.items(): | |
| voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name) | |
| os.makedirs(voice_dir, exist_ok=True) | |
| # Check if we already have references | |
| if not regenerate and profile.reference_segments: | |
| logger.info(f"Voice {voice_name} already has {len(profile.reference_segments)} reference segments") | |
| continue | |
| # Get prompts for this voice | |
| prompts = VOICE_PROMPTS[voice_name] | |
| # Generate reference segments | |
| logger.info(f"Generating reference segments for voice: {voice_name}") | |
| reference_segments = [] | |
| for i, prompt in enumerate(prompts): | |
| ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav") | |
| # Skip if file exists and we're not regenerating | |
| if not regenerate and os.path.exists(ref_path): | |
| try: | |
| # Load existing reference | |
| audio_tensor, sr = torchaudio.load(ref_path) | |
| if sr != generator.sample_rate: | |
| audio_tensor = torchaudio.functional.resample( | |
| audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate | |
| ) | |
| else: | |
| audio_tensor = audio_tensor.squeeze(0) | |
| reference_segments.append(audio_tensor.to(generator.device)) | |
| logger.info(f"Loaded existing reference {i+1}/{len(prompts)} for {voice_name}") | |
| continue | |
| except Exception as e: | |
| logger.warning(f"Failed to load existing reference {i+1} for {voice_name}: {e}") | |
| try: | |
| # Use a lower temperature for more stability in reference samples | |
| logger.info(f"Generating reference {i+1}/{len(prompts)} for {voice_name}: '{prompt}'") | |
| # We want references to be as clean as possible | |
| audio = generator.generate( | |
| text=prompt, | |
| speaker=profile.speaker_id, | |
| context=[], # No context for initial samples to prevent voice bleed | |
| max_audio_length_ms=6000, # Shorter for more control | |
| temperature=0.7, # Lower temperature for more stability | |
| topk=30, # More focused sampling | |
| ) | |
| # Validate and enhance the audio | |
| is_valid, audio, message = validate_generated_audio( | |
| audio, voice_name, generator.sample_rate | |
| ) | |
| if is_valid: | |
| # Enhance the audio | |
| audio = enhance_audio(audio, generator.sample_rate, profile) | |
| # Save the reference to persistent storage | |
| torchaudio.save(ref_path, audio.unsqueeze(0).cpu(), generator.sample_rate) | |
| reference_segments.append(audio) | |
| logger.info(f"Generated reference {i+1} for {voice_name}: {message}") | |
| else: | |
| logger.warning(f"Invalid reference for {voice_name}: {message}") | |
| # Try again with different settings if invalid | |
| if i < len(prompts) - 1: | |
| logger.info(f"Trying again with next prompt") | |
| continue | |
| except Exception as e: | |
| logger.error(f"Error generating reference for {voice_name}: {e}") | |
| # Update the voice profile with references | |
| if reference_segments: | |
| VOICE_PROFILES[voice_name].reference_segments = reference_segments | |
| logger.info(f"Updated {voice_name} with {len(reference_segments)} reference segments") | |
| # Save the updated profiles to persistent storage | |
| save_voice_profiles() | |
| def get_voice_segments(voice_name: str, device: torch.device) -> List: | |
| """Get context segments for a given voice. | |
| Args: | |
| voice_name: Name of the voice to use | |
| device: Device to place tensors on | |
| Returns: | |
| List of context segments | |
| """ | |
| from app.models import Segment | |
| if voice_name not in VOICE_PROFILES: | |
| logger.warning(f"Voice {voice_name} not found, defaulting to alloy") | |
| voice_name = "alloy" | |
| profile = VOICE_PROFILES[voice_name] | |
| # If we don't have reference segments yet, create them | |
| if not profile.reference_segments: | |
| try: | |
| # Try to load from disk - use persistent storage | |
| voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name) | |
| if os.path.exists(voice_dir): | |
| reference_segments = [] | |
| prompts = VOICE_PROMPTS[voice_name] | |
| for i, prompt in enumerate(prompts): | |
| ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav") | |
| if os.path.exists(ref_path): | |
| audio_tensor, sr = torchaudio.load(ref_path) | |
| audio_tensor = audio_tensor.squeeze(0) | |
| reference_segments.append(audio_tensor) | |
| if reference_segments: | |
| profile.reference_segments = reference_segments | |
| logger.info(f"Loaded {len(reference_segments)} reference segments for {voice_name}") | |
| except Exception as e: | |
| logger.error(f"Error loading reference segments for {voice_name}: {e}") | |
| # Create context segments from references | |
| context = [] | |
| if profile.reference_segments: | |
| for i, ref_audio in enumerate(profile.reference_segments): | |
| # Use corresponding prompt if available, otherwise use a generic one | |
| text = VOICE_PROMPTS[voice_name][i] if i < len(VOICE_PROMPTS[voice_name]) else f"Voice reference for {voice_name}" | |
| context.append( | |
| Segment( | |
| speaker=profile.speaker_id, | |
| text=text, | |
| audio=ref_audio.to(device) | |
| ) | |
| ) | |
| logger.info(f"Returning {len(context)} context segments for {voice_name}") | |
| return context | |
| def save_voice_profiles(): | |
| """Save voice profiles to persistent storage.""" | |
| os.makedirs(VOICE_PROFILES_DIR, exist_ok=True) | |
| profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt") | |
| # Create a serializable version of the profiles | |
| serializable_profiles = {} | |
| for name, profile in VOICE_PROFILES.items(): | |
| serializable_profiles[name] = { | |
| 'reference_segments': [seg.cpu() for seg in profile.reference_segments] | |
| } | |
| # Save to persistent storage | |
| torch.save(serializable_profiles, profile_path) | |
| logger.info(f"Saved voice profiles to {profile_path}") | |
| def process_generated_audio( | |
| audio: torch.Tensor, | |
| voice_name: str, | |
| sample_rate: int, | |
| text: str | |
| ) -> torch.Tensor: | |
| """Process generated audio for consistency and quality. | |
| Args: | |
| audio: Audio tensor | |
| voice_name: Name of voice used | |
| sample_rate: Audio sample rate | |
| text: Text that was spoken | |
| Returns: | |
| Processed audio tensor | |
| """ | |
| # Validate the audio | |
| is_valid, audio, message = validate_generated_audio(audio, voice_name, sample_rate) | |
| if not is_valid: | |
| logger.warning(f"Generated audio validation issue: {message}") | |
| # Get voice profile for enhancement | |
| profile = VOICE_PROFILES.get(voice_name, VOICE_PROFILES["alloy"]) | |
| # Enhance the audio based on voice profile | |
| enhanced = enhance_audio(audio, sample_rate, profile) | |
| # Log the enhancement | |
| original_duration = audio.shape[0] / sample_rate | |
| enhanced_duration = enhanced.shape[0] / sample_rate | |
| logger.info( | |
| f"Processed audio for '{voice_name}': " | |
| f"Duration: {original_duration:.2f}s->{enhanced_duration:.2f}s, " | |
| f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{enhanced.pow(2).mean().sqrt().item():.3f}" | |
| ) | |
| return enhanced |