wjm55 commited on
Commit
aae7036
·
1 Parent(s): fae55a0

Implement model caching in app.py and initialize default model at startup. Update test.py to use variable for app URL.

Browse files
Files changed (2) hide show
  1. app.py +17 -4
  2. test.py +1 -1
app.py CHANGED
@@ -9,8 +9,14 @@ from ultralytics import YOLO
9
  import requests
10
  import supervision as sv
11
 
 
 
12
 
13
  def init_model(model_id: str):
 
 
 
 
14
  # Define models
15
  MODEL_OPTIONS = {
16
  "YOLOv11-Nano": "medieval-yolov11n.pt",
@@ -19,19 +25,26 @@ def init_model(model_id: str):
19
  "YOLOv11-Large": "medieval-yolov11l.pt",
20
  "YOLOv11-XLarge": "medieval-yolov11x.pt"
21
  }
 
22
  if model_id in MODEL_OPTIONS:
23
  os.makedirs("models", exist_ok=True)
24
  model_path = hf_hub_download(
25
  repo_id="biglam/medieval-manuscript-yolov11",
26
- filename=MODEL_OPTIONS[model_id]
 
27
  )
28
- local_path = os.path.join("models", model_path)
29
- return YOLO(local_path)
 
30
  else:
31
  raise ValueError(f"Model {model_id} not found")
32
 
33
  app = FastAPI()
34
 
 
 
 
 
35
 
36
  @app.post("/predict")
37
  async def predict(image: UploadFile,
@@ -39,7 +52,7 @@ async def predict(image: UploadFile,
39
  conf: float = 0.25,
40
  iou: float = 0.7
41
  ):
42
- # Initialize model at startup
43
  model = init_model(model_id)
44
 
45
  # Download and open image from URL
 
9
  import requests
10
  import supervision as sv
11
 
12
+ # Global variable to store model instances
13
+ MODEL_CACHE = {}
14
 
15
  def init_model(model_id: str):
16
+ # Return cached model if it exists
17
+ if model_id in MODEL_CACHE:
18
+ return MODEL_CACHE[model_id]
19
+
20
  # Define models
21
  MODEL_OPTIONS = {
22
  "YOLOv11-Nano": "medieval-yolov11n.pt",
 
25
  "YOLOv11-Large": "medieval-yolov11l.pt",
26
  "YOLOv11-XLarge": "medieval-yolov11x.pt"
27
  }
28
+
29
  if model_id in MODEL_OPTIONS:
30
  os.makedirs("models", exist_ok=True)
31
  model_path = hf_hub_download(
32
  repo_id="biglam/medieval-manuscript-yolov11",
33
+ filename=MODEL_OPTIONS[model_id],
34
+ cache_dir="models" # Specify cache directory
35
  )
36
+ model = YOLO(model_path)
37
+ MODEL_CACHE[model_id] = model
38
+ return model
39
  else:
40
  raise ValueError(f"Model {model_id} not found")
41
 
42
  app = FastAPI()
43
 
44
+ # Initialize default model at startup
45
+ @app.on_event("startup")
46
+ async def startup_event():
47
+ init_model("YOLOv11-XLarge") # Initialize default model
48
 
49
  @app.post("/predict")
50
  async def predict(image: UploadFile,
 
52
  conf: float = 0.25,
53
  iou: float = 0.7
54
  ):
55
+ # Get model from cache or initialize it
56
  model = init_model(model_id)
57
 
58
  # Download and open image from URL
test.py CHANGED
@@ -18,7 +18,7 @@ with open(image_path, 'rb') as f:
18
  }
19
 
20
  # Send POST request to the endpoint
21
- response = requests.post(local_app + ':7860/predict',
22
  files=files,
23
  params=params)
24
 
 
18
  }
19
 
20
  # Send POST request to the endpoint
21
+ response = requests.post(hf_app + ':7860/predict',
22
  files=files,
23
  params=params)
24