import logging from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from ultralytics import YOLO import cv2 import numpy as np import io from PIL import Image import base64 import os from io import BytesIO logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) app = FastAPI() # Set the path to the YOLO model YOLO_MODEL_PATH = "/tmp/models/yolov9s.pt" # Model Check if not os.path.exists(YOLO_MODEL_PATH): raise FileNotFoundError(f"YOLO model not found at {YOLO_MODEL_PATH}") # Load the YOLOv9 model try: model = YOLO(YOLO_MODEL_PATH) # Load the YOLO model from the pre-downloaded path except Exception as e: raise RuntimeError(f"Failed to load YOLO model: {str(e)}") # Class labels for vehicles (cars, motorbikes, buses, trucks, etc.) vehicle_classes = [2, 3, 5, 7] # Adjust as necessary for your use case @app.get("/") async def root(): return {"status": "OK", "model": "YOLOv9s"} @app.post("/analyze_traffic/") async def analyze_traffic(file: UploadFile = File(...)): """ Analyze the traffic image using YOLOv9 and return the results along with a processed image. """ try: # Load image from the uploaded file image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_np = np.array(image) # YOLO model detection logging.info("Running YOLO detection on the uploaded image...") results = model(image_np) # Perform detection detections = results[0] # Get detections from the first image (batch size = 1) # Log raw results for debugging logging.debug(f"Raw YOLO results: {results}") # Extract vehicle details vehicle_count = 0 vehicle_boxes = [] for det in detections.boxes: cls_id = int(det.cls[0]) if hasattr(det, 'cls') else None if cls_id in vehicle_classes: vehicle_count += 1 box = det.xyxy[0] # Bounding box vehicle_boxes.append((int(box[0]), int(box[1]), int(box[2]), int(box[3]))) #Log for detection structure logging.debug(f"Detection structure: {det.__dict__}") # Log detected vehicle details logging.info(f"Vehicle count: {vehicle_count}") logging.info(f"Vehicle bounding boxes: {vehicle_boxes}") # Calculate congestion level based on vehicle count if vehicle_count > 20: congestion_level = "High" elif vehicle_count > 10: congestion_level = "Medium" else: congestion_level = "Low" # Determine traffic flow rate based on congestion level flow_rate = "Smooth" if congestion_level == "Low" else "Heavy" # Draw bounding boxes on the processed image for (x1, y1, x2, y2) in vehicle_boxes: cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 0), thickness=2) # Encode the processed image to base64 _, buffer = cv2.imencode('.jpg', image_np) processed_image_base64 = base64.b64encode(buffer).decode('utf-8') # Return the analysis results along with the processed image return JSONResponse(content={ "vehicle_count": vehicle_count, "congestion_level": congestion_level, "flow_rate": flow_rate, "processed_image_base64": processed_image_base64 }) except Exception as e: # Log any exceptions that occur logging.error(f"Error analyzing traffic: {str(e)}") raise HTTPException(status_code=500, detail=f"Error analyzing traffic: {str(e)}") except cv2.error as cv_error: logging.error(f"OpenCV error: {cv_error}") raise HTTPException(status_code=500, detail="Image processing error.")