Fzina's picture
Update app.py
e3876cb verified
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from ultralytics import YOLO
import cv2
import numpy as np
import io
from PIL import Image
import base64
import os
from io import BytesIO
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
app = FastAPI()
# Set the path to the YOLO model
YOLO_MODEL_PATH = "/tmp/models/yolov9s.pt"
# Model Check
if not os.path.exists(YOLO_MODEL_PATH):
raise FileNotFoundError(f"YOLO model not found at {YOLO_MODEL_PATH}")
# Load the YOLOv9 model
try:
model = YOLO(YOLO_MODEL_PATH) # Load the YOLO model from the pre-downloaded path
except Exception as e:
raise RuntimeError(f"Failed to load YOLO model: {str(e)}")
# Class labels for vehicles (cars, motorbikes, buses, trucks, etc.)
vehicle_classes = [2, 3, 5, 7] # Adjust as necessary for your use case
@app.get("/")
async def root():
return {"status": "OK", "model": "YOLOv9s"}
@app.post("/analyze_traffic/")
async def analyze_traffic(file: UploadFile = File(...)):
"""
Analyze the traffic image using YOLOv9 and return the results along with a processed image.
"""
try:
# Load image from the uploaded file
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_np = np.array(image)
# YOLO model detection
logging.info("Running YOLO detection on the uploaded image...")
results = model(image_np) # Perform detection
detections = results[0] # Get detections from the first image (batch size = 1)
# Log raw results for debugging
logging.debug(f"Raw YOLO results: {results}")
# Extract vehicle details
vehicle_count = 0
vehicle_boxes = []
for det in detections.boxes:
cls_id = int(det.cls[0]) if hasattr(det, 'cls') else None
if cls_id in vehicle_classes:
vehicle_count += 1
box = det.xyxy[0] # Bounding box
vehicle_boxes.append((int(box[0]), int(box[1]), int(box[2]), int(box[3])))
#Log for detection structure
logging.debug(f"Detection structure: {det.__dict__}")
# Log detected vehicle details
logging.info(f"Vehicle count: {vehicle_count}")
logging.info(f"Vehicle bounding boxes: {vehicle_boxes}")
# Calculate congestion level based on vehicle count
if vehicle_count > 20:
congestion_level = "High"
elif vehicle_count > 10:
congestion_level = "Medium"
else:
congestion_level = "Low"
# Determine traffic flow rate based on congestion level
flow_rate = "Smooth" if congestion_level == "Low" else "Heavy"
# Draw bounding boxes on the processed image
for (x1, y1, x2, y2) in vehicle_boxes:
cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
# Encode the processed image to base64
_, buffer = cv2.imencode('.jpg', image_np)
processed_image_base64 = base64.b64encode(buffer).decode('utf-8')
# Return the analysis results along with the processed image
return JSONResponse(content={
"vehicle_count": vehicle_count,
"congestion_level": congestion_level,
"flow_rate": flow_rate,
"processed_image_base64": processed_image_base64
})
except Exception as e:
# Log any exceptions that occur
logging.error(f"Error analyzing traffic: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error analyzing traffic: {str(e)}")
except cv2.error as cv_error:
logging.error(f"OpenCV error: {cv_error}")
raise HTTPException(status_code=500, detail="Image processing error.")