Spaces:
Paused
Paused
| """Streaming support for the TTS API.""" | |
| import asyncio | |
| import io | |
| import logging | |
| import time | |
| from typing import AsyncGenerator, Optional, List | |
| import torch | |
| import torchaudio | |
| from fastapi import APIRouter, Request, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from app.api.schemas import SpeechRequest, ResponseFormat | |
| from app.prompt_engineering import split_into_segments | |
| from app.models import Segment | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| class AudioChunker: | |
| """Handle audio chunking for streaming responses.""" | |
| def __init__(self, | |
| sample_rate: int, | |
| format: str = "mp3", | |
| chunk_size_ms: int = 200): # Smaller chunks for better streaming | |
| """ | |
| Initialize audio chunker. | |
| Args: | |
| sample_rate: Audio sample rate in Hz | |
| format: Output audio format (mp3, opus, etc.) | |
| chunk_size_ms: Size of each chunk in milliseconds | |
| """ | |
| self.sample_rate = sample_rate | |
| self.format = format.lower() | |
| self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000)) | |
| logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)") | |
| async def chunk_audio(self, | |
| audio: torch.Tensor, | |
| delay_ms: int = 0) -> AsyncGenerator[bytes, None]: | |
| """ | |
| Convert audio tensor to streaming chunks. | |
| Args: | |
| audio: Audio tensor to stream | |
| delay_ms: Artificial delay between chunks (for testing) | |
| Yields: | |
| Audio chunks as bytes | |
| """ | |
| # Ensure audio is on CPU | |
| if audio.is_cuda: | |
| audio = audio.cpu() | |
| # Calculate number of chunks | |
| num_samples = audio.shape[0] | |
| num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples | |
| logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks") | |
| for i in range(num_chunks): | |
| start_idx = i * self.chunk_size_samples | |
| end_idx = min(start_idx + self.chunk_size_samples, num_samples) | |
| # Extract chunk | |
| chunk = audio[start_idx:end_idx] | |
| # Convert to bytes in requested format | |
| chunk_bytes = await self._format_chunk(chunk) | |
| # Add artificial delay if requested (for testing) | |
| if delay_ms > 0: | |
| await asyncio.sleep(delay_ms / 1000) | |
| yield chunk_bytes | |
| async def _format_chunk(self, chunk: torch.Tensor) -> bytes: | |
| """Convert audio chunk to bytes in the specified format.""" | |
| buf = io.BytesIO() | |
| # Ensure chunk is 1D and on CPU | |
| if len(chunk.shape) == 1: | |
| chunk = chunk.unsqueeze(0) # Add channel dimension | |
| # Ensure chunk is on CPU | |
| if chunk.is_cuda: | |
| chunk = chunk.cpu() | |
| # Save to buffer in specified format | |
| if self.format == "mp3": | |
| torchaudio.save(buf, chunk, self.sample_rate, format="mp3") | |
| elif self.format == "opus": | |
| torchaudio.save(buf, chunk, self.sample_rate, format="opus") | |
| elif self.format == "aac": | |
| torchaudio.save(buf, chunk, self.sample_rate, format="aac") | |
| elif self.format == "flac": | |
| torchaudio.save(buf, chunk, self.sample_rate, format="flac") | |
| elif self.format == "wav": | |
| torchaudio.save(buf, chunk, self.sample_rate, format="wav") | |
| else: | |
| # Default to mp3 | |
| torchaudio.save(buf, chunk, self.sample_rate, format="mp3") | |
| # Get bytes from buffer | |
| buf.seek(0) | |
| return buf.read() | |
| # Helper function to get speaker ID for a voice | |
| def get_speaker_id(app_state, voice): | |
| """Helper function to get speaker ID from voice name or ID""" | |
| if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map: | |
| return app_state.voice_speaker_map[voice] | |
| # Standard voices mapping | |
| voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5} | |
| if voice in voice_to_speaker: | |
| return voice_to_speaker[voice] | |
| # Try parsing as integer | |
| try: | |
| speaker_id = int(voice) | |
| if 0 <= speaker_id < 6: | |
| return speaker_id | |
| except (ValueError, TypeError): | |
| pass | |
| # Check cloned voices if the voice cloner exists | |
| if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None: | |
| # Check by ID | |
| if voice in app_state.voice_cloner.cloned_voices: | |
| return app_state.voice_cloner.cloned_voices[voice].speaker_id | |
| # Check by name | |
| for v_id, v_info in app_state.voice_cloner.cloned_voices.items(): | |
| if v_info.name.lower() == voice.lower(): | |
| return v_info.speaker_id | |
| # Default to alloy | |
| return 0 | |
| async def stream_speech( | |
| request: Request, | |
| speech_request: SpeechRequest, | |
| ): | |
| """ | |
| Stream audio of text being spoken by a realistic voice. | |
| This endpoint provides an OpenAI-compatible streaming interface for TTS. | |
| """ | |
| # Check if model is loaded | |
| if not hasattr(request.app.state, "generator") or request.app.state.generator is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded. Please try again later." | |
| ) | |
| # Get request parameters | |
| model = speech_request.model | |
| input_text = speech_request.input | |
| voice = speech_request.voice | |
| response_format = speech_request.response_format | |
| speed = speech_request.speed | |
| temperature = speech_request.temperature | |
| max_audio_length_ms = speech_request.max_audio_length_ms | |
| # Log the request | |
| logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'") | |
| # Check if text is empty | |
| if not input_text or len(input_text.strip()) == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Input text cannot be empty" | |
| ) | |
| # Get speaker ID for the voice | |
| speaker_id = get_speaker_id(request.app.state, voice) | |
| if speaker_id is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}" | |
| ) | |
| try: | |
| # Create media type based on format | |
| media_type = { | |
| "mp3": "audio/mpeg", | |
| "opus": "audio/opus", | |
| "aac": "audio/aac", | |
| "flac": "audio/flac", | |
| "wav": "audio/wav", | |
| }.get(response_format, "audio/mpeg") | |
| # Create the chunker for streaming | |
| sample_rate = request.app.state.sample_rate | |
| chunker = AudioChunker(sample_rate, response_format) | |
| # Split text into segments using the imported function | |
| from app.prompt_engineering import split_into_segments | |
| text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response | |
| logger.info(f"Split text into {len(text_segments)} segments for incremental streaming") | |
| async def generate_streaming_audio(): | |
| # Check for cloned voice | |
| voice_info = None | |
| from_cloned_voice = False | |
| if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled: | |
| voice_info = request.app.state.get_voice_info(voice) | |
| from_cloned_voice = voice_info and voice_info["type"] == "cloned" | |
| if from_cloned_voice: | |
| # Use cloned voice context for first segment | |
| voice_cloner = request.app.state.voice_cloner | |
| context = voice_cloner.get_voice_context(voice_info["voice_id"]) | |
| else: | |
| # Use standard voice context | |
| from app.voice_enhancement import get_voice_segments | |
| context = get_voice_segments(voice, request.app.state.device) | |
| else: | |
| # Use standard voice context | |
| from app.voice_enhancement import get_voice_segments | |
| context = get_voice_segments(voice, request.app.state.device) | |
| # Send an empty chunk to initialize the connection | |
| yield b'' | |
| # Process each text segment incrementally and stream in real time | |
| for i, segment_text in enumerate(text_segments): | |
| try: | |
| logger.info(f"Generating segment {i+1}/{len(text_segments)}") | |
| # Generate audio for this segment - use async to avoid blocking | |
| if from_cloned_voice: | |
| # Generate with cloned voice | |
| voice_cloner = request.app.state.voice_cloner | |
| # Convert to asynchronous with asyncio.to_thread | |
| segment_audio = await asyncio.to_thread( | |
| voice_cloner.generate_speech, | |
| segment_text, | |
| voice_info["voice_id"], | |
| temperature=temperature, | |
| topk=30, | |
| max_audio_length_ms=2000 # Keep segments short for streaming | |
| ) | |
| else: | |
| # Use standard voice with generator | |
| segment_audio = await asyncio.to_thread( | |
| request.app.state.generator.generate, | |
| segment_text, | |
| speaker_id, | |
| context, | |
| max_audio_length_ms=2000, # Short for quicker generation | |
| temperature=temperature | |
| ) | |
| # Process audio quality for this segment | |
| if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled: | |
| from app.voice_enhancement import process_generated_audio | |
| segment_audio = process_generated_audio( | |
| audio=segment_audio, | |
| voice_name=voice, | |
| sample_rate=sample_rate, | |
| text=segment_text | |
| ) | |
| # Handle speed adjustment | |
| if speed != 1.0 and speed > 0: | |
| try: | |
| # Adjust speed using torchaudio | |
| effects = [["tempo", str(speed)]] | |
| audio_cpu = segment_audio.cpu() | |
| adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor( | |
| audio_cpu.unsqueeze(0), | |
| sample_rate, | |
| effects | |
| ) | |
| segment_audio = adjusted_audio.squeeze(0) | |
| except Exception as e: | |
| logger.warning(f"Failed to adjust speech speed: {e}") | |
| # Convert this segment to bytes and stream immediately | |
| buf = io.BytesIO() | |
| audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio | |
| torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format) | |
| buf.seek(0) | |
| segment_bytes = buf.read() | |
| # Stream this segment immediately | |
| yield segment_bytes | |
| # Update context with this segment for next generation | |
| context = [ | |
| Segment( | |
| text=segment_text, | |
| speaker=speaker_id, | |
| audio=segment_audio | |
| ) | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error generating segment {i+1}: {e}") | |
| # Try to continue with next segment | |
| # Return streaming response | |
| return StreamingResponse( | |
| generate_streaming_audio(), | |
| media_type=media_type, | |
| headers={ | |
| "Content-Disposition": f'attachment; filename="speech.{response_format}"', | |
| "X-Accel-Buffering": "no", # Prevent buffering in nginx | |
| "Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching | |
| "Pragma": "no-cache", | |
| "Expires": "0", | |
| "Connection": "keep-alive", | |
| "Transfer-Encoding": "chunked" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in stream_speech: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") | |
| async def openai_stream_speech( | |
| request: Request, | |
| speech_request: SpeechRequest, | |
| ): | |
| """ | |
| Stream audio in OpenAI-compatible streaming format. | |
| This endpoint is compatible with the OpenAI streaming TTS API. | |
| """ | |
| # Use the same logic as the stream_speech endpoint but with a different name | |
| # to maintain the OpenAI API naming convention | |
| return await stream_speech(request, speech_request) |