from contextlib import asynccontextmanager import uuid from dotenv import load_dotenv from fastapi import Depends, FastAPI, File, HTTPException, UploadFile from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles import os from starlette.exceptions import HTTPException as StarletteHTTPException from fastapi.middleware.cors import CORSMiddleware import pathlib from contextlib import asynccontextmanager from glob import glob import os from models.launch import inference, train_model from models.spectrogram_cnn import get_model # distinguish model type for reshaping load_dotenv() SERVER = str(os.environ.get('API_URL')) path = os.path.dirname(os.path.realpath(__file__)) tempFolderPath = os.path.join(path, "temp") if not os.path.exists(tempFolderPath): os.makedirs(tempFolderPath) def load_model_and_parameters(): setup = { "model_name": "C6XL", "dataset_name": "InverSynth", "epochs": 1, "dataset_dir": "test_datasets", "output_dir": "output", "dataset_file": None, "parameters_file": None, "data_format": "channels_last", "run_name": None, "resume": True, } setup["model_type"] = "STFT" try: # charger model model, parameters_file = train_model(model_callback=get_model, **setup) except Exception as e: print(f"Couldn't load model: {e}") return None, None return model, parameters_file @asynccontextmanager async def lifespan(app: FastAPI): # Remove all files in the temp folder tempFolderPath = os.path.join(path, "temp") if os.path.exists(tempFolderPath): for file_name in os.listdir(tempFolderPath): file_path = os.path.join(tempFolderPath, file_name) try: if os.path.isfile(file_path): os.remove(file_path) except Exception as e: print(f"Error deleting file {file_path}: {e}") if not os.path.exists(tempFolderPath): os.makedirs(tempFolderPath) yield app = FastAPI(lifespan=lifespan) str_p = str(path) class SPAStaticFiles(StaticFiles): async def get_response(self, path: str, scope): try: return await super().get_response(path, scope) except (HTTPException, StarletteHTTPException) as ex: if ex.status_code == 404: return await super().get_response("index.html", scope) else: raise ex @app.get("/download/{file_id}") async def generate_audio(file_id: str): try: # Use glob to find files starting with the specified ID matching_files = glob(f"temp/{file_id}*") if not matching_files: # Handle the case when no matching file is found print(f"No file found for file ID {file_id}") raise HTTPException(status_code=404, detail="File not found") # Assuming you want to copy the first matching file else: source_file_path = matching_files[0] # Check if the file exists # You can perform additional processing or send the file directly return JSONResponse(content={"url": f"{source_file_path}"}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def is_valid_audio(file_extension): # Define a list of valid audio file extensions valid_audio_extensions = [".mp3", ".wav", ".ogg", ".flac", ".m4a"] # Check if the provided file extension is in the list of valid audio extensions return file_extension.lower() in valid_audio_extensions @app.post("/upload/") async def upload_audio_file(file: UploadFile = File(...)): try: model, parameters_file = load_model_and_parameters() except: raise("Couldn't load model") try: # Create a unique identifier for the uploaded file file_id = str(uuid.uuid4()) # Extract the original file extension _, file_extension = os.path.splitext(file.filename) # Check if the file has a valid audio extension if not is_valid_audio(file_extension): raise HTTPException(status_code=400, detail="Invalid audio file format") # Construct the file paths with the original file extension file_path = os.path.join("temp", file_id + file_extension) with open(file_path, "wb") as audio_file: audio_file.write(file.file.read()) # generate_output_audio(file_path, output_file_path) output = await start_inference(model=model, parameters_file=parameters_file, file_id=file_id, file_extension=file_extension) # Send a confirmation with the identifier print(SERVER+output[0]) return {"file_path": SERVER+output[0], "csv_path": SERVER+output[1], "output_file_path": SERVER+output[2]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) async def start_inference(model, parameters_file, file_id: str, file_extension : str): file_path = os.path.join("temp", file_id + file_extension) output = inference(model=model, parameters_file=parameters_file, file_path=file_path, file_id=file_id) return output origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=origins, allow_headers=origins, ) app.mount( "/temp", StaticFiles(directory="temp", check_dir=True, html=True), name="temp" ) app.mount( "/", SPAStaticFiles(directory=f"{pathlib.PurePath(str_p).parent}/front/dist", html=True), name="dist", ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)