File size: 5,750 Bytes
86694c3
 
d909077
86694c3
 
 
 
 
 
 
 
 
d909077
86694c3
d909077
86694c3
 
 
 
d909077
 
 
86694c3
 
56a1afa
 
 
 
 
86694c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22d5503
86694c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d909077
86694c3
 
 
 
 
 
 
 
d909077
86694c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from contextlib import asynccontextmanager
import uuid
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
import os
from starlette.exceptions import HTTPException as StarletteHTTPException
from fastapi.middleware.cors import CORSMiddleware
import pathlib
from contextlib import asynccontextmanager
from glob import glob
import os

from models.launch import inference, train_model
from models.spectrogram_cnn import get_model

# distinguish model type for reshaping

load_dotenv() 

SERVER = str(os.environ.get('API_URL'))

path = os.path.dirname(os.path.realpath(__file__))
tempFolderPath = os.path.join(path, "temp")
if not os.path.exists(tempFolderPath):
    os.makedirs(tempFolderPath)
    
    
def load_model_and_parameters():
    setup = {
        "model_name": "C6XL",
        "dataset_name": "InverSynth",
        "epochs": 1,
        "dataset_dir": "test_datasets",
        "output_dir": "output",
        "dataset_file": None,
        "parameters_file": None,
        "data_format": "channels_last",
        "run_name": None,
        "resume": True,
    }
    setup["model_type"] = "STFT"

    try:
        # charger model
        model, parameters_file = train_model(model_callback=get_model, **setup)
    except Exception as e:
        print(f"Couldn't load model: {e}")
        return None, None

    return model, parameters_file

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Remove all files in the temp folder
    tempFolderPath = os.path.join(path, "temp")
    if os.path.exists(tempFolderPath):
        for file_name in os.listdir(tempFolderPath):
            file_path = os.path.join(tempFolderPath, file_name)
            try:
                if os.path.isfile(file_path):
                    os.remove(file_path)
            except Exception as e:
                print(f"Error deleting file {file_path}: {e}")

    if not os.path.exists(tempFolderPath):
        os.makedirs(tempFolderPath)

    yield
 
app = FastAPI(lifespan=lifespan)

str_p = str(path)


class SPAStaticFiles(StaticFiles):
    async def get_response(self, path: str, scope):
        try:
            return await super().get_response(path, scope)
        except (HTTPException, StarletteHTTPException) as ex:
            if ex.status_code == 404:
                return await super().get_response("index.html", scope)
            else:
                raise ex

@app.get("/download/{file_id}")
async def generate_audio(file_id: str):
    try:
        # Use glob to find files starting with the specified ID
        matching_files = glob(f"temp/{file_id}*")

        if not matching_files:
            # Handle the case when no matching file is found
            print(f"No file found for file ID {file_id}")
            raise HTTPException(status_code=404, detail="File not found")

        # Assuming you want to copy the first matching file
        else:
            source_file_path = matching_files[0]
            # Check if the file exists
            # You can perform additional processing or send the file directly
            return JSONResponse(content={"url": f"{source_file_path}"})

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


def is_valid_audio(file_extension):
    # Define a list of valid audio file extensions
    valid_audio_extensions = [".mp3", ".wav", ".ogg", ".flac", ".m4a"]

    # Check if the provided file extension is in the list of valid audio extensions
    return file_extension.lower() in valid_audio_extensions


@app.post("/upload/")
async def upload_audio_file(file: UploadFile = File(...)):
    try:
        model, parameters_file = load_model_and_parameters()
    except:
        raise("Couldn't load model")
    try:
        # Create a unique identifier for the uploaded file
        file_id = str(uuid.uuid4())
        
        # Extract the original file extension
        _, file_extension = os.path.splitext(file.filename)

        # Check if the file has a valid audio extension
        if not is_valid_audio(file_extension):
            raise HTTPException(status_code=400, detail="Invalid audio file format")

        
        # Construct the file paths with the original file extension
        file_path = os.path.join("temp", file_id + file_extension)

        
        with open(file_path, "wb") as audio_file:
            audio_file.write(file.file.read())
        # generate_output_audio(file_path, output_file_path)
        output = await start_inference(model=model, parameters_file=parameters_file, file_id=file_id, file_extension=file_extension)
        # Send a confirmation with the identifier
        print(SERVER+output[0])
        return {"file_path": SERVER+output[0], "csv_path": SERVER+output[1], "output_file_path": SERVER+output[2]}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

async def start_inference(model, parameters_file, file_id: str, file_extension : str):
    file_path = os.path.join("temp", file_id + file_extension)
    
    output = inference(model=model, parameters_file=parameters_file, file_path=file_path, file_id=file_id)
    
    return output


origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=origins,
    allow_headers=origins,
)

app.mount(
    "/temp", StaticFiles(directory="temp", check_dir=True, html=True), name="temp"
)
app.mount(
    "/",
    SPAStaticFiles(directory=f"{pathlib.PurePath(str_p).parent}/front/dist", html=True),
    name="dist",
)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)