|
import os |
|
import cv2 |
|
import tqdm |
|
import uuid |
|
import logging |
|
|
|
import torch |
|
import trackers |
|
import numpy as np |
|
import gradio as gr |
|
import imageio.v3 as iio |
|
import supervision as sv |
|
|
|
from pathlib import Path |
|
from typing import List, Optional, Tuple |
|
|
|
from PIL import Image |
|
from pipeline import build_pipeline |
|
from utils import cfg, load_config, load_onnx_model |
|
|
|
|
|
|
|
DETECTORS = { |
|
"yolo8n-640": 'downloads/yolo8n-640.onnx', |
|
"yolo8n-416": 'downloads/yolo8n-416.onnx', |
|
} |
|
DEFAULT_DETECTOR = list(DETECTORS.keys())[0] |
|
DEFAULT_CONFIDENCE_THRESHOLD = 0.6 |
|
|
|
|
|
|
|
IMAGE_EXAMPLES = [ |
|
{"path": "./examples/images/forest.jpg", "label": "Local Image"}, |
|
{"path": "./examples/images/river.jpg", "label": "Local Image"}, |
|
{"path": "./examples/images/ocean.jpg", "label": "Local Image"}, |
|
] |
|
|
|
|
|
MAX_NUM_FRAMES = 250 |
|
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} |
|
VIDEO_OUTPUT_DIR = Path("static/videos") |
|
VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
class TrackingAlgorithm: |
|
BYTETRACK = "ByteTrack (2021)" |
|
DEEPSORT = "DeepSORT (2017)" |
|
SORT = "SORT (2016)" |
|
|
|
TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT] |
|
VIDEO_EXAMPLES = [ |
|
{"path": "./examples/videos/sea.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK, "classes": "Person, Boat"}, |
|
{"path": "./examples/videos/forest.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK, "classes": "LightVehicle, Person, Boat"}, |
|
] |
|
|
|
|
|
|
|
|
|
color = sv.ColorPalette.from_hex([ |
|
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", |
|
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" |
|
]) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_pipeline(config: dict, onnx_path: str): |
|
pipeline = build_pipeline(config) |
|
load_onnx_model(pipeline.detector, onnx_path) |
|
return pipeline |
|
|
|
|
|
def detect_objects( |
|
config: dict, |
|
onnx_path: str, |
|
images: List[np.ndarray] | np.ndarray, |
|
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
|
target_size: Optional[Tuple[int, int]] = None, |
|
classes: Optional[List[str]] = None, |
|
): |
|
config.defrost() |
|
config.detector.thresholds.confidence = float(confidence_threshold) |
|
config.freeze() |
|
pipeline = get_pipeline(config, onnx_path) |
|
id2label = pipeline.detector.get_category_mapping() |
|
label2id = {v: k for k, v in pipeline.detector.get_category_mapping().items()} |
|
if classes is not None: |
|
wrong_classes = [cls for cls in classes if cls not in label2id] |
|
if wrong_classes: |
|
gr.Warning(f"Classes not found in model config: {wrong_classes}") |
|
keep_ids = [label2id[cls] for cls in classes if cls in label2id] |
|
else: |
|
keep_ids = None |
|
|
|
if isinstance(images, np.ndarray) and images.ndim == 4: |
|
images = [x for x in images] |
|
|
|
results = [] |
|
for img in tqdm.tqdm(images, desc="Processing frames"): |
|
output_ = pipeline(img) |
|
output_reshaped = { |
|
"scores": torch.from_numpy(output_.confidence) if isinstance(output_.confidence, np.ndarray) else output_.confidence, |
|
"labels": torch.from_numpy(output_.class_id) if isinstance(output_.class_id, np.ndarray) else output_.class_id, |
|
"boxes": torch.from_numpy(output_.xyxy) if isinstance(output_.xyxy, np.ndarray) else output_.xyxy, |
|
} |
|
results.append(output_reshaped) |
|
if target_size: |
|
|
|
scale_x = target_size[0] / img.shape[1] |
|
scale_y = target_size[1] / img.shape[0] |
|
output_reshaped["boxes"][:, [0, 2]] *= scale_x |
|
output_reshaped["boxes"][:, [1, 3]] *= scale_y |
|
|
|
|
|
|
|
|
|
for i, result in enumerate(results): |
|
results[i] = {k: v for k, v in result.items()} |
|
if keep_ids is not None: |
|
keep = torch.isin(results[i]["labels"], torch.tensor(keep_ids)) |
|
results[i] = {k: v[keep] for k, v in results[i].items()} |
|
|
|
|
|
return results, pipeline.detector.get_category_mapping() |
|
|
|
|
|
def process_image( |
|
model: str = DEFAULT_DETECTOR, |
|
image: Optional[Image.Image] = None, |
|
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
|
): |
|
|
|
load_config(cfg, f'configs/{model}.yaml') |
|
results, id2label = detect_objects( |
|
config=cfg.pipeline, |
|
onnx_path=DETECTORS[model], |
|
images=[np.array(image)], |
|
confidence_threshold=confidence_threshold, |
|
) |
|
result = results[0] |
|
|
|
annotations = [] |
|
for label, score, box in zip(result["labels"], result["scores"], result["boxes"]): |
|
text_label = id2label[label.item()] |
|
formatted_label = f"{text_label} ({score:.2f})" |
|
x_min, y_min, x_max, y_max = box.cpu().numpy().round().astype(int) |
|
x_min = max(0, x_min) |
|
y_min = max(0, y_min) |
|
x_max = min(image.width - 1, x_max) |
|
y_max = min(image.height - 1, y_max) |
|
annotations.append(((x_min, y_min, x_max, y_max), formatted_label)) |
|
|
|
return (image, annotations) |
|
|
|
|
|
def get_target_size(image_height, image_width, max_size: int): |
|
if image_height < max_size and image_width < max_size: |
|
new_height, new_width = image_height, image_width |
|
elif image_height > image_width: |
|
new_height = max_size |
|
new_width = int(image_width * max_size / image_height) |
|
else: |
|
new_width = max_size |
|
new_height = int(image_height * max_size / image_width) |
|
|
|
|
|
new_height = new_height // 2 * 2 |
|
new_width = new_width // 2 * 2 |
|
|
|
return new_width, new_height |
|
|
|
|
|
def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1): |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
i = 0 |
|
progress_bar = tqdm.tqdm(total=k, desc="Reading frames") |
|
while cap.isOpened() and len(frames) < k: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
if i % read_every_i_frame == 0: |
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
progress_bar.update(1) |
|
i += 1 |
|
cap.release() |
|
progress_bar.close() |
|
return frames |
|
|
|
|
|
def get_tracker(tracker: str, fps: float): |
|
if tracker == TrackingAlgorithm.SORT: |
|
return trackers.SORTTracker(frame_rate=fps) |
|
elif tracker == TrackingAlgorithm.DEEPSORT: |
|
feature_extractor = trackers.DeepSORTFeatureExtractor.from_timm("mobilenetv4_conv_small.e1200_r224_in1k", device="cpu") |
|
return trackers.DeepSORTTracker(feature_extractor, frame_rate=fps) |
|
elif tracker == TrackingAlgorithm.BYTETRACK: |
|
return sv.ByteTrack(frame_rate=int(fps)) |
|
else: |
|
raise ValueError(f"Invalid tracker: {tracker}") |
|
|
|
|
|
def update_tracker(tracker, detections, frame): |
|
tracker_name = tracker.__class__.__name__ |
|
if tracker_name == "SORTTracker": |
|
return tracker.update(detections) |
|
elif tracker_name == "DeepSORTTracker": |
|
return tracker.update(detections, frame) |
|
elif tracker_name == "ByteTrack": |
|
return tracker.update_with_detections(detections) |
|
else: |
|
raise ValueError(f"Invalid tracker: {tracker}") |
|
|
|
|
|
def process_video( |
|
video_path: str, |
|
checkpoint: str, |
|
tracker_algorithm: Optional[str] = None, |
|
classes: str = "all", |
|
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
) -> str: |
|
|
|
if not video_path or not os.path.isfile(video_path): |
|
raise ValueError(f"Invalid video path: {video_path}") |
|
|
|
ext = os.path.splitext(video_path)[1].lower() |
|
if ext not in ALLOWED_VIDEO_EXTENSIONS: |
|
raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}") |
|
|
|
video_info = sv.VideoInfo.from_video_path(video_path) |
|
read_each_i_frame = max(1, video_info.fps // 25) |
|
target_fps = video_info.fps / read_each_i_frame |
|
target_width, target_height = get_target_size(video_info.height, video_info.width, 1080) |
|
|
|
n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame) |
|
frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame) |
|
frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames] |
|
|
|
|
|
|
|
color_lookup = sv.ColorLookup.TRACK if tracker_algorithm else sv.ColorLookup.CLASS |
|
|
|
box_annotator = sv.BoxAnnotator(color, color_lookup=color_lookup, thickness=1) |
|
label_annotator = sv.LabelAnnotator(color, color_lookup=color_lookup, text_scale=0.5) |
|
|
|
|
|
if classes != "all": |
|
classes_list = [cls.strip() for cls in classes.split(",")] |
|
else: |
|
classes_list = None |
|
|
|
load_config(cfg, f'configs/{checkpoint}.yaml') |
|
results, id2label = detect_objects( |
|
config=cfg.pipeline, |
|
onnx_path=DETECTORS[checkpoint], |
|
images=np.array(frames), |
|
confidence_threshold=confidence_threshold, |
|
target_size=(target_height, target_width), |
|
classes=classes_list, |
|
) |
|
|
|
|
|
annotated_frames = [] |
|
|
|
|
|
if tracker_algorithm: |
|
tracker = get_tracker(tracker_algorithm, target_fps) |
|
for frame, result in progress.tqdm(zip(frames, results), desc="Tracking objects", total=len(frames)): |
|
detections = sv.Detections.from_transformers(result, id2label=id2label) |
|
detections = detections.with_nms(threshold=0.95, class_agnostic=True) |
|
detections = update_tracker(tracker, detections, frame) |
|
labels = [f"#{tracker_id} {id2label[class_id]}" for class_id, tracker_id in zip(detections.class_id, detections.tracker_id)] |
|
annotated_frame = box_annotator.annotate(scene=frame, detections=detections) |
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) |
|
annotated_frames.append(annotated_frame) |
|
|
|
else: |
|
for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)): |
|
detections = sv.Detections.from_transformers(result, id2label=id2label) |
|
detections = detections.with_nms(threshold=0.95, class_agnostic=True) |
|
annotated_frame = box_annotator.annotate(scene=frame, detections=detections) |
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections) |
|
annotated_frames.append(annotated_frame) |
|
|
|
output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4") |
|
iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264") |
|
return output_filename |
|
|
|
|
|
|
|
def create_image_inputs() -> List[gr.components.Component]: |
|
return [ |
|
gr.Image( |
|
label="Upload Image", |
|
type="pil", |
|
sources=["upload", "webcam"], |
|
interactive=True, |
|
elem_classes="input-component", |
|
), |
|
gr.Dropdown( |
|
choices=list(DETECTORS.keys()), |
|
label="Select Model Checkpoint", |
|
value=DEFAULT_DETECTOR, |
|
elem_classes="input-component", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=DEFAULT_CONFIDENCE_THRESHOLD, |
|
step=0.1, |
|
label="Confidence Threshold", |
|
elem_classes="input-component", |
|
), |
|
] |
|
|
|
|
|
def create_video_inputs() -> List[gr.components.Component]: |
|
return [ |
|
gr.Video( |
|
label="Upload Video", |
|
sources=["upload"], |
|
interactive=True, |
|
format="mp4", |
|
elem_classes="input-component", |
|
), |
|
gr.Dropdown( |
|
choices=list(DETECTORS.keys()), |
|
label="Select Model Checkpoint", |
|
value=DEFAULT_DETECTOR, |
|
elem_classes="input-component", |
|
), |
|
gr.Dropdown( |
|
choices=TRACKERS, |
|
label="Select Tracker (Optional)", |
|
value=None, |
|
elem_classes="input-component", |
|
), |
|
gr.TextArea( |
|
label="Specify Class Names to Detect (comma separated)", |
|
value="all", |
|
lines=1, |
|
elem_classes="input-component", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=DEFAULT_CONFIDENCE_THRESHOLD, |
|
step=0.1, |
|
label="Confidence Threshold", |
|
elem_classes="input-component", |
|
), |
|
] |
|
|
|
|
|
def create_button_row() -> List[gr.Button]: |
|
return [ |
|
gr.Button( |
|
f"Detect Objects", variant="primary", elem_classes="action-button" |
|
), |
|
gr.Button(f"Clear", variant="secondary", elem_classes="action-button"), |
|
] |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Ocean()) as demo: |
|
gr.Markdown( |
|
""" |
|
# Pipeline for Aerial Search and Rescue Demo |
|
Experience state-of-the-art object detection with Open Source [WALDO30](https://huggingface.co/StephanST/WALDO30) models. |
|
- **Image** and **Video** modes are supported. |
|
- Select a model and adjust the confidence threshold to see detections! |
|
- On video mode, you can enable tracking powered by [Supervision](https://github.com/roboflow/supervision) and [Trackers](https://github.com/roboflow/trackers) from Roboflow. |
|
|
|
For more details and source code, visit the [PiSAR](https://github.com/eadali/PiSAR). |
|
""", |
|
elem_classes="header-text", |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Image"): |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
with gr.Group(): |
|
( |
|
image_input, |
|
image_model_checkpoint, |
|
image_confidence_threshold, |
|
) = create_image_inputs() |
|
image_detect_button, image_clear_button = create_button_row() |
|
with gr.Column(scale=2): |
|
image_output = gr.AnnotatedImage( |
|
label="Detection Results", |
|
show_label=True, |
|
color_map=None, |
|
elem_classes="output-component", |
|
) |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
DEFAULT_DETECTOR, |
|
example["path"], |
|
DEFAULT_CONFIDENCE_THRESHOLD, |
|
] |
|
for example in IMAGE_EXAMPLES |
|
], |
|
inputs=[ |
|
image_model_checkpoint, |
|
image_input, |
|
image_confidence_threshold, |
|
], |
|
outputs=[image_output], |
|
fn=process_image, |
|
label="Select an image example to populate inputs", |
|
cache_examples=True, |
|
cache_mode="lazy", |
|
) |
|
|
|
with gr.Tab("Video"): |
|
gr.Markdown( |
|
f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)." |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
with gr.Group(): |
|
video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold = create_video_inputs() |
|
video_detect_button, video_clear_button = create_button_row() |
|
with gr.Column(scale=2): |
|
video_output = gr.Video( |
|
label="Detection Results", |
|
format="mp4", |
|
elem_classes="output-component", |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[example["path"], DEFAULT_DETECTOR, example["tracker"], example["classes"], DEFAULT_CONFIDENCE_THRESHOLD] |
|
for example in VIDEO_EXAMPLES |
|
], |
|
inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold], |
|
outputs=[video_output], |
|
fn=process_video, |
|
cache_examples=False, |
|
label="Select a video example to populate inputs", |
|
) |
|
|
|
|
|
image_clear_button.click( |
|
fn=lambda: ( |
|
None, |
|
DEFAULT_DETECTOR, |
|
DEFAULT_CONFIDENCE_THRESHOLD, |
|
None, |
|
), |
|
outputs=[ |
|
image_input, |
|
image_model_checkpoint, |
|
image_confidence_threshold, |
|
image_output, |
|
], |
|
) |
|
|
|
|
|
video_clear_button.click( |
|
fn=lambda: ( |
|
None, |
|
DEFAULT_DETECTOR, |
|
None, |
|
"all", |
|
DEFAULT_CONFIDENCE_THRESHOLD, |
|
None, |
|
), |
|
outputs=[ |
|
video_input, |
|
video_checkpoint, |
|
video_tracker, |
|
video_classes, |
|
video_confidence_threshold, |
|
video_output, |
|
], |
|
) |
|
|
|
|
|
image_detect_button.click( |
|
fn=process_image, |
|
inputs=[ |
|
image_model_checkpoint, |
|
image_input, |
|
image_confidence_threshold, |
|
], |
|
outputs=[image_output], |
|
) |
|
|
|
|
|
video_detect_button.click( |
|
fn=process_video, |
|
inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold], |
|
outputs=[video_output], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch() |