yolo-detect / app.py
jake2004's picture
Update app.py
161024b verified
import streamlit as st
import cv2
from PIL import Image
import numpy as np
from ultralytics import YOLO
import tempfile
import os
import supervision as sv
from streamlit_webrtc import webrtc_streamer, WebRtcMode
import functools
# --- Page Configuration ---
st.set_page_config(
page_title="Traffic Lane and Object Detection",
page_icon=":camera_video:",
layout="wide",
initial_sidebar_state="expanded",
)
# --- Sidebar ---
st.sidebar.header("Traffic Lane and Object Detection Options")
source_type = st.sidebar.radio("Select Input Source:", ("Image", "Video", "Live Camera Feed")) # Added Live Camera Feed
confidence_threshold = st.sidebar.slider(
"Confidence Threshold", min_value=0.0, max_value=1.0, value=0.25, step=0.05
)
iou_threshold = st.sidebar.slider(
"IoU Threshold", min_value=0.0, max_value=1.0, value=0.45, step=0.05, help="Intersection over Union threshold for NMS"
)
# --- Load YOLO Model ---
@st.cache_resource
def load_model():
model = YOLO("yolov8x.pt") # Load a pretrained YOLOv8x model
return model
model = load_model()
# --- Functions ---
def detect_lanes_and_objects_image(image, model, confidence_threshold, iou_threshold):
"""Runs YOLO on a single image with Supervision."""
img_np = np.array(image)
results = model(img_np, conf=confidence_threshold, iou=iou_threshold)
detections = sv.Detections.from_ultralytics(results[0])
annotator = sv.BoxAnnotator(thickness=2)
annotated_frame = annotator.annotate(scene=img_np, detections=detections)
return annotated_frame
def detect_lanes_and_objects_video(video_path, model, confidence_threshold, iou_threshold):
"""Runs YOLO on a video file with Supervision."""
video = cv2.VideoCapture(video_path)
frame_width = int(video.get(3))
frame_height = int(video.get(4))
fps = int(video.get(cv2.CAP_PROP_FPS))
codec = cv2.VideoWriter_fourcc(*"mp4v")
output_path = "output.mp4"
out = cv2.VideoWriter(output_path, codec, fps, (frame_width, frame_height))
stframe = st.empty()
while video.isOpened():
ret, frame = video.read()
if not ret:
break
results = model(frame, conf=confidence_threshold, iou=iou_threshold)
detections = sv.Detections.from_ultralytics(results[0])
annotator = sv.BoxAnnotator(thickness=2, text_scale=0.5)
annotated_frame = annotator.annotate(scene=frame, detections=detections)
out.write(annotated_frame)
stframe.image(annotated_frame, channels="BGR", use_column_width=True)
video.release()
out.release()
cv2.destroyAllWindows()
return output_path
def process_frame(frame, model, confidence_threshold, iou_threshold):
"""Process each frame from the webcam with Supervision."""
img = frame.to_ndarray(format="bgr24")
results = model(img, conf=confidence_threshold, iou=iou_threshold)
detections = sv.Detections.from_ultralytics(results[0])
annotator = sv.BoxAnnotator(thickness=2, text_scale=0.5)
annotated_frame = annotator.annotate(scene=img, detections=detections)
return annotated_frame
st.title("Traffic Lane and Object Detection")
if source_type == "Image":
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Run Detection"):
with st.spinner("Running YOLOv8..."):
detected_image = detect_lanes_and_objects_image(
image, model, confidence_threshold, iou_threshold
)
st.image(detected_image, caption="Detected Image", use_column_width=True)
elif source_type == "Video":
uploaded_video = st.file_uploader("Upload a video", type=["mp4", "avi", "mov"])
if uploaded_video is not None:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_video.read())
video_path = tfile.name
if st.button("Run Detection"):
with st.spinner("Running YOLOv8 on video..."):
output_video_path = detect_lanes_and_objects_video(
video_path, model, confidence_threshold, iou_threshold
)
st.video(output_video_path)
tfile.close()
os.unlink(video_path)
os.remove(output_video_path)
elif source_type == "Live Camera Feed":
st.subheader("Live Camera Feed")
if model is None:
st.error("YOLO model is not loaded. Cannot run live feed.")
st.stop()
custom_process_frame = functools.partial(
process_frame,
model=model,
confidence_threshold=confidence_threshold,
iou_threshold=iou_threshold,
)
webrtc_streamer(
key="live-feed",
video_frame_callback=custom_process_frame,
mode=WebRtcMode.SENDRECV,
media_stream_constraints={"video": True, "audio": False},
)
st.markdown("Your markdown content here")
st.markdown(
"""
**Note:** This example uses YOLOv8 for object detection. Lane detection is a more complex task and requires additional image processing techniques. This is a simplified demo and will likely not perform well on complex or noisy video.
"""
)