|
import os |
|
import time |
|
import tempfile |
|
from typing import List, Dict, Optional, Tuple |
|
import json |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
SAMPLES_DIR = "samples" |
|
EMBED_IMG = os.path.join(SAMPLES_DIR, "uav_image.jpg") |
|
EMBED_VID = os.path.join(SAMPLES_DIR, "uav_video.mp4") |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() |
|
|
|
|
|
MODEL_CHOICES: Dict[str, Tuple[str, str]] = { |
|
"Multi-class (Drone/Helicopter/Airplane/Bird)": |
|
("Javvanny/yolov8m_flying_objects_detection", "yolov8m/weights/best.pt"), |
|
"Drone-only (cleaner, fewer false positives)": |
|
("doguilmak/Drone-Detection-YOLOv8x", "weight/best.pt"), |
|
} |
|
|
|
|
|
|
|
|
|
LABEL_MAP = { |
|
"Airplane": "Airplane", |
|
"Bird": "Bird", |
|
"Drone": "Drone", |
|
"Helicopter": "Helicopter", |
|
"UAV": "UAV", |
|
"БПЛА": "UAV", |
|
"БПЛА коптер": "Drone", |
|
"квадрокоптер": "Drone", |
|
"квадроcамолет": "Drone", |
|
"самолет": "Airplane", |
|
"вертолет": "Helicopter", |
|
"автомобиль": "Car", |
|
"машина": "Car", |
|
"БПЛА самелет": "UAV Airplane", |
|
"drone": "Drone", |
|
} |
|
THREAT_SET = {"drone", "uav", "airplane", "helicopter"} |
|
|
|
def map_label(name: str) -> str: |
|
if not isinstance(name, str): |
|
return name |
|
return LABEL_MAP.get(name, LABEL_MAP.get(name.lower(), name)) |
|
|
|
def translate_names_dict(names_dict: Dict[int, str]) -> Dict[int, str]: |
|
if not isinstance(names_dict, dict): |
|
return names_dict |
|
return {k: map_label(v) for k, v in names_dict.items()} |
|
|
|
def is_threat(label_en: str) -> bool: |
|
return label_en and label_en.lower() in THREAT_SET |
|
|
|
|
|
|
|
|
|
MIN_CONF = float(os.getenv("MIN_CONF", 0.30)) |
|
MIN_AREA_PCT = float(os.getenv("MIN_AREA_PCT", 0.001)) |
|
SKY_RATIO = float(os.getenv("SKY_RATIO", 0.95)) |
|
|
|
|
|
|
|
|
|
_model = None |
|
_model_err = None |
|
_model_names = None |
|
_loaded_repo = None |
|
_loaded_file = None |
|
_loaded_key = None |
|
_ffmpeg_status = None |
|
|
|
def _lazy_cv2(): |
|
import cv2 |
|
return cv2 |
|
|
|
def _ffmpeg_ok() -> bool: |
|
global _ffmpeg_status |
|
if _ffmpeg_status is not None: |
|
return _ffmpeg_status |
|
try: |
|
cv2 = _lazy_cv2() |
|
info = cv2.getBuildInformation() |
|
_ffmpeg_status = ("FFMPEG:YES" in info) or ("FFMPEG: YES" in info) |
|
except Exception: |
|
_ffmpeg_status = False |
|
return _ffmpeg_status |
|
|
|
def _download_from_hf(repo_id: str, filename: str) -> str: |
|
from huggingface_hub import hf_hub_download, login |
|
if HF_TOKEN: |
|
try: |
|
login(token=HF_TOKEN) |
|
except Exception: |
|
pass |
|
return hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
def _reset_model_cache(): |
|
global _model, _model_err, _model_names, _loaded_repo, _loaded_file |
|
_model = None |
|
_model_err = None |
|
_model_names = None |
|
_loaded_repo = None |
|
_loaded_file = None |
|
|
|
def _get_model(model_key: str, conf: float, iou: float): |
|
"""Load the YOLO model selected in the dropdown.""" |
|
from ultralytics import YOLO |
|
global _model, _model_err, _model_names, _loaded_repo, _loaded_file, _loaded_key |
|
if _loaded_key != model_key: |
|
_reset_model_cache() |
|
_loaded_key = model_key |
|
|
|
if _model is None and _model_err is None: |
|
repo, file = MODEL_CHOICES[model_key] |
|
last_err = None |
|
try: |
|
weights = _download_from_hf(repo, file) |
|
m = YOLO(weights) |
|
|
|
m.overrides["max_det"] = 300 |
|
m.overrides["conf"] = float(conf) |
|
m.overrides["iou"] = float(iou) |
|
m.overrides["agnostic_nms"] = True |
|
_model = m |
|
_loaded_repo, _loaded_file = repo, file |
|
try: |
|
_model_names = m.model.names if hasattr(m, "model") else None |
|
except Exception: |
|
_model_names = None |
|
except Exception as e: |
|
last_err = e |
|
_model = None |
|
if _model is None: |
|
_model_err = f"Model load failed for {repo}/{file}. Error: {last_err}" |
|
if _model_err: |
|
raise RuntimeError(_model_err) |
|
|
|
_model.overrides["conf"] = float(conf) |
|
_model.overrides["iou"] = float(iou) |
|
_model.overrides["agnostic_nms"] = True |
|
return _model |
|
|
|
def _model_info_text(): |
|
repo = f"{_loaded_repo}/{_loaded_file}" if _loaded_repo else "not loaded" |
|
try: |
|
names = ", ".join(sorted(set(translate_names_dict(_model_names or {}).values()))) or "unknown" |
|
except Exception: |
|
names = "unknown" |
|
return f"**Model:** {repo} • FFmpeg: {'Yes' if _ffmpeg_ok() else 'No'} • Python: 3.10\n\n**Classes:** {names}" |
|
|
|
|
|
|
|
|
|
def _results_to_rows(results) -> List[dict]: |
|
rows: List[dict] = [] |
|
if not results: |
|
return rows |
|
r = results[0] |
|
if getattr(r, "boxes", None) is None: |
|
return rows |
|
names_dict = getattr(r, "names", {}) or _model_names or {} |
|
names_dict = translate_names_dict(names_dict) |
|
import numpy as np |
|
xyxy = r.boxes.xyxy.cpu().numpy() if hasattr(r.boxes, "xyxy") else np.zeros((0,4)) |
|
confs = r.boxes.conf.cpu().numpy() if hasattr(r.boxes, "conf") else np.zeros((0,)) |
|
clss = r.boxes.cls.cpu().numpy() if hasattr(r.boxes, "cls") else np.zeros((0,)) |
|
for i, box in enumerate(xyxy): |
|
x1,y1,x2,y2 = [float(v) for v in box.tolist()] |
|
cls_idx = int(clss[i]) if i < len(clss) else -1 |
|
cls_name = names_dict.get(cls_idx, str(cls_idx)) |
|
rows.append({ |
|
"class": map_label(cls_name), |
|
"confidence": float(confs[i]) if i < len(confs) else None, |
|
"x1": x1, "y1": y1, "x2": x2, "y2": y2, |
|
"width": x2-x1, "height": y2-y1, |
|
}) |
|
return rows |
|
|
|
def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]: |
|
""" |
|
Drop low-conf, tiny, ground-region boxes. |
|
For drone-only model, DO NOT restrict classes (some checkpoints label as 'UAV'/'drone' variants). |
|
For multi-class, keep only classes we care about. |
|
""" |
|
if "Multi-class" in model_key: |
|
allowed = {"Drone", "UAV", "Helicopter", "Airplane"} |
|
else: |
|
allowed = set() |
|
|
|
try: |
|
H, W = r.orig_img.shape[:2] |
|
except Exception: |
|
H = W = None |
|
|
|
kept = [] |
|
for row in rows: |
|
if row.get("confidence") is not None and row["confidence"] < MIN_CONF: |
|
continue |
|
cls = map_label(str(row.get("class",""))) |
|
if allowed and cls not in allowed: |
|
continue |
|
if H and W and (W * H) > 0: |
|
area = row["width"] * row["height"] |
|
if area / (W * H) < MIN_AREA_PCT: |
|
continue |
|
y_bottom = row["y2"] |
|
horizon = H * SKY_RATIO |
|
if y_bottom > horizon: |
|
continue |
|
kept.append(row) |
|
return kept |
|
|
|
def _draw_annotations_bgr(bgr_img, rows: List[dict]): |
|
"""Draw boxes ourselves so overlay matches filtered results.""" |
|
cv2 = _lazy_cv2() |
|
out = bgr_img.copy() |
|
for r in rows: |
|
x1,y1,x2,y2 = int(r["x1"]), int(r["y1"]), int(r["x2"]), int(r["y2"]) |
|
cls = map_label(r["class"]) |
|
label = f'{cls} {float(r.get("confidence") or 0):.2f}' |
|
color = (255, 128, 0) if is_threat(cls) else (0, 200, 0) |
|
cv2.rectangle(out, (x1,y1), (x2,y2), color, 2) |
|
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) |
|
cv2.rectangle(out, (x1, max(0, y1- th - 6)), (x1 + tw + 6, y1), color, -1) |
|
cv2.putText(out, label, (x1+3, y1-4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA) |
|
return out |
|
|
|
|
|
def _save_pdf_detections(title: str, detections: List[dict], header_note: str = "", image_path: Optional[str] = None) -> str: |
|
from reportlab.lib.pagesizes import A4 |
|
from reportlab.pdfgen import canvas |
|
from reportlab.lib.units import cm |
|
from reportlab.lib.utils import ImageReader |
|
out_path = os.path.join(tempfile.gettempdir(), f"report_{int(time.time())}.pdf") |
|
c = canvas.Canvas(out_path, pagesize=A4) |
|
W, H = A4 |
|
margin = 2*cm |
|
y = H - margin |
|
|
|
c.setFont("Helvetica-Bold", 16); c.drawString(margin, y, title); y -= 0.8*cm |
|
c.setFont("Helvetica", 11) |
|
for line in (header_note or "").splitlines(): |
|
c.drawString(margin, y, line[:110]); y -= 0.6*cm |
|
|
|
total = len(detections or []) |
|
threats = sum(1 for d in (detections or []) if d.get("threat") == "Threat") |
|
c.drawString(margin, y, f"Detections: {total} | Threats: {threats}"); y -= 0.8*cm |
|
|
|
if image_path and os.path.exists(image_path): |
|
try: |
|
img = ImageReader(image_path) |
|
max_w, max_h = W - 2*margin, 8*cm |
|
iw, ih = img.getSize(); scale = min(max_w/iw, max_h/ih) |
|
w, h = iw*scale, ih*scale |
|
c.drawImage(img, margin, y - h, width=w, height=h, preserveAspectRatio=True, mask='auto') |
|
y -= h + 0.8*cm |
|
except Exception: |
|
pass |
|
|
|
c.setFont("Helvetica-Bold", 12) |
|
c.drawString(margin + 0*cm, y, "Timestamp") |
|
c.drawString(margin + 5.0*cm, y, "Object") |
|
c.drawString(margin + 10.0*cm, y, "Conf.") |
|
c.drawString(margin + 12.0*cm, y, "Threat") |
|
y -= 0.5*cm |
|
c.setLineWidth(0.5); c.line(margin, y, W - margin, y); y -= 0.4*cm |
|
c.setFont("Helvetica", 11) |
|
|
|
for d in detections or []: |
|
if y < 2.5*cm: |
|
c.showPage(); y = H - margin |
|
c.setFont("Helvetica-Bold", 12) |
|
c.drawString(margin + 0*cm, y, "Timestamp") |
|
c.drawString(margin + 5.0*cm, y, "Object") |
|
c.drawString(margin + 10.0*cm, y, "Conf.") |
|
c.drawString(margin + 12.0*cm, y, "Threat") |
|
y -= 0.5*cm |
|
c.setLineWidth(0.5); c.line(margin, y, W - margin, y); y -= 0.4*cm |
|
c.setFont("Helvetica", 11) |
|
|
|
ts = str(d.get("time","")) |
|
obj = str(d.get("object","")) |
|
conf = d.get("confidence"); conf_s = f"{conf:.2f}" if isinstance(conf,(int,float)) else "-" |
|
thr = str(d.get("threat","")) |
|
c.drawString(margin + 0*cm, y, ts[:20]) |
|
c.drawString(margin + 5.0*cm, y, obj[:20]) |
|
c.drawString(margin + 10.0*cm, y, conf_s) |
|
c.drawString(margin + 12.0*cm, y, thr) |
|
y -= 0.55*cm |
|
|
|
c.showPage(); c.save() |
|
return out_path |
|
|
|
def _apply_english_overlay(r): |
|
try: |
|
if hasattr(r, "names") and isinstance(r.names, dict): |
|
r.names = translate_names_dict(r.names) |
|
except Exception: |
|
pass |
|
|
|
|
|
|
|
|
|
def detect_image_safe(model_key: str, image, conf: float, iou: float, bypass_filters: bool = True): |
|
try: |
|
if image is None: |
|
return None, [], "⚠️ No image provided.", [], None, _model_info_text() |
|
cv2 = _lazy_cv2() |
|
model = _get_model(model_key, conf, iou) |
|
results = model.predict(image, imgsz=1280, verbose=False) |
|
r = results[0] |
|
_apply_english_overlay(r) |
|
|
|
rows_raw = _results_to_rows(results) |
|
rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key) |
|
|
|
annotated_bgr = _draw_annotations_bgr(r.orig_img, rows) |
|
now_utc = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime()) |
|
det_records = [{ |
|
"time": now_utc, |
|
"object": map_label(row["class"]), |
|
"confidence": float(row.get("confidence") or 0.0), |
|
"threat": "Threat" if is_threat(map_label(row["class"])) else "Non-threat", |
|
} for row in rows] |
|
|
|
|
|
summary = f"raw:{len(rows_raw)} | kept:{len(rows)}" |
|
counts = {} |
|
for d in det_records: |
|
counts[d["object"]] = counts.get(d["object"], 0) + 1 |
|
if counts: |
|
summary += " • " + ", ".join(f"{k}: {v}" for k, v in counts.items()) |
|
|
|
tmp_img = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.jpg") |
|
try: |
|
cv2.imwrite(tmp_img, annotated_bgr) |
|
except Exception: |
|
tmp_img = None |
|
|
|
annotated_rgb = annotated_bgr[:, :, ::-1] |
|
return annotated_rgb, rows, summary, det_records, tmp_img, _model_info_text() |
|
except Exception as e: |
|
return None, [], f"❌ Error during image detection: {e}", [], None, _model_info_text() |
|
|
|
def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float, max_frames: int = 300, bypass_filters: bool = True): |
|
try: |
|
if not video_path: |
|
return None, "{}", "⚠️ No video provided.", [], _model_info_text() |
|
cv2 = _lazy_cv2() |
|
model = _get_model(model_key, conf, iou) |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None, "{}", "❌ Failed to open video.", [], _model_info_text() |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 24.0 |
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 1280) |
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 720) |
|
|
|
out_path = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.mp4") |
|
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) |
|
if not writer or (hasattr(writer, "isOpened") and not writer.isOpened()): |
|
return None, "{}", "❌ Video writer could not open. Try another format/resolution.", [], _model_info_text() |
|
|
|
det_records: List[dict] = [] |
|
frames = 0 |
|
raw_total = 0 |
|
kept_total = 0 |
|
try: |
|
while True: |
|
ok, frame = cap.read() |
|
if not ok: |
|
break |
|
frames += 1 |
|
if frames > int(max_frames): |
|
break |
|
|
|
results = model.predict(frame, imgsz=1280, verbose=False) |
|
r = results[0] |
|
_apply_english_overlay(r) |
|
|
|
rows_raw = _results_to_rows(results) |
|
rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key) |
|
raw_total += len(rows_raw) |
|
kept_total += len(rows) |
|
|
|
t_sec = frames / float(fps if fps > 0 else 24.0) |
|
for row in rows: |
|
label = map_label(row["class"]) |
|
det_records.append({ |
|
"time": f"{t_sec:.2f}s", |
|
"object": label, |
|
"confidence": float(row.get("confidence") or 0.0), |
|
"threat": "Threat" if is_threat(label) else "Non-threat", |
|
}) |
|
|
|
annotated_bgr = _draw_annotations_bgr(frame, rows) |
|
writer.write(annotated_bgr) |
|
finally: |
|
cap.release() |
|
writer.release() |
|
|
|
|
|
counts = {} |
|
for d in det_records: |
|
counts[d["object"]] = counts.get(d["object"], 0) + 1 |
|
summary = f"raw:{raw_total} | kept:{kept_total}" |
|
if counts: |
|
summary += " • " + ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) |
|
|
|
detections_json = json.dumps(det_records[:200], ensure_ascii=False, indent=2) |
|
return out_path, detections_json, summary, det_records, _model_info_text() |
|
except Exception as e: |
|
return None, "{}", f"❌ Error during video detection: {e}", [], _model_info_text() |
|
|
|
|
|
def export_pdf_img(det_records: List[dict], summary: str, annotated_tmp_jpg: Optional[str]): |
|
try: |
|
note = summary or "" |
|
return _save_pdf_detections( |
|
"UAV Detector — Image Report", det_records or [], note, |
|
image_path=annotated_tmp_jpg if annotated_tmp_jpg and os.path.exists(annotated_tmp_jpg) else None |
|
) |
|
except Exception as e: |
|
return _save_pdf_detections("UAV Detector — Image Report", [], f"❌ PDF export error: {e}", None) |
|
|
|
def export_pdf_vid(det_records, summary): |
|
"""Be forgiving: accept list[dict], DataFrame, JSON string, or None.""" |
|
try: |
|
|
|
if det_records is None: |
|
det_list = [] |
|
elif isinstance(det_records, list): |
|
det_list = det_records |
|
elif isinstance(det_records, str): |
|
try: |
|
det_list = json.loads(det_records) |
|
if not isinstance(det_list, list): |
|
det_list = [] |
|
except Exception: |
|
det_list = [] |
|
else: |
|
try: |
|
import pandas as pd |
|
if isinstance(det_records, pd.DataFrame): |
|
det_list = det_records.to_dict(orient="records") |
|
else: |
|
det_list = [] |
|
except Exception: |
|
det_list = [] |
|
|
|
note = summary if isinstance(summary, str) else (str(summary) if summary is not None else "") |
|
return _save_pdf_detections("UAV Detector — Video Report", det_list, note, image_path=None) |
|
except Exception as e: |
|
return _save_pdf_detections("UAV Detector — Video Report", [], f"❌ PDF export error: {e}", None) |
|
|
|
|
|
|
|
|
|
NOTE = ( |
|
"Detections include timestamp, object, confidence, and Threat/Non-threat. " |
|
"Use 'Bypass filters (debug)' to see raw model boxes; tighten filters after you confirm detections." |
|
) |
|
|
|
with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo: |
|
gr.Markdown("# UAV / Drone Detection (Pretrained YOLO)") |
|
gr.Markdown("Embedded samples (optional): `samples/uav_image.jpg`, `samples/uav_video.mp4`.") |
|
|
|
with gr.Row(): |
|
model_key = gr.Dropdown(choices=list(MODEL_CHOICES.keys()), |
|
value=list(MODEL_CHOICES.keys())[0], |
|
label="Model") |
|
model_info_md = gr.Markdown(value=_model_info_text()) |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("Image"): |
|
with gr.Row(): |
|
image_in = gr.Image( |
|
value=EMBED_IMG if os.path.exists(EMBED_IMG) else None, |
|
type="filepath", |
|
label="Input Image" |
|
) |
|
with gr.Column(): |
|
conf_img = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence") |
|
iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU") |
|
filters_off_img = gr.Checkbox(value=True, label="Bypass filters (debug)") |
|
run_img = gr.Button("Run Detection") |
|
gr.Markdown(NOTE) |
|
|
|
image_out = gr.Image(label="Annotated Image") |
|
table_out = gr.Dataframe(headers=["class","confidence","x1","y1","x2","y2","width","height"]) |
|
msg_img = gr.Markdown() |
|
pdf_img_btn = gr.Button("Generate PDF Report") |
|
pdf_img_path = gr.File(label="PDF Report", interactive=False) |
|
annotated_tmp_img_path = gr.State(value=None) |
|
image_det_state = gr.State(value=[]) |
|
|
|
def _run_img(mkey, image, conf, iou, bypass): |
|
return detect_image_safe(mkey, image, conf, iou, bypass) |
|
|
|
run_img.click( |
|
fn=_run_img, |
|
inputs=[model_key, image_in, conf_img, iou_img, filters_off_img], |
|
outputs=[image_out, table_out, msg_img, image_det_state, annotated_tmp_img_path, model_info_md], |
|
) |
|
|
|
pdf_img_btn.click( |
|
fn=export_pdf_img, |
|
inputs=[image_det_state, msg_img, annotated_tmp_img_path], |
|
outputs=[pdf_img_path], |
|
) |
|
|
|
|
|
with gr.TabItem("Video"): |
|
with gr.Row(): |
|
video_in = gr.Video( |
|
value=EMBED_VID if os.path.exists(EMBED_VID) else None, |
|
label="Input Video" |
|
) |
|
with gr.Column(): |
|
conf_vid = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence") |
|
iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU") |
|
max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process") |
|
filters_off_vid = gr.Checkbox(value=True, label="Bypass filters (debug)") |
|
run_vid = gr.Button("Run Detection") |
|
gr.Markdown(NOTE) |
|
|
|
video_out = gr.Video(label="Annotated Video") |
|
detections_json_text = gr.Textbox(label="Detections (first 200)", max_lines=20) |
|
msg_vid = gr.Markdown() |
|
pdf_vid_btn = gr.Button("Generate PDF Report") |
|
pdf_vid_path = gr.File(label="PDF Report", interactive=False) |
|
video_det_state = gr.State(value=[]) |
|
|
|
def _run_vid(mkey, vpath, conf, iou, maxf, bypass): |
|
return detect_video_safe(mkey, vpath, conf, iou, int(maxf), bypass) |
|
|
|
run_vid.click( |
|
fn=_run_vid, |
|
inputs=[model_key, video_in, conf_vid, iou_vid, max_frames, filters_off_vid], |
|
outputs=[video_out, detections_json_text, msg_vid, video_det_state, model_info_md], |
|
) |
|
|
|
|
|
pdf_vid_btn.click( |
|
fn=export_pdf_vid, |
|
inputs=[video_det_state, msg_vid], |
|
outputs=[pdf_vid_path], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=True) |