Spaces:
Sleeping
Sleeping
| from contextlib import asynccontextmanager | |
| import uuid | |
| from dotenv import load_dotenv | |
| from fastapi import Depends, FastAPI, File, HTTPException, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import os | |
| from starlette.exceptions import HTTPException as StarletteHTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import pathlib | |
| from contextlib import asynccontextmanager | |
| from glob import glob | |
| import os | |
| from models.launch import inference, train_model | |
| from models.spectrogram_cnn import get_model | |
| # distinguish model type for reshaping | |
| load_dotenv() | |
| SERVER = str(os.environ.get('API_URL')) | |
| path = os.path.dirname(os.path.realpath(__file__)) | |
| tempFolderPath = os.path.join(path, "temp") | |
| if not os.path.exists(tempFolderPath): | |
| os.makedirs(tempFolderPath) | |
| def load_model_and_parameters(): | |
| setup = { | |
| "model_name": "C6XL", | |
| "dataset_name": "InverSynth", | |
| "epochs": 1, | |
| "dataset_dir": "test_datasets", | |
| "output_dir": "output", | |
| "dataset_file": None, | |
| "parameters_file": None, | |
| "data_format": "channels_last", | |
| "run_name": None, | |
| "resume": True, | |
| } | |
| setup["model_type"] = "STFT" | |
| try: | |
| # charger model | |
| model, parameters_file = train_model(model_callback=get_model, **setup) | |
| except Exception as e: | |
| print(f"Couldn't load model: {e}") | |
| return None, None | |
| return model, parameters_file | |
| async def lifespan(app: FastAPI): | |
| # Remove all files in the temp folder | |
| tempFolderPath = os.path.join(path, "temp") | |
| if os.path.exists(tempFolderPath): | |
| for file_name in os.listdir(tempFolderPath): | |
| file_path = os.path.join(tempFolderPath, file_name) | |
| try: | |
| if os.path.isfile(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| print(f"Error deleting file {file_path}: {e}") | |
| if not os.path.exists(tempFolderPath): | |
| os.makedirs(tempFolderPath) | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| str_p = str(path) | |
| class SPAStaticFiles(StaticFiles): | |
| async def get_response(self, path: str, scope): | |
| try: | |
| return await super().get_response(path, scope) | |
| except (HTTPException, StarletteHTTPException) as ex: | |
| if ex.status_code == 404: | |
| return await super().get_response("index.html", scope) | |
| else: | |
| raise ex | |
| async def generate_audio(file_id: str): | |
| try: | |
| # Use glob to find files starting with the specified ID | |
| matching_files = glob(f"temp/{file_id}*") | |
| if not matching_files: | |
| # Handle the case when no matching file is found | |
| print(f"No file found for file ID {file_id}") | |
| raise HTTPException(status_code=404, detail="File not found") | |
| # Assuming you want to copy the first matching file | |
| else: | |
| source_file_path = matching_files[0] | |
| # Check if the file exists | |
| # You can perform additional processing or send the file directly | |
| return JSONResponse(content={"url": f"{source_file_path}"}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def is_valid_audio(file_extension): | |
| # Define a list of valid audio file extensions | |
| valid_audio_extensions = [".mp3", ".wav", ".ogg", ".flac", ".m4a"] | |
| # Check if the provided file extension is in the list of valid audio extensions | |
| return file_extension.lower() in valid_audio_extensions | |
| async def upload_audio_file(file: UploadFile = File(...)): | |
| try: | |
| model, parameters_file = load_model_and_parameters() | |
| except: | |
| raise("Couldn't load model") | |
| try: | |
| # Create a unique identifier for the uploaded file | |
| file_id = str(uuid.uuid4()) | |
| # Extract the original file extension | |
| _, file_extension = os.path.splitext(file.filename) | |
| # Check if the file has a valid audio extension | |
| if not is_valid_audio(file_extension): | |
| raise HTTPException(status_code=400, detail="Invalid audio file format") | |
| # Construct the file paths with the original file extension | |
| file_path = os.path.join("temp", file_id + file_extension) | |
| with open(file_path, "wb") as audio_file: | |
| audio_file.write(file.file.read()) | |
| # generate_output_audio(file_path, output_file_path) | |
| output = await start_inference(model=model, parameters_file=parameters_file, file_id=file_id, file_extension=file_extension) | |
| # Send a confirmation with the identifier | |
| print(SERVER+output[0]) | |
| return {"file_path": SERVER+output[0], "csv_path": SERVER+output[1], "output_file_path": SERVER+output[2]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def start_inference(model, parameters_file, file_id: str, file_extension : str): | |
| file_path = os.path.join("temp", file_id + file_extension) | |
| output = inference(model=model, parameters_file=parameters_file, file_path=file_path, file_id=file_id) | |
| return output | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=origins, | |
| allow_headers=origins, | |
| ) | |
| app.mount( | |
| "/temp", StaticFiles(directory="temp", check_dir=True, html=True), name="temp" | |
| ) | |
| app.mount( | |
| "/", | |
| SPAStaticFiles(directory=f"{pathlib.PurePath(str_p).parent}/front/dist", html=True), | |
| name="dist", | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |