Spaces:
Paused
Paused
| """ | |
| Voice cloning module for CSM-1B TTS API. | |
| This module provides functionality to clone voices from audio samples, | |
| with advanced audio preprocessing and voice adaptation techniques. | |
| """ | |
| import os | |
| import io | |
| import time | |
| import tempfile | |
| import logging | |
| import asyncio | |
| import yt_dlp | |
| import whisper | |
| from typing import Dict, List, Optional, Union, Tuple, BinaryIO | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from pydantic import BaseModel | |
| from fastapi import UploadFile | |
| from app.models import Segment | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| # Directory for storing cloned voice data | |
| CLONED_VOICES_DIR = "/app/cloned_voices" | |
| os.makedirs(CLONED_VOICES_DIR, exist_ok=True) | |
| class ClonedVoice(BaseModel): | |
| """Model representing a cloned voice.""" | |
| id: str | |
| name: str | |
| created_at: float | |
| speaker_id: int | |
| description: Optional[str] = None | |
| audio_duration: float | |
| sample_count: int | |
| class VoiceCloner: | |
| """Voice cloning utility for CSM-1B model.""" | |
| def __init__(self, generator, device="cuda"): | |
| """Initialize the voice cloner with a generator instance.""" | |
| self.generator = generator | |
| self.device = device | |
| self.sample_rate = generator.sample_rate | |
| self.cloned_voices = self._load_existing_voices() | |
| logger.info(f"Voice cloner initialized with {len(self.cloned_voices)} existing voices") | |
| def _load_existing_voices(self) -> Dict[str, ClonedVoice]: | |
| """Load existing cloned voices from disk.""" | |
| voices = {} | |
| if not os.path.exists(CLONED_VOICES_DIR): | |
| return voices | |
| for voice_dir in os.listdir(CLONED_VOICES_DIR): | |
| voice_path = os.path.join(CLONED_VOICES_DIR, voice_dir) | |
| if not os.path.isdir(voice_path): | |
| continue | |
| info_path = os.path.join(voice_path, "info.json") | |
| if os.path.exists(info_path): | |
| try: | |
| import json | |
| with open(info_path, "r") as f: | |
| voice_info = json.load(f) | |
| voices[voice_dir] = ClonedVoice(**voice_info) | |
| logger.info(f"Loaded cloned voice: {voice_dir}") | |
| except Exception as e: | |
| logger.error(f"Error loading voice {voice_dir}: {e}") | |
| return voices | |
| async def process_audio_file( | |
| self, | |
| file: Union[UploadFile, BinaryIO, str], | |
| transcript: Optional[str] = None | |
| ) -> Tuple[torch.Tensor, Optional[str], float]: | |
| """ | |
| Process an audio file for voice cloning. | |
| Args: | |
| file: The audio file (UploadFile, file-like object, or path) | |
| transcript: Optional transcript of the audio | |
| Returns: | |
| Tuple of (processed_audio, transcript, duration_seconds) | |
| """ | |
| temp_path = None | |
| try: | |
| # Handle different input types | |
| if isinstance(file, str): | |
| # It's a file path | |
| audio_path = file | |
| logger.info(f"Processing audio from file path: {audio_path}") | |
| else: | |
| # Create a temporary file | |
| temp_fd, temp_path = tempfile.mkstemp(suffix=".wav") | |
| os.close(temp_fd) # Close the file descriptor | |
| if isinstance(file, UploadFile): | |
| # It's a FastAPI UploadFile | |
| logger.info("Processing audio from UploadFile") | |
| contents = await file.read() | |
| with open(temp_path, "wb") as f: | |
| f.write(contents) | |
| elif hasattr(file, 'read'): | |
| # It's a file-like object - check if it's async | |
| logger.info("Processing audio from file-like object") | |
| if asyncio.iscoroutinefunction(file.read): | |
| # It's an async read method | |
| contents = await file.read() | |
| else: | |
| # It's a sync read method | |
| contents = file.read() | |
| with open(temp_path, "wb") as f: | |
| f.write(contents) | |
| else: | |
| raise ValueError(f"Unsupported file type: {type(file)}") | |
| audio_path = temp_path | |
| logger.info(f"Saved uploaded audio to temporary file: {audio_path}") | |
| # Load audio | |
| logger.info(f"Loading audio from {audio_path}") | |
| audio, sr = torchaudio.load(audio_path) | |
| # Convert to mono if stereo | |
| if audio.shape[0] > 1: | |
| logger.info(f"Converting {audio.shape[0]} channels to mono") | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| # Remove first dimension if it's 1 | |
| if audio.shape[0] == 1: | |
| audio = audio.squeeze(0) | |
| # Resample if necessary | |
| if sr != self.sample_rate: | |
| logger.info(f"Resampling from {sr}Hz to {self.sample_rate}Hz") | |
| audio = torchaudio.functional.resample( | |
| audio, orig_freq=sr, new_freq=self.sample_rate | |
| ) | |
| # Get audio duration | |
| duration_seconds = len(audio) / self.sample_rate | |
| # Process audio for better quality | |
| logger.info(f"Preprocessing audio for quality enhancement") | |
| processed_audio = self._preprocess_audio(audio) | |
| processed_duration = len(processed_audio) / self.sample_rate | |
| logger.info( | |
| f"Processed audio: original duration={duration_seconds:.2f}s, " | |
| f"processed duration={processed_duration:.2f}s" | |
| ) | |
| return processed_audio, transcript, duration_seconds | |
| except Exception as e: | |
| logger.error(f"Error processing audio: {e}", exc_info=True) | |
| raise RuntimeError(f"Failed to process audio file: {e}") | |
| finally: | |
| # Clean up temp file if we created one | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.unlink(temp_path) | |
| logger.debug(f"Deleted temporary file {temp_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temporary file {temp_path}: {e}") | |
| def _preprocess_audio(self, audio: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Preprocess audio for better voice cloning quality. | |
| Args: | |
| audio: Raw audio tensor | |
| Returns: | |
| Processed audio tensor | |
| """ | |
| # Normalize volume | |
| if torch.max(torch.abs(audio)) > 0: | |
| audio = audio / torch.max(torch.abs(audio)) | |
| # Remove silence with dynamic threshold | |
| audio = self._remove_silence(audio, threshold=0.02) # Slightly higher threshold to remove more noise | |
| # Remove DC offset (very low frequency noise) | |
| audio = audio - torch.mean(audio) | |
| # Apply simple noise reduction | |
| # This filters out very high frequencies that might contain noise | |
| try: | |
| audio_np = audio.cpu().numpy() | |
| from scipy import signal | |
| # Apply a bandpass filter to focus on speech frequencies (80Hz - 8000Hz) | |
| sos = signal.butter(3, [80, 8000], 'bandpass', fs=self.sample_rate, output='sos') | |
| filtered = signal.sosfilt(sos, audio_np) | |
| # Normalize the filtered audio | |
| filtered = filtered / (np.max(np.abs(filtered)) + 1e-8) | |
| # Convert back to torch tensor | |
| audio = torch.tensor(filtered, device=audio.device) | |
| except Exception as e: | |
| logger.warning(f"Advanced audio filtering failed, using basic processing: {e}") | |
| # Ensure audio has correct amplitude | |
| audio = audio * 0.9 # Slightly reduce volume to prevent clipping | |
| return audio | |
| def _remove_silence( | |
| self, | |
| audio: torch.Tensor, | |
| threshold: float = 0.015, | |
| min_silence_duration: float = 0.2 | |
| ) -> torch.Tensor: | |
| """ | |
| Remove silence from audio while preserving speech rhythm. | |
| Args: | |
| audio: Input audio tensor | |
| threshold: Energy threshold for silence detection | |
| min_silence_duration: Minimum silence duration in seconds | |
| Returns: | |
| Audio with silence removed | |
| """ | |
| # Convert to numpy for easier processing | |
| audio_np = audio.cpu().numpy() | |
| # Calculate energy | |
| energy = np.abs(audio_np) | |
| # Find regions above threshold (speech) | |
| is_speech = energy > threshold | |
| # Convert min_silence_duration to samples | |
| min_silence_samples = int(min_silence_duration * self.sample_rate) | |
| # Find speech segments | |
| speech_segments = [] | |
| in_speech = False | |
| speech_start = 0 | |
| for i in range(len(is_speech)): | |
| if is_speech[i] and not in_speech: | |
| # Start of speech segment | |
| in_speech = True | |
| speech_start = i | |
| elif not is_speech[i] and in_speech: | |
| # Potential end of speech segment | |
| # Only end if silence is long enough | |
| silence_count = 0 | |
| for j in range(i, min(len(is_speech), i + min_silence_samples)): | |
| if not is_speech[j]: | |
| silence_count += 1 | |
| else: | |
| break | |
| if silence_count >= min_silence_samples: | |
| # End of speech segment | |
| in_speech = False | |
| speech_segments.append((speech_start, i)) | |
| # Handle case where audio ends during speech | |
| if in_speech: | |
| speech_segments.append((speech_start, len(is_speech))) | |
| # If no speech segments found, return original audio | |
| if not speech_segments: | |
| logger.warning("No speech segments detected, returning original audio") | |
| return audio | |
| # Add small buffer around segments | |
| buffer_samples = int(0.05 * self.sample_rate) # 50ms buffer | |
| processed_segments = [] | |
| for start, end in speech_segments: | |
| buffered_start = max(0, start - buffer_samples) | |
| buffered_end = min(len(audio_np), end + buffer_samples) | |
| processed_segments.append(audio_np[buffered_start:buffered_end]) | |
| # Concatenate all segments with small pauses between them | |
| small_pause = np.zeros(int(0.15 * self.sample_rate)) # 150ms pause | |
| result = processed_segments[0] | |
| for segment in processed_segments[1:]: | |
| result = np.concatenate([result, small_pause, segment]) | |
| return torch.tensor(result, device=audio.device) | |
| def _enhance_speech(self, audio: torch.Tensor) -> torch.Tensor: | |
| """Enhance speech quality for better cloning results.""" | |
| # This is a placeholder for more advanced speech enhancement | |
| # In a production implementation, you could add: | |
| # - Noise reduction | |
| # - Equalization for speech frequencies | |
| # - Gentle compression for better dynamics | |
| return audio | |
| async def clone_voice( | |
| self, | |
| audio_file: Union[UploadFile, BinaryIO, str], | |
| voice_name: str, | |
| transcript: Optional[str] = None, | |
| description: Optional[str] = None, | |
| speaker_id: Optional[int] = None # Make this optional | |
| ) -> ClonedVoice: | |
| """ | |
| Clone a voice from an audio file. | |
| Args: | |
| audio_file: Audio file with the voice to clone | |
| voice_name: Name for the cloned voice | |
| transcript: Transcript of the audio (optional) | |
| description: Description of the voice (optional) | |
| speaker_id: Speaker ID to use (default: auto-assigned) | |
| Returns: | |
| ClonedVoice object with voice information | |
| """ | |
| logger.info(f"Cloning new voice '{voice_name}' from audio file") | |
| # Process the audio file | |
| processed_audio, provided_transcript, duration = await self.process_audio_file( | |
| audio_file, transcript | |
| ) | |
| # Use a better speaker ID assignment - use a small number similar to the built-in voices | |
| # This prevents issues with the speaker ID being interpreted as speech | |
| if speaker_id is None: | |
| # Use a number between 10-20 to avoid conflicts with built-in voices (0-5) | |
| # but not too large like 999 which might cause issues | |
| existing_ids = [v.speaker_id for v in self.cloned_voices.values()] | |
| for potential_id in range(10, 20): | |
| if potential_id not in existing_ids: | |
| speaker_id = potential_id | |
| break | |
| else: | |
| # If all IDs in range are taken, use a fallback | |
| speaker_id = 10 | |
| # Generate a unique ID for the voice | |
| voice_id = f"{int(time.time())}_{voice_name.lower().replace(' ', '_')}" | |
| # Create directory for the voice | |
| voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id) | |
| os.makedirs(voice_dir, exist_ok=True) | |
| # Save the processed audio | |
| audio_path = os.path.join(voice_dir, "reference.wav") | |
| torchaudio.save(audio_path, processed_audio.unsqueeze(0).cpu(), self.sample_rate) | |
| # Save the transcript if provided | |
| if provided_transcript: | |
| transcript_path = os.path.join(voice_dir, "transcript.txt") | |
| with open(transcript_path, "w") as f: | |
| f.write(provided_transcript) | |
| # Create and save voice info | |
| voice_info = ClonedVoice( | |
| id=voice_id, | |
| name=voice_name, | |
| created_at=time.time(), | |
| speaker_id=speaker_id, | |
| description=description, | |
| audio_duration=duration, | |
| sample_count=len(processed_audio) | |
| ) | |
| # Save voice info as JSON | |
| import json | |
| with open(os.path.join(voice_dir, "info.json"), "w") as f: | |
| f.write(json.dumps(voice_info.dict())) | |
| # Add to cloned voices dictionary | |
| self.cloned_voices[voice_id] = voice_info | |
| logger.info(f"Voice '{voice_name}' cloned successfully with ID: {voice_id} and speaker_id: {speaker_id}") | |
| return voice_info | |
| def get_voice_context(self, voice_id: str) -> List[Segment]: | |
| """ | |
| Get context segments for a cloned voice. | |
| Args: | |
| voice_id: ID of the cloned voice | |
| Returns: | |
| List of context segments for the voice | |
| """ | |
| if voice_id not in self.cloned_voices: | |
| logger.warning(f"Voice ID {voice_id} not found") | |
| return [] | |
| voice = self.cloned_voices[voice_id] | |
| voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id) | |
| audio_path = os.path.join(voice_dir, "reference.wav") | |
| if not os.path.exists(audio_path): | |
| logger.error(f"Audio file for voice {voice_id} not found at {audio_path}") | |
| return [] | |
| try: | |
| # Load the audio | |
| audio, sr = torchaudio.load(audio_path) | |
| audio = audio.squeeze(0) | |
| # Resample if necessary | |
| if sr != self.sample_rate: | |
| audio = torchaudio.functional.resample( | |
| audio, orig_freq=sr, new_freq=self.sample_rate | |
| ) | |
| # Trim to a maximum of 5 seconds to avoid sequence length issues | |
| # This is a balance between voice quality and model limitations | |
| max_samples = 5 * self.sample_rate # 5 seconds | |
| if audio.shape[0] > max_samples: | |
| logger.info(f"Trimming voice sample from {audio.shape[0]} to {max_samples} samples") | |
| # Take from beginning for better voice characteristics | |
| audio = audio[:max_samples] | |
| # Load transcript if available | |
| transcript_path = os.path.join(voice_dir, "transcript.txt") | |
| transcript = "" | |
| if os.path.exists(transcript_path): | |
| with open(transcript_path, "r") as f: | |
| full_transcript = f.read() | |
| # Take a portion of transcript that roughly matches our audio portion | |
| words = full_transcript.split() | |
| # Estimate 3 words per second as a rough average | |
| word_count = min(len(words), int(5 * 3)) # 5 seconds * 3 words/second | |
| transcript = " ".join(words[:word_count]) | |
| else: | |
| transcript = f"Voice sample for {voice.name}" | |
| # Create context segment | |
| segment = Segment( | |
| text=transcript, | |
| speaker=voice.speaker_id, | |
| audio=audio.to(self.device) | |
| ) | |
| logger.info(f"Created voice context segment with {audio.shape[0]/self.sample_rate:.1f}s audio") | |
| return [segment] | |
| except Exception as e: | |
| logger.error(f"Error getting voice context for {voice_id}: {e}") | |
| return [] | |
| def list_voices(self) -> List[ClonedVoice]: | |
| """List all available cloned voices.""" | |
| return list(self.cloned_voices.values()) | |
| def delete_voice(self, voice_id: str) -> bool: | |
| """ | |
| Delete a cloned voice. | |
| Args: | |
| voice_id: ID of the voice to delete | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| if voice_id not in self.cloned_voices: | |
| return False | |
| voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id) | |
| if os.path.exists(voice_dir): | |
| try: | |
| import shutil | |
| shutil.rmtree(voice_dir) | |
| del self.cloned_voices[voice_id] | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error deleting voice {voice_id}: {e}") | |
| return False | |
| return False | |
| async def clone_voice_from_youtube( | |
| self, # Don't forget the self parameter for class methods | |
| youtube_url: str, | |
| voice_name: str, | |
| start_time: int = 0, | |
| duration: int = 180, | |
| description: str = None | |
| ) -> ClonedVoice: | |
| """ | |
| Clone a voice from a YouTube video. | |
| Args: | |
| youtube_url: URL of the YouTube video | |
| voice_name: Name for the cloned voice | |
| start_time: Start time in seconds | |
| duration: Duration to extract in seconds | |
| description: Optional description of the voice | |
| Returns: | |
| ClonedVoice object with voice information | |
| """ | |
| logger.info(f"Cloning voice '{voice_name}' from YouTube: {youtube_url}") | |
| # Create temporary directory for processing | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Step 1: Download audio from YouTube | |
| audio_path = await self._download_youtube_audio(youtube_url, temp_dir, start_time, duration) | |
| # Step 2: Generate transcript using Whisper | |
| transcript = await self._generate_transcript(audio_path) | |
| # Step 3: Clone the voice using the extracted audio and transcript | |
| voice = await self.clone_voice( | |
| audio_file=audio_path, | |
| voice_name=voice_name, | |
| transcript=transcript, | |
| description=description or f"Voice cloned from YouTube: {youtube_url}" | |
| ) | |
| return voice | |
| async def _download_youtube_audio( | |
| self, # Don't forget the self parameter | |
| url: str, | |
| output_dir: str, | |
| start_time: int = 0, | |
| duration: int = 180 | |
| ) -> str: | |
| """ | |
| Download audio from a YouTube video. | |
| Args: | |
| url: YouTube URL | |
| output_dir: Directory to save the audio | |
| start_time: Start time in seconds | |
| duration: Duration to extract in seconds | |
| Returns: | |
| Path to the downloaded audio file | |
| """ | |
| output_path = os.path.join(output_dir, "youtube_audio.wav") | |
| # Configure yt-dlp options | |
| ydl_opts = { | |
| 'format': 'bestaudio/best', | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| 'preferredquality': '192', | |
| }], | |
| 'outtmpl': output_path.replace(".wav", ""), | |
| 'quiet': True, | |
| 'no_warnings': True | |
| } | |
| # Download the video | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([url]) | |
| # Trim the audio to the specified segment | |
| if start_time > 0 or duration < float('inf'): | |
| import ffmpeg | |
| trimmed_path = os.path.join(output_dir, "trimmed_audio.wav") | |
| # Use ffmpeg to trim the audio | |
| ( | |
| ffmpeg.input(output_path) | |
| .audio | |
| .filter('atrim', start=start_time, duration=duration) | |
| .output(trimmed_path) | |
| .run(quiet=True, overwrite_output=True) | |
| ) | |
| return trimmed_path | |
| return output_path | |
| async def _generate_transcript(self, audio_path: str) -> str: | |
| """ | |
| Generate transcript from audio using Whisper. | |
| Args: | |
| audio_path: Path to the audio file | |
| Returns: | |
| Transcript text | |
| """ | |
| # Load Whisper model (use small model for faster processing) | |
| model = whisper.load_model("small") | |
| # Transcribe the audio | |
| result = model.transcribe(audio_path) | |
| return result["text"] | |
| def generate_speech( | |
| self, | |
| text: str, | |
| voice_id: str, | |
| temperature: float = 0.65, | |
| topk: int = 30, | |
| max_audio_length_ms: int = 15000 | |
| ) -> torch.Tensor: | |
| """ | |
| Generate speech with a cloned voice. | |
| Args: | |
| text: Text to synthesize | |
| voice_id: ID of the cloned voice to use | |
| temperature: Sampling temperature (lower = more stable, higher = more varied) | |
| topk: Top-k sampling parameter | |
| max_audio_length_ms: Maximum audio length in milliseconds | |
| Returns: | |
| Generated audio tensor | |
| """ | |
| # Remove any async/await keywords - this is a synchronous function | |
| if voice_id not in self.cloned_voices: | |
| raise ValueError(f"Voice ID {voice_id} not found") | |
| voice = self.cloned_voices[voice_id] | |
| context = self.get_voice_context(voice_id) | |
| if not context: | |
| raise ValueError(f"Could not get context for voice {voice_id}") | |
| # Preprocess text for better pronunciation | |
| processed_text = self._preprocess_text(text) | |
| logger.info(f"Generating speech with voice '{voice.name}' (ID: {voice_id}, speaker: {voice.speaker_id})") | |
| try: | |
| # Check if text is too long and should be split | |
| if len(processed_text) > 200: | |
| logger.info(f"Text is long ({len(processed_text)} chars), splitting for better quality") | |
| from app.prompt_engineering import split_into_segments | |
| # Split text into manageable segments | |
| segments = split_into_segments(processed_text, max_chars=150) | |
| logger.info(f"Split text into {len(segments)} segments") | |
| all_audio_chunks = [] | |
| # Process each segment | |
| for i, segment_text in enumerate(segments): | |
| logger.info(f"Generating segment {i+1}/{len(segments)}") | |
| # Generate this segment - using plain text without formatting | |
| segment_audio = self.generator.generate( | |
| text=segment_text, # Use plain text, no formatting | |
| speaker=voice.speaker_id, | |
| context=context, | |
| max_audio_length_ms=min(max_audio_length_ms, 10000), | |
| temperature=temperature, | |
| topk=topk, | |
| ) | |
| all_audio_chunks.append(segment_audio) | |
| # Use this segment as context for the next one for consistency | |
| if i < len(segments) - 1: | |
| context = [ | |
| Segment( | |
| text=segment_text, | |
| speaker=voice.speaker_id, | |
| audio=segment_audio | |
| ) | |
| ] | |
| # Combine chunks with small silence between them | |
| if len(all_audio_chunks) == 1: | |
| audio = all_audio_chunks[0] | |
| else: | |
| silence_samples = int(0.1 * self.sample_rate) # 100ms silence | |
| silence = torch.zeros(silence_samples, device=all_audio_chunks[0].device) | |
| # Join segments with silence | |
| audio_parts = [] | |
| for i, chunk in enumerate(all_audio_chunks): | |
| audio_parts.append(chunk) | |
| if i < len(all_audio_chunks) - 1: # Don't add silence after the last chunk | |
| audio_parts.append(silence) | |
| # Concatenate all parts | |
| audio = torch.cat(audio_parts) | |
| return audio | |
| else: | |
| # For short text, generate directly - using plain text without formatting | |
| audio = self.generator.generate( | |
| text=processed_text, # Use plain text, no formatting | |
| speaker=voice.speaker_id, | |
| context=context, | |
| max_audio_length_ms=max_audio_length_ms, | |
| temperature=temperature, | |
| topk=topk, | |
| ) | |
| return audio | |
| except Exception as e: | |
| logger.error(f"Error generating speech with voice {voice_id}: {e}") | |
| raise | |
| def _preprocess_text(self, text: str) -> str: | |
| """Preprocess text for better pronunciation and voice cloning.""" | |
| # Make sure text ends with punctuation for better phrasing | |
| text = text.strip() | |
| if not text.endswith(('.', '?', '!', ';')): | |
| text = text + '.' | |
| return text | |