Spaces:
Paused
Paused
| # coding=utf-8 | |
| from io import BytesIO | |
| from typing import Optional, Dict, Any, List, Set, Union, Tuple | |
| import os | |
| import time | |
| # Third-party imports | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import HTMLResponse | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from funasr import AutoModel | |
| from dotenv import load_dotenv | |
| import os | |
| import time | |
| import gradio as gr | |
| # ๅ ่ฝฝ็ฏๅขๅ้ | |
| load_dotenv() | |
| # ่ทๅAPI Token | |
| API_TOKEN: str = os.getenv("API_TOKEN") | |
| if not API_TOKEN: | |
| raise RuntimeError("API_TOKEN environment variable is not set") | |
| # ่ฎพ็ฝฎ่ฎค่ฏ | |
| security = HTTPBearer() | |
| app = FastAPI( | |
| title="SenseVoice API", | |
| description="่ฏญ้ณ่ฏๅซ API ๆๅก", | |
| version="1.0.0" | |
| ) | |
| # ๅ ่ฎธ่ทจๅ่ฏทๆฑ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ๅๅงๅๆจกๅ | |
| model = AutoModel( | |
| model="FunAudioLLM/SenseVoiceSmall", | |
| vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", | |
| vad_kwargs={"max_single_segment_time": 30000}, | |
| hub="hf", | |
| device="cuda" | |
| ) | |
| # ๅค็จๅๆ็ๆ ผๅผๅๅฝๆฐ | |
| emotion_dict: Dict[str, str] = { | |
| "<|HAPPY|>": "๐", | |
| "<|SAD|>": "๐", | |
| "<|ANGRY|>": "๐ก", | |
| "<|NEUTRAL|>": "", | |
| "<|FEARFUL|>": "๐ฐ", | |
| "<|DISGUSTED|>": "๐คข", | |
| "<|SURPRISED|>": "๐ฎ", | |
| } | |
| event_dict: Dict[str, str] = { | |
| "<|BGM|>": "๐ผ", | |
| "<|Speech|>": "", | |
| "<|Applause|>": "๐", | |
| "<|Laughter|>": "๐", | |
| "<|Cry|>": "๐ญ", | |
| "<|Sneeze|>": "๐คง", | |
| "<|Breath|>": "", | |
| "<|Cough|>": "๐คง", | |
| } | |
| emoji_dict: Dict[str, str] = { | |
| "<|nospeech|><|Event_UNK|>": "โ", | |
| "<|zh|>": "", | |
| "<|en|>": "", | |
| "<|yue|>": "", | |
| "<|ja|>": "", | |
| "<|ko|>": "", | |
| "<|nospeech|>": "", | |
| "<|HAPPY|>": "๐", | |
| "<|SAD|>": "๐", | |
| "<|ANGRY|>": "๐ก", | |
| "<|NEUTRAL|>": "", | |
| "<|BGM|>": "๐ผ", | |
| "<|Speech|>": "", | |
| "<|Applause|>": "๐", | |
| "<|Laughter|>": "๐", | |
| "<|FEARFUL|>": "๐ฐ", | |
| "<|DISGUSTED|>": "๐คข", | |
| "<|SURPRISED|>": "๐ฎ", | |
| "<|Cry|>": "๐ญ", | |
| "<|EMO_UNKNOWN|>": "", | |
| "<|Sneeze|>": "๐คง", | |
| "<|Breath|>": "", | |
| "<|Cough|>": "๐ท", | |
| "<|Sing|>": "", | |
| "<|Speech_Noise|>": "", | |
| "<|withitn|>": "", | |
| "<|woitn|>": "", | |
| "<|GBG|>": "", | |
| "<|Event_UNK|>": "", | |
| } | |
| lang_dict: Dict[str, str] = { | |
| "<|zh|>": "<|lang|>", | |
| "<|en|>": "<|lang|>", | |
| "<|yue|>": "<|lang|>", | |
| "<|ja|>": "<|lang|>", | |
| "<|ko|>": "<|lang|>", | |
| "<|nospeech|>": "<|lang|>", | |
| } | |
| emo_set: Set[str] = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"} | |
| event_set: Set[str] = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท"} | |
| def format_text_basic(text: str) -> str: | |
| """Replace special tokens with corresponding emojis""" | |
| for token in emoji_dict: | |
| text = text.replace(token, emoji_dict[token]) | |
| return text | |
| def format_text_with_emotion(text: str) -> str: | |
| """Format text with emotion and event markers""" | |
| token_count: Dict[str, int] = {} | |
| original_text = text | |
| for token in emoji_dict: | |
| token_count[token] = text.count(token) | |
| # Determine dominant emotion | |
| dominant_emotion = "<|NEUTRAL|>" | |
| for emotion in emotion_dict: | |
| if token_count[emotion] > token_count[dominant_emotion]: | |
| dominant_emotion = emotion | |
| # Add event markers | |
| text = original_text | |
| for event in event_dict: | |
| if token_count[event] > 0: | |
| text = event_dict[event] + text | |
| # Replace all tokens with their emoji equivalents | |
| for token in emoji_dict: | |
| text = text.replace(token, emoji_dict[token]) | |
| # Add dominant emotion | |
| text = text + emotion_dict[dominant_emotion] | |
| # Clean up emoji spacing | |
| for emoji in emo_set.union(event_set): | |
| text = text.replace(" " + emoji, emoji) | |
| text = text.replace(emoji + " ", emoji) | |
| return text.strip() | |
| def format_text_advanced(text: str) -> str: | |
| """Advanced text formatting with multilingual and complex token handling""" | |
| def get_emotion(text: str) -> Optional[str]: | |
| return text[-1] if text[-1] in emo_set else None | |
| def get_event(text: str) -> Optional[str]: | |
| return text[0] if text[0] in event_set else None | |
| # Handle special cases | |
| text = text.replace("<|nospeech|><|Event_UNK|>", "โ") | |
| for lang in lang_dict: | |
| text = text.replace(lang, "<|lang|>") | |
| # Process text segments | |
| text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] | |
| formatted_text = " " + text_segments[0] | |
| current_event = get_event(formatted_text) | |
| # Merge segments | |
| for i in range(1, len(text_segments)): | |
| if not text_segments[i]: | |
| continue | |
| if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: | |
| text_segments[i] = text_segments[i][1:] | |
| current_event = get_event(text_segments[i]) | |
| if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): | |
| formatted_text = formatted_text[:-1] | |
| formatted_text += text_segments[i].strip() | |
| formatted_text = formatted_text.replace("The.", " ") | |
| return formatted_text.strip() | |
| async def process_audio(audio_data: bytes, language: str = "auto") -> str: | |
| """Process audio data and return transcription result""" | |
| try: | |
| # Convert bytes to numpy array | |
| audio_buffer = BytesIO(audio_data) | |
| waveform, sample_rate = torchaudio.load(audio_buffer) | |
| # Convert to mono channel | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0) | |
| # Convert to numpy array and normalize | |
| input_wav = waveform.numpy().astype(np.float32) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy() | |
| # Model inference | |
| text = model.generate( | |
| input=input_wav, | |
| cache={}, | |
| language=language, | |
| use_itn=True, | |
| batch_size_s=500, | |
| merge_vad=True | |
| ) | |
| # Format result | |
| result = text[0]["text"] | |
| result = format_text_advanced(result) | |
| return result | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| traceback.print_stack() | |
| raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}") | |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: | |
| """Verify Bearer Token authentication""" | |
| if credentials.credentials != API_TOKEN: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid authentication token", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| return credentials | |
| async def transcribe_audio( | |
| file: UploadFile = File(...), | |
| model: Optional[str] = "FunAudioLLM/SenseVoiceSmall", | |
| language: Optional[str] = "auto", | |
| token: HTTPAuthorizationCredentials = Depends(verify_token) | |
| ) -> Dict[str, Union[str, int, float]]: | |
| """Audio transcription endpoint | |
| Args: | |
| file: Audio file (supports common audio formats) | |
| model: Model name, currently only supports FunAudioLLM/SenseVoiceSmall | |
| language: Language code, supports auto/zh/en/yue/ja/ko/nospeech | |
| Returns: | |
| Dict[str, Union[str, int, float]]: { | |
| "text": "Transcription result", | |
| "error_code": 0, | |
| "error_msg": "", | |
| "process_time": 1.234 # Processing time in seconds | |
| } | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Validate file format | |
| if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): | |
| return { | |
| "text": "", | |
| "error_code": 400, | |
| "error_msg": "Unsupported audio format", | |
| "process_time": time.time() - start_time | |
| } | |
| # Validate model | |
| if model != "FunAudioLLM/SenseVoiceSmall": | |
| return { | |
| "text": "", | |
| "error_code": 400, | |
| "error_msg": "Unsupported model", | |
| "process_time": time.time() - start_time | |
| } | |
| # Validate language | |
| if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: | |
| return { | |
| "text": "", | |
| "error_code": 400, | |
| "error_msg": "Unsupported language", | |
| "process_time": time.time() - start_time | |
| } | |
| # Process audio | |
| content = await file.read() | |
| text = await process_audio(content, language) | |
| return { | |
| "text": text, | |
| "error_code": 0, | |
| "error_msg": "", | |
| "process_time": time.time() - start_time | |
| } | |
| except Exception as e: | |
| return { | |
| "text": "", | |
| "error_code": 500, | |
| "error_msg": str(e), | |
| "process_time": time.time() - start_time | |
| } | |
| def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: | |
| """Gradio interface for audio transcription""" | |
| try: | |
| if audio is None: | |
| return "Please upload an audio file" | |
| # Extract audio data | |
| sample_rate, input_wav = audio | |
| # Normalize audio | |
| input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max | |
| # Convert to mono | |
| if len(input_wav.shape) > 1: | |
| input_wav = input_wav.mean(-1) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32) | |
| input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy() | |
| # Model inference | |
| text = model.generate( | |
| input=input_wav, | |
| cache={}, | |
| language=language, | |
| use_itn=True, | |
| batch_size_s=500, | |
| merge_vad=True | |
| ) | |
| # Format result | |
| result = text[0]["text"] | |
| result = format_text_advanced(result) | |
| return result | |
| except Exception as e: | |
| return f"Processing failed: {str(e)}" | |
| # Create Gradio interface with localized labels | |
| demo = gr.Interface( | |
| fn=transcribe_audio_gradio, | |
| inputs=[ | |
| gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| label="Upload audio or record from microphone" | |
| ), | |
| gr.Dropdown( | |
| choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], | |
| value="auto", | |
| label="Select Language" | |
| ) | |
| ], | |
| outputs=gr.Textbox(label="Recognition Result"), | |
| title="SenseVoice Speech Recognition", | |
| description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", | |
| examples=[ | |
| ["examples/zh.mp3", "zh"], | |
| ["examples/en.mp3", "en"], | |
| ] | |
| ) | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Custom Swagger UI redirect | |
| async def custom_swagger_ui_html(): | |
| return HTMLResponse(""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>SenseVoice API Documentation</title> | |
| <meta http-equiv="refresh" content="0;url=/docs/" /> | |
| </head> | |
| <body> | |
| <p>Redirecting to API documentation...</p> | |
| </body> | |
| </html> | |
| """) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |