dr-one / app.py
rba28's picture
Update app.py
d79a580 verified
import os
import time
import tempfile
from typing import List, Dict, Optional, Tuple
import json
import gradio as gr
# =========================
# CONFIG – embedded samples
# =========================
SAMPLES_DIR = "samples"
EMBED_IMG = os.path.join(SAMPLES_DIR, "uav_image.jpg")
EMBED_VID = os.path.join(SAMPLES_DIR, "uav_video.mp4")
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() # optional for private/gated repos
# Selectable models (public & tested paths)
MODEL_CHOICES: Dict[str, Tuple[str, str]] = {
"Multi-class (Drone/Helicopter/Airplane/Bird)":
("Javvanny/yolov8m_flying_objects_detection", "yolov8m/weights/best.pt"),
"Drone-only (cleaner, fewer false positives)":
("doguilmak/Drone-Detection-YOLOv8x", "weight/best.pt"),
}
# =========================
# LABELS & THREAT RULES
# =========================
LABEL_MAP = {
"Airplane": "Airplane",
"Bird": "Bird",
"Drone": "Drone",
"Helicopter": "Helicopter",
"UAV": "UAV",
"БПЛА": "UAV",
"БПЛА коптер": "Drone",
"квадрокоптер": "Drone",
"квадроcамолет": "Drone",
"самолет": "Airplane",
"вертолет": "Helicopter",
"автомобиль": "Car",
"машина": "Car",
"БПЛА самелет": "UAV Airplane",
"drone": "Drone",
}
THREAT_SET = {"drone", "uav", "airplane", "helicopter"}
def map_label(name: str) -> str:
if not isinstance(name, str):
return name
return LABEL_MAP.get(name, LABEL_MAP.get(name.lower(), name))
def translate_names_dict(names_dict: Dict[int, str]) -> Dict[int, str]:
if not isinstance(names_dict, dict):
return names_dict
return {k: map_label(v) for k, v in names_dict.items()}
def is_threat(label_en: str) -> bool:
return label_en and label_en.lower() in THREAT_SET
# =========================
# FILTERS (relaxed defaults; tighten later if needed)
# =========================
MIN_CONF = float(os.getenv("MIN_CONF", 0.30)) # post-filter confidence
MIN_AREA_PCT = float(os.getenv("MIN_AREA_PCT", 0.001)) # min box area fraction
SKY_RATIO = float(os.getenv("SKY_RATIO", 0.95)) # sky gate nearly off by default
# =========================
# LAZY GLOBAL STATE
# =========================
_model = None
_model_err = None
_model_names = None
_loaded_repo = None
_loaded_file = None
_loaded_key = None # which dropdown choice loaded
_ffmpeg_status = None
def _lazy_cv2():
import cv2
return cv2
def _ffmpeg_ok() -> bool:
global _ffmpeg_status
if _ffmpeg_status is not None:
return _ffmpeg_status
try:
cv2 = _lazy_cv2()
info = cv2.getBuildInformation()
_ffmpeg_status = ("FFMPEG:YES" in info) or ("FFMPEG: YES" in info)
except Exception:
_ffmpeg_status = False
return _ffmpeg_status
def _download_from_hf(repo_id: str, filename: str) -> str:
from huggingface_hub import hf_hub_download, login
if HF_TOKEN:
try:
login(token=HF_TOKEN)
except Exception:
pass
return hf_hub_download(repo_id=repo_id, filename=filename)
def _reset_model_cache():
global _model, _model_err, _model_names, _loaded_repo, _loaded_file
_model = None
_model_err = None
_model_names = None
_loaded_repo = None
_loaded_file = None
def _get_model(model_key: str, conf: float, iou: float):
"""Load the YOLO model selected in the dropdown."""
from ultralytics import YOLO
global _model, _model_err, _model_names, _loaded_repo, _loaded_file, _loaded_key
if _loaded_key != model_key:
_reset_model_cache()
_loaded_key = model_key
if _model is None and _model_err is None:
repo, file = MODEL_CHOICES[model_key]
last_err = None
try:
weights = _download_from_hf(repo, file)
m = YOLO(weights)
# Core overrides
m.overrides["max_det"] = 300
m.overrides["conf"] = float(conf)
m.overrides["iou"] = float(iou)
m.overrides["agnostic_nms"] = True
_model = m
_loaded_repo, _loaded_file = repo, file
try:
_model_names = m.model.names if hasattr(m, "model") else None
except Exception:
_model_names = None
except Exception as e:
last_err = e
_model = None
if _model is None:
_model_err = f"Model load failed for {repo}/{file}. Error: {last_err}"
if _model_err:
raise RuntimeError(_model_err)
# keep sliders reflected every call
_model.overrides["conf"] = float(conf)
_model.overrides["iou"] = float(iou)
_model.overrides["agnostic_nms"] = True
return _model
def _model_info_text():
repo = f"{_loaded_repo}/{_loaded_file}" if _loaded_repo else "not loaded"
try:
names = ", ".join(sorted(set(translate_names_dict(_model_names or {}).values()))) or "unknown"
except Exception:
names = "unknown"
return f"**Model:** {repo} • FFmpeg: {'Yes' if _ffmpeg_ok() else 'No'} • Python: 3.10\n\n**Classes:** {names}"
# =========================
# HELPERS
# =========================
def _results_to_rows(results) -> List[dict]:
rows: List[dict] = []
if not results:
return rows
r = results[0]
if getattr(r, "boxes", None) is None:
return rows
names_dict = getattr(r, "names", {}) or _model_names or {}
names_dict = translate_names_dict(names_dict)
import numpy as np
xyxy = r.boxes.xyxy.cpu().numpy() if hasattr(r.boxes, "xyxy") else np.zeros((0,4))
confs = r.boxes.conf.cpu().numpy() if hasattr(r.boxes, "conf") else np.zeros((0,))
clss = r.boxes.cls.cpu().numpy() if hasattr(r.boxes, "cls") else np.zeros((0,))
for i, box in enumerate(xyxy):
x1,y1,x2,y2 = [float(v) for v in box.tolist()]
cls_idx = int(clss[i]) if i < len(clss) else -1
cls_name = names_dict.get(cls_idx, str(cls_idx))
rows.append({
"class": map_label(cls_name),
"confidence": float(confs[i]) if i < len(confs) else None,
"x1": x1, "y1": y1, "x2": x2, "y2": y2,
"width": x2-x1, "height": y2-y1,
})
return rows
def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
"""
Drop low-conf, tiny, ground-region boxes.
For drone-only model, DO NOT restrict classes (some checkpoints label as 'UAV'/'drone' variants).
For multi-class, keep only classes we care about.
"""
if "Multi-class" in model_key:
allowed = {"Drone", "UAV", "Helicopter", "Airplane"}
else:
allowed = set() # no restriction for drone-only
try:
H, W = r.orig_img.shape[:2]
except Exception:
H = W = None
kept = []
for row in rows:
if row.get("confidence") is not None and row["confidence"] < MIN_CONF:
continue
cls = map_label(str(row.get("class","")))
if allowed and cls not in allowed:
continue
if H and W and (W * H) > 0:
area = row["width"] * row["height"]
if area / (W * H) < MIN_AREA_PCT:
continue
y_bottom = row["y2"]
horizon = H * SKY_RATIO
if y_bottom > horizon: # below sky line → likely ground/grass noise
continue
kept.append(row)
return kept
def _draw_annotations_bgr(bgr_img, rows: List[dict]):
"""Draw boxes ourselves so overlay matches filtered results."""
cv2 = _lazy_cv2()
out = bgr_img.copy()
for r in rows:
x1,y1,x2,y2 = int(r["x1"]), int(r["y1"]), int(r["x2"]), int(r["y2"])
cls = map_label(r["class"])
label = f'{cls} {float(r.get("confidence") or 0):.2f}'
color = (255, 128, 0) if is_threat(cls) else (0, 200, 0)
cv2.rectangle(out, (x1,y1), (x2,y2), color, 2)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(out, (x1, max(0, y1- th - 6)), (x1 + tw + 6, y1), color, -1)
cv2.putText(out, label, (x1+3, y1-4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
return out
# ---------- PDF builder ----------
def _save_pdf_detections(title: str, detections: List[dict], header_note: str = "", image_path: Optional[str] = None) -> str:
from reportlab.lib.pagesizes import A4
from reportlab.pdfgen import canvas
from reportlab.lib.units import cm
from reportlab.lib.utils import ImageReader
out_path = os.path.join(tempfile.gettempdir(), f"report_{int(time.time())}.pdf")
c = canvas.Canvas(out_path, pagesize=A4)
W, H = A4
margin = 2*cm
y = H - margin
c.setFont("Helvetica-Bold", 16); c.drawString(margin, y, title); y -= 0.8*cm
c.setFont("Helvetica", 11)
for line in (header_note or "").splitlines():
c.drawString(margin, y, line[:110]); y -= 0.6*cm
total = len(detections or [])
threats = sum(1 for d in (detections or []) if d.get("threat") == "Threat")
c.drawString(margin, y, f"Detections: {total} | Threats: {threats}"); y -= 0.8*cm
if image_path and os.path.exists(image_path):
try:
img = ImageReader(image_path)
max_w, max_h = W - 2*margin, 8*cm
iw, ih = img.getSize(); scale = min(max_w/iw, max_h/ih)
w, h = iw*scale, ih*scale
c.drawImage(img, margin, y - h, width=w, height=h, preserveAspectRatio=True, mask='auto')
y -= h + 0.8*cm
except Exception:
pass
c.setFont("Helvetica-Bold", 12)
c.drawString(margin + 0*cm, y, "Timestamp")
c.drawString(margin + 5.0*cm, y, "Object")
c.drawString(margin + 10.0*cm, y, "Conf.")
c.drawString(margin + 12.0*cm, y, "Threat")
y -= 0.5*cm
c.setLineWidth(0.5); c.line(margin, y, W - margin, y); y -= 0.4*cm
c.setFont("Helvetica", 11)
for d in detections or []:
if y < 2.5*cm:
c.showPage(); y = H - margin
c.setFont("Helvetica-Bold", 12)
c.drawString(margin + 0*cm, y, "Timestamp")
c.drawString(margin + 5.0*cm, y, "Object")
c.drawString(margin + 10.0*cm, y, "Conf.")
c.drawString(margin + 12.0*cm, y, "Threat")
y -= 0.5*cm
c.setLineWidth(0.5); c.line(margin, y, W - margin, y); y -= 0.4*cm
c.setFont("Helvetica", 11)
ts = str(d.get("time",""))
obj = str(d.get("object",""))
conf = d.get("confidence"); conf_s = f"{conf:.2f}" if isinstance(conf,(int,float)) else "-"
thr = str(d.get("threat",""))
c.drawString(margin + 0*cm, y, ts[:20])
c.drawString(margin + 5.0*cm, y, obj[:20])
c.drawString(margin + 10.0*cm, y, conf_s)
c.drawString(margin + 12.0*cm, y, thr)
y -= 0.55*cm
c.showPage(); c.save()
return out_path
def _apply_english_overlay(r):
try:
if hasattr(r, "names") and isinstance(r.names, dict):
r.names = translate_names_dict(r.names)
except Exception:
pass
# =========================
# INFERENCE (filters toggle + imgsz=1280 + debug)
# =========================
def detect_image_safe(model_key: str, image, conf: float, iou: float, bypass_filters: bool = True):
try:
if image is None:
return None, [], "⚠️ No image provided.", [], None, _model_info_text()
cv2 = _lazy_cv2()
model = _get_model(model_key, conf, iou)
results = model.predict(image, imgsz=1280, verbose=False) # larger input helps tiny drones
r = results[0]
_apply_english_overlay(r)
rows_raw = _results_to_rows(results)
rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key)
annotated_bgr = _draw_annotations_bgr(r.orig_img, rows)
now_utc = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
det_records = [{
"time": now_utc,
"object": map_label(row["class"]),
"confidence": float(row.get("confidence") or 0.0),
"threat": "Threat" if is_threat(map_label(row["class"])) else "Non-threat",
} for row in rows]
# Debug summary shows raw vs kept counts
summary = f"raw:{len(rows_raw)} | kept:{len(rows)}"
counts = {}
for d in det_records:
counts[d["object"]] = counts.get(d["object"], 0) + 1
if counts:
summary += " • " + ", ".join(f"{k}: {v}" for k, v in counts.items())
tmp_img = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.jpg")
try:
cv2.imwrite(tmp_img, annotated_bgr)
except Exception:
tmp_img = None
annotated_rgb = annotated_bgr[:, :, ::-1]
return annotated_rgb, rows, summary, det_records, tmp_img, _model_info_text()
except Exception as e:
return None, [], f"❌ Error during image detection: {e}", [], None, _model_info_text()
def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float, max_frames: int = 300, bypass_filters: bool = True):
try:
if not video_path:
return None, "{}", "⚠️ No video provided.", [], _model_info_text()
cv2 = _lazy_cv2()
model = _get_model(model_key, conf, iou)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "{}", "❌ Failed to open video.", [], _model_info_text()
fps = cap.get(cv2.CAP_PROP_FPS) or 24.0
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 1280)
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 720)
out_path = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.mp4")
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
if not writer or (hasattr(writer, "isOpened") and not writer.isOpened()):
return None, "{}", "❌ Video writer could not open. Try another format/resolution.", [], _model_info_text()
det_records: List[dict] = []
frames = 0
raw_total = 0
kept_total = 0
try:
while True:
ok, frame = cap.read()
if not ok:
break
frames += 1
if frames > int(max_frames):
break
results = model.predict(frame, imgsz=1280, verbose=False)
r = results[0]
_apply_english_overlay(r)
rows_raw = _results_to_rows(results)
rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key)
raw_total += len(rows_raw)
kept_total += len(rows)
t_sec = frames / float(fps if fps > 0 else 24.0)
for row in rows:
label = map_label(row["class"])
det_records.append({
"time": f"{t_sec:.2f}s",
"object": label,
"confidence": float(row.get("confidence") or 0.0),
"threat": "Threat" if is_threat(label) else "Non-threat",
})
annotated_bgr = _draw_annotations_bgr(frame, rows)
writer.write(annotated_bgr)
finally:
cap.release()
writer.release()
# Debug summary
counts = {}
for d in det_records:
counts[d["object"]] = counts.get(d["object"], 0) + 1
summary = f"raw:{raw_total} | kept:{kept_total}"
if counts:
summary += " • " + ", ".join(f"{k}: {v}" for k, v in sorted(counts.items()))
detections_json = json.dumps(det_records[:200], ensure_ascii=False, indent=2)
return out_path, detections_json, summary, det_records, _model_info_text()
except Exception as e:
return None, "{}", f"❌ Error during video detection: {e}", [], _model_info_text()
# ---------- PDF export ----------
def export_pdf_img(det_records: List[dict], summary: str, annotated_tmp_jpg: Optional[str]):
try:
note = summary or ""
return _save_pdf_detections(
"UAV Detector — Image Report", det_records or [], note,
image_path=annotated_tmp_jpg if annotated_tmp_jpg and os.path.exists(annotated_tmp_jpg) else None
)
except Exception as e:
return _save_pdf_detections("UAV Detector — Image Report", [], f"❌ PDF export error: {e}", None)
def export_pdf_vid(det_records, summary):
"""Be forgiving: accept list[dict], DataFrame, JSON string, or None."""
try:
# Normalize detections
if det_records is None:
det_list = []
elif isinstance(det_records, list):
det_list = det_records
elif isinstance(det_records, str):
try:
det_list = json.loads(det_records)
if not isinstance(det_list, list):
det_list = []
except Exception:
det_list = []
else:
try:
import pandas as pd
if isinstance(det_records, pd.DataFrame):
det_list = det_records.to_dict(orient="records")
else:
det_list = []
except Exception:
det_list = []
note = summary if isinstance(summary, str) else (str(summary) if summary is not None else "")
return _save_pdf_detections("UAV Detector — Video Report", det_list, note, image_path=None)
except Exception as e:
return _save_pdf_detections("UAV Detector — Video Report", [], f"❌ PDF export error: {e}", None)
# =========================
# UI
# =========================
NOTE = (
"Detections include timestamp, object, confidence, and Threat/Non-threat. "
"Use 'Bypass filters (debug)' to see raw model boxes; tighten filters after you confirm detections."
)
with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
gr.Markdown("# UAV / Drone Detection (Pretrained YOLO)")
gr.Markdown("Embedded samples (optional): `samples/uav_image.jpg`, `samples/uav_video.mp4`.")
with gr.Row():
model_key = gr.Dropdown(choices=list(MODEL_CHOICES.keys()),
value=list(MODEL_CHOICES.keys())[0],
label="Model")
model_info_md = gr.Markdown(value=_model_info_text())
with gr.Tabs():
# IMAGE
with gr.TabItem("Image"):
with gr.Row():
image_in = gr.Image(
value=EMBED_IMG if os.path.exists(EMBED_IMG) else None,
type="filepath",
label="Input Image"
)
with gr.Column():
conf_img = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence")
iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
filters_off_img = gr.Checkbox(value=True, label="Bypass filters (debug)")
run_img = gr.Button("Run Detection")
gr.Markdown(NOTE)
image_out = gr.Image(label="Annotated Image")
table_out = gr.Dataframe(headers=["class","confidence","x1","y1","x2","y2","width","height"])
msg_img = gr.Markdown()
pdf_img_btn = gr.Button("Generate PDF Report")
pdf_img_path = gr.File(label="PDF Report", interactive=False)
annotated_tmp_img_path = gr.State(value=None)
image_det_state = gr.State(value=[])
def _run_img(mkey, image, conf, iou, bypass):
return detect_image_safe(mkey, image, conf, iou, bypass)
run_img.click(
fn=_run_img,
inputs=[model_key, image_in, conf_img, iou_img, filters_off_img],
outputs=[image_out, table_out, msg_img, image_det_state, annotated_tmp_img_path, model_info_md],
)
pdf_img_btn.click(
fn=export_pdf_img,
inputs=[image_det_state, msg_img, annotated_tmp_img_path],
outputs=[pdf_img_path],
)
# VIDEO
with gr.TabItem("Video"):
with gr.Row():
video_in = gr.Video(
value=EMBED_VID if os.path.exists(EMBED_VID) else None,
label="Input Video"
)
with gr.Column():
conf_vid = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence")
iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
filters_off_vid = gr.Checkbox(value=True, label="Bypass filters (debug)")
run_vid = gr.Button("Run Detection")
gr.Markdown(NOTE)
video_out = gr.Video(label="Annotated Video")
detections_json_text = gr.Textbox(label="Detections (first 200)", max_lines=20)
msg_vid = gr.Markdown()
pdf_vid_btn = gr.Button("Generate PDF Report")
pdf_vid_path = gr.File(label="PDF Report", interactive=False)
video_det_state = gr.State(value=[])
def _run_vid(mkey, vpath, conf, iou, maxf, bypass):
return detect_video_safe(mkey, vpath, conf, iou, int(maxf), bypass)
run_vid.click(
fn=_run_vid,
inputs=[model_key, video_in, conf_vid, iou_vid, max_frames, filters_off_vid],
outputs=[video_out, detections_json_text, msg_vid, video_det_state, model_info_md],
)
# IMPORTANT: feed the structured state (video_det_state) to PDF — not the textbox string
pdf_vid_btn.click(
fn=export_pdf_vid,
inputs=[video_det_state, msg_vid],
outputs=[pdf_vid_path],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=True)