Spaces:
Sleeping
Sleeping
import logging | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import io | |
from PIL import Image | |
import base64 | |
import os | |
from io import BytesIO | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
app = FastAPI() | |
# Set the path to the YOLO model | |
YOLO_MODEL_PATH = "/tmp/models/yolov9s.pt" | |
# Model Check | |
if not os.path.exists(YOLO_MODEL_PATH): | |
raise FileNotFoundError(f"YOLO model not found at {YOLO_MODEL_PATH}") | |
# Load the YOLOv9 model | |
try: | |
model = YOLO(YOLO_MODEL_PATH) # Load the YOLO model from the pre-downloaded path | |
except Exception as e: | |
raise RuntimeError(f"Failed to load YOLO model: {str(e)}") | |
# Class labels for vehicles (cars, motorbikes, buses, trucks, etc.) | |
vehicle_classes = [2, 3, 5, 7] # Adjust as necessary for your use case | |
async def root(): | |
return {"status": "OK", "model": "YOLOv9s"} | |
async def analyze_traffic(file: UploadFile = File(...)): | |
""" | |
Analyze the traffic image using YOLOv9 and return the results along with a processed image. | |
""" | |
try: | |
# Load image from the uploaded file | |
image_bytes = await file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
image_np = np.array(image) | |
# YOLO model detection | |
logging.info("Running YOLO detection on the uploaded image...") | |
results = model(image_np) # Perform detection | |
detections = results[0] # Get detections from the first image (batch size = 1) | |
# Log raw results for debugging | |
logging.debug(f"Raw YOLO results: {results}") | |
# Extract vehicle details | |
vehicle_count = 0 | |
vehicle_boxes = [] | |
for det in detections.boxes: | |
cls_id = int(det.cls[0]) if hasattr(det, 'cls') else None | |
if cls_id in vehicle_classes: | |
vehicle_count += 1 | |
box = det.xyxy[0] # Bounding box | |
vehicle_boxes.append((int(box[0]), int(box[1]), int(box[2]), int(box[3]))) | |
#Log for detection structure | |
logging.debug(f"Detection structure: {det.__dict__}") | |
# Log detected vehicle details | |
logging.info(f"Vehicle count: {vehicle_count}") | |
logging.info(f"Vehicle bounding boxes: {vehicle_boxes}") | |
# Calculate congestion level based on vehicle count | |
if vehicle_count > 20: | |
congestion_level = "High" | |
elif vehicle_count > 10: | |
congestion_level = "Medium" | |
else: | |
congestion_level = "Low" | |
# Determine traffic flow rate based on congestion level | |
flow_rate = "Smooth" if congestion_level == "Low" else "Heavy" | |
# Draw bounding boxes on the processed image | |
for (x1, y1, x2, y2) in vehicle_boxes: | |
cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 0), thickness=2) | |
# Encode the processed image to base64 | |
_, buffer = cv2.imencode('.jpg', image_np) | |
processed_image_base64 = base64.b64encode(buffer).decode('utf-8') | |
# Return the analysis results along with the processed image | |
return JSONResponse(content={ | |
"vehicle_count": vehicle_count, | |
"congestion_level": congestion_level, | |
"flow_rate": flow_rate, | |
"processed_image_base64": processed_image_base64 | |
}) | |
except Exception as e: | |
# Log any exceptions that occur | |
logging.error(f"Error analyzing traffic: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error analyzing traffic: {str(e)}") | |
except cv2.error as cv_error: | |
logging.error(f"OpenCV error: {cv_error}") | |
raise HTTPException(status_code=500, detail="Image processing error.") | |