from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, VitsModel import torch import numpy as np import os import noisereduce as nr import base64 import io import soundfile as sf # قراءة التوكن من Secrets token = os.getenv("acees-token") # تخزين النماذج models = {} # اختيار الجهاز device = "cuda" if torch.cuda.is_available() else "cpu" # إزالة الضوضاء def remove_noise(audio_data, sr=16000): return nr.reduce_noise(y=audio_data, hop_length=256, sr=sr) # تحميل النموذج def get_model(name_model): global models if name_model in models: tokenizer = AutoTokenizer.from_pretrained(name_model, token=token) return models[name_model], tokenizer model = VitsModel.from_pretrained(name_model, token=token) model.decoder.apply_weight_norm() for flow in model.flow.flows: torch.nn.utils.weight_norm(flow.conv_pre) torch.nn.utils.weight_norm(flow.conv_post) model.to(device) models[name_model] = model tokenizer = AutoTokenizer.from_pretrained(name_model, token=token) return model, tokenizer # نموذج البيانات للـ POST class TTSRequest(BaseModel): text: str name_model: str = "wasmdashai/vits-ar-sa-huba-v2" speaking_rate: float = 16000.0 # إنشاء التطبيق app = FastAPI(title="VITS TTS API", description="Convert Arabic/English text to speech using VITS models") @app.get("/", summary="Health check") def home(): return {"message": "FastAPI VITS TTS service is running"} @app.post("/predict/", summary="Text-to-Speech", description="Convert text to audio (WAV, Base64)") def modelspeech(req: TTSRequest): try: model, tokenizer = get_model(req.name_model) inputs = tokenizer(req.text, return_tensors="pt").to(device) model.speaking_rate = req.speaking_rate with torch.no_grad(): outputs = model(**inputs) waveform = outputs.waveform[0].cpu().numpy() # إزالة الضوضاء waveform = remove_noise(waveform) # تحويل الصوت إلى Base64 WAV buffer = io.BytesIO() sf.write(buffer, waveform, samplerate=model.config.sampling_rate, format="WAV") buffer.seek(0) audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") return { "sampling_rate": model.config.sampling_rate, "audio_base64": audio_base64 } except Exception as e: raise HTTPException(status_code=500, detail=str(e))