ff / app.py
wasmdashai's picture
Update app.py
bf679e1 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, VitsModel
import torch
import numpy as np
import os
import noisereduce as nr
import base64
import io
import soundfile as sf
# قراءة التوكن من Secrets
token = os.getenv("acees-token")
# تخزين النماذج
models = {}
# اختيار الجهاز
device = "cuda" if torch.cuda.is_available() else "cpu"
# إزالة الضوضاء
def remove_noise(audio_data, sr=16000):
return nr.reduce_noise(y=audio_data, hop_length=256, sr=sr)
# تحميل النموذج
def get_model(name_model):
global models
if name_model in models:
tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
return models[name_model], tokenizer
model = VitsModel.from_pretrained(name_model, token=token)
model.decoder.apply_weight_norm()
for flow in model.flow.flows:
torch.nn.utils.weight_norm(flow.conv_pre)
torch.nn.utils.weight_norm(flow.conv_post)
model.to(device)
models[name_model] = model
tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
return model, tokenizer
# نموذج البيانات للـ POST
class TTSRequest(BaseModel):
text: str
name_model: str = "wasmdashai/vits-ar-sa-huba-v2"
speaking_rate: float = 16000.0
# إنشاء التطبيق
app = FastAPI(title="VITS TTS API", description="Convert Arabic/English text to speech using VITS models")
@app.get("/", summary="Health check")
def home():
return {"message": "FastAPI VITS TTS service is running"}
@app.post("/predict/", summary="Text-to-Speech", description="Convert text to audio (WAV, Base64)")
def modelspeech(req: TTSRequest):
try:
model, tokenizer = get_model(req.name_model)
inputs = tokenizer(req.text, return_tensors="pt").to(device)
model.speaking_rate = req.speaking_rate
with torch.no_grad():
outputs = model(**inputs)
waveform = outputs.waveform[0].cpu().numpy()
# إزالة الضوضاء
waveform = remove_noise(waveform)
# تحويل الصوت إلى Base64 WAV
buffer = io.BytesIO()
sf.write(buffer, waveform, samplerate=model.config.sampling_rate, format="WAV")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return {
"sampling_rate": model.config.sampling_rate,
"audio_base64": audio_base64
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))