import os TMP_PATH = "/tmp" os.environ["HF_HOME"] = os.path.join(TMP_PATH, "huggingface") os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"] + "/transformers" os.environ["HF_DATASETS_CACHE"] = os.environ["HF_HOME"] + "/datasets" os.environ["HF_METRICS_CACHE"] = os.environ["HF_HOME"] + "/metrics" os.environ["MPLCONFIGDIR"] = os.path.join(TMP_PATH, "matplotlib") os.environ["TORCH_HOME"] = os.path.join(TMP_PATH, "torch") os.environ["XDG_CACHE_HOME"] = os.path.join(TMP_PATH, "xdg-cache") os.environ["HOME"] = TMP_PATH # Création des dossiers si nécessaires for path in [ os.environ["HF_HOME"], os.environ["MPLCONFIGDIR"], os.environ["TORCH_HOME"], os.environ["XDG_CACHE_HOME"], TMP_PATH ]: os.makedirs(path, exist_ok=True) # --- Imports principaux --- from fastapi import FastAPI, UploadFile, File, HTTPException, Form from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline, set_seed from deep_translator import GoogleTranslator from TTS.api import TTS import whisper import io import torch import scipy.io.wavfile import numpy as np import traceback # --- Init FastAPI --- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Config --- device = "cuda" if torch.cuda.is_available() else "cpu" TEMP_DIR = "/tmp" os.makedirs(TEMP_DIR, exist_ok=True) # --- Modèles chargés dynamiquement --- loaded_tts_models = {} loaded_tts_lingala = None asr_model = None # --- Utilitaires --- def load_tts_model(lang: str): if lang not in loaded_tts_models: model_id = f"OlameMend/mms-tts-{lang}" try: loaded_tts_models[lang] = pipeline("text-to-audio", model=model_id) except Exception as e: raise RuntimeError(f"Erreur lors du chargement du modèle TTS '{model_id}': {e}") return loaded_tts_models[lang] def load_asr_model(): global asr_model if asr_model is None: asr_model = whisper.load_model("tiny") return asr_model def load_tts_lingala_model(): global loaded_tts_lingala if loaded_tts_lingala is None: loaded_tts_lingala = TTS("tts_models/lin/openbible/vits") return loaded_tts_lingala def preprocess_text_tts(text: str) -> str: return text.strip() def generate_tts_audio(lang: str, text: str) -> io.BytesIO: lang = lang.lower() synthesizer = load_tts_model(lang) processed_text = preprocess_text_tts(text) set_seed(555) speech = synthesizer(processed_text) wav_io = io.BytesIO() scipy.io.wavfile.write(wav_io, rate=speech["sampling_rate"], data=speech["audio"][0]) wav_io.seek(0) return wav_io def speech_2_speech_ling(source_audio_path: str, lang: str) -> io.BytesIO: asr = load_asr_model() tts_lingala = load_tts_lingala_model() result = asr.transcribe(source_audio_path, language=lang) text = result["text"] translated_text = GoogleTranslator(source="auto", target="ln").translate(text) wav_io = io.BytesIO() tts_lingala.tts_with_vc_to_file(text=translated_text, speaker_wav=source_audio_path, file_path=wav_io) wav_io.seek(0) return wav_io # --- Endpoints --- @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/tts/") async def api_tts( lang: str = Form(...), text: str = Form(None), file: UploadFile = File(None) ): try: if file: content = await file.read() text = content.decode("utf-8") if not text: raise ValueError("Aucun texte fourni (champ texte ou fichier manquant).") wav_io = generate_tts_audio(lang, text) wav_io.seek(0) return StreamingResponse(wav_io, media_type="audio/wav") except ValueError as ve: raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"Erreur TTS : {str(e)}") @app.post("/speech-to-speech/") async def api_s2st( source_audio: UploadFile = File(...), lang: str = Form(...) ): try: source_path = os.path.join(TEMP_DIR, "source_" + source_audio.filename) with open(source_path, "wb") as f: f.write(await source_audio.read()) wav_io = speech_2_speech_ling(source_path, lang) wav_io.seek(0) return StreamingResponse(wav_io, media_type="audio/wav") except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"Erreur S2ST : {str(e)}")