Spaces:
Sleeping
Sleeping
wjm55
commited on
Commit
·
aae7036
1
Parent(s):
fae55a0
Implement model caching in app.py and initialize default model at startup. Update test.py to use variable for app URL.
Browse files
app.py
CHANGED
@@ -9,8 +9,14 @@ from ultralytics import YOLO
|
|
9 |
import requests
|
10 |
import supervision as sv
|
11 |
|
|
|
|
|
12 |
|
13 |
def init_model(model_id: str):
|
|
|
|
|
|
|
|
|
14 |
# Define models
|
15 |
MODEL_OPTIONS = {
|
16 |
"YOLOv11-Nano": "medieval-yolov11n.pt",
|
@@ -19,19 +25,26 @@ def init_model(model_id: str):
|
|
19 |
"YOLOv11-Large": "medieval-yolov11l.pt",
|
20 |
"YOLOv11-XLarge": "medieval-yolov11x.pt"
|
21 |
}
|
|
|
22 |
if model_id in MODEL_OPTIONS:
|
23 |
os.makedirs("models", exist_ok=True)
|
24 |
model_path = hf_hub_download(
|
25 |
repo_id="biglam/medieval-manuscript-yolov11",
|
26 |
-
filename=MODEL_OPTIONS[model_id]
|
|
|
27 |
)
|
28 |
-
|
29 |
-
|
|
|
30 |
else:
|
31 |
raise ValueError(f"Model {model_id} not found")
|
32 |
|
33 |
app = FastAPI()
|
34 |
|
|
|
|
|
|
|
|
|
35 |
|
36 |
@app.post("/predict")
|
37 |
async def predict(image: UploadFile,
|
@@ -39,7 +52,7 @@ async def predict(image: UploadFile,
|
|
39 |
conf: float = 0.25,
|
40 |
iou: float = 0.7
|
41 |
):
|
42 |
-
#
|
43 |
model = init_model(model_id)
|
44 |
|
45 |
# Download and open image from URL
|
|
|
9 |
import requests
|
10 |
import supervision as sv
|
11 |
|
12 |
+
# Global variable to store model instances
|
13 |
+
MODEL_CACHE = {}
|
14 |
|
15 |
def init_model(model_id: str):
|
16 |
+
# Return cached model if it exists
|
17 |
+
if model_id in MODEL_CACHE:
|
18 |
+
return MODEL_CACHE[model_id]
|
19 |
+
|
20 |
# Define models
|
21 |
MODEL_OPTIONS = {
|
22 |
"YOLOv11-Nano": "medieval-yolov11n.pt",
|
|
|
25 |
"YOLOv11-Large": "medieval-yolov11l.pt",
|
26 |
"YOLOv11-XLarge": "medieval-yolov11x.pt"
|
27 |
}
|
28 |
+
|
29 |
if model_id in MODEL_OPTIONS:
|
30 |
os.makedirs("models", exist_ok=True)
|
31 |
model_path = hf_hub_download(
|
32 |
repo_id="biglam/medieval-manuscript-yolov11",
|
33 |
+
filename=MODEL_OPTIONS[model_id],
|
34 |
+
cache_dir="models" # Specify cache directory
|
35 |
)
|
36 |
+
model = YOLO(model_path)
|
37 |
+
MODEL_CACHE[model_id] = model
|
38 |
+
return model
|
39 |
else:
|
40 |
raise ValueError(f"Model {model_id} not found")
|
41 |
|
42 |
app = FastAPI()
|
43 |
|
44 |
+
# Initialize default model at startup
|
45 |
+
@app.on_event("startup")
|
46 |
+
async def startup_event():
|
47 |
+
init_model("YOLOv11-XLarge") # Initialize default model
|
48 |
|
49 |
@app.post("/predict")
|
50 |
async def predict(image: UploadFile,
|
|
|
52 |
conf: float = 0.25,
|
53 |
iou: float = 0.7
|
54 |
):
|
55 |
+
# Get model from cache or initialize it
|
56 |
model = init_model(model_id)
|
57 |
|
58 |
# Download and open image from URL
|
test.py
CHANGED
@@ -18,7 +18,7 @@ with open(image_path, 'rb') as f:
|
|
18 |
}
|
19 |
|
20 |
# Send POST request to the endpoint
|
21 |
-
response = requests.post(
|
22 |
files=files,
|
23 |
params=params)
|
24 |
|
|
|
18 |
}
|
19 |
|
20 |
# Send POST request to the endpoint
|
21 |
+
response = requests.post(hf_app + ':7860/predict',
|
22 |
files=files,
|
23 |
params=params)
|
24 |
|