gradio-Ui / app.py
Nour190's picture
update app.py
8c74959 verified
# app.py
"""
Shoplifting detection Gradio app β€” robust startup:
- finds local model file (best.pt) anywhere in repo
- avoids writing to /data if not writable (chooses a writable fallback)
- sets YOLO_CONFIG_DIR to a writable dir to silence Ultralytics permission warnings
"""
import os
import time
import logging
import gradio as gr
import pandas as pd
import tempfile
# ensure deterministic logs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# set YOLO_CONFIG_DIR to a writable folder inside the workspace to avoid Ultralytics permission warnings
YOLO_CONFIG_DIR = os.path.join(os.getcwd(), ".ultralytics")
os.environ.setdefault("YOLO_CONFIG_DIR", YOLO_CONFIG_DIR)
try:
os.makedirs(YOLO_CONFIG_DIR, exist_ok=True)
logger.info(f"YOLO_CONFIG_DIR set to: {YOLO_CONFIG_DIR}")
except Exception as e:
logger.warning(f"Failed creating YOLO_CONFIG_DIR {YOLO_CONFIG_DIR}: {e}")
# ---- model config ----
MODEL_FILENAME = "best.pt"
# common expected locations (but we'll search recursively)
COMMON_MODEL_PATHS = [
os.path.join(os.getcwd(), MODEL_FILENAME),
os.path.join(os.getcwd(), "models", MODEL_FILENAME),
]
# ---- utility: choose writable path (for cache and outputs) ----
def choose_writable_path(preferred_paths, fallback_name):
"""
Return first writable path from preferred_paths (creates if needed).
Falls back to tempfile or current working dir subfolder.
"""
for p in preferred_paths:
if not p:
continue
try:
os.makedirs(p, exist_ok=True)
# quick write test
test_path = os.path.join(p, f".write_test_{int(time.time())}")
with open(test_path, "w") as f:
f.write("ok")
os.remove(test_path)
logger.info(f"Using writable path: {p}")
return p
except Exception as e:
logger.warning(f"Cannot use path '{p}': {e}")
tmp_base = os.path.join(tempfile.gettempdir(), fallback_name)
try:
os.makedirs(tmp_base, exist_ok=True)
logger.info(f"Falling back to temporary path: {tmp_base}")
return tmp_base
except Exception as e:
cwd_fallback = os.path.join(os.getcwd(), fallback_name)
try:
os.makedirs(cwd_fallback, exist_ok=True)
logger.info(f"Falling back to CWD path: {cwd_fallback}")
return cwd_fallback
except Exception as e2:
raise RuntimeError(f"Failed to create fallback dirs: {e} / {e2}")
# Resolve HF cache & outputs using writable locations (do not assume /data exists or is writable)
PREFERRED_HF_HOME = os.getenv("HF_HOME") # if user set in env vars
if not PREFERRED_HF_HOME:
PREFERRED_HF_HOME = os.path.join(os.getcwd(), ".huggingface") # default inside repo
HF_HOME_DIR = choose_writable_path([PREFERRED_HF_HOME, os.path.join(os.getcwd(), ".huggingface")], "hf_cache")
CACHE_DIR = os.path.join(HF_HOME_DIR, "hub")
PREFERRED_BASE_OUT = os.getenv("BASE_OUT") or os.path.join(os.getcwd(), "shoplift_outputs")
BASE_OUT = choose_writable_path([PREFERRED_BASE_OUT, os.path.join(os.getcwd(), "shoplift_outputs")], "shoplift_outputs")
os.makedirs(BASE_OUT, exist_ok=True)
logger.info(f"CACHE_DIR resolved to: {CACHE_DIR}")
logger.info(f"BASE_OUT resolved to: {BASE_OUT}")
# ---- prepare model path: search locally first (recommended) ----
def find_local_model():
# 1) check common places
for p in COMMON_MODEL_PATHS:
try:
if os.path.exists(p):
size = os.path.getsize(p)
if size > 100 * 1024: # treat anything >100KB as the real model
return p
else:
logger.warning(f"Found {p} but size is small ({size} bytes) β€” might be a pointer file.")
except Exception:
continue
# 2) recursive search within workspace
for root, dirs, files in os.walk(os.getcwd()):
if MODEL_FILENAME in files:
candidate = os.path.join(root, MODEL_FILENAME)
try:
size = os.path.getsize(candidate)
except Exception:
size = 0
if size > 100 * 1024:
return candidate
else:
# small file -> likely Git LFS pointer
raise RuntimeError(
f"Found {candidate} but its size is {size} bytes β€” looks like a Git LFS pointer. "
"Make sure you uploaded the real model binary (use git lfs) or place the full .pt in the repo."
)
return None
# Try to find the model locally
local_model = find_local_model()
if local_model:
MODEL_PATH = local_model
logger.info(f"Using local model file at: {MODEL_PATH}")
else:
# No local model found β€” give clear error and instructions
raise RuntimeError(
"No local model 'best.pt' found in the repository. Please add the model binary to the repo (recommended),\n"
"or set HUGGINGFACE_HUB_TOKEN in Space settings to allow downloading from the Hub. "
"Recommended step: put the full model binary at the project root or at models/best.pt and re-run."
)
# ---- imports that rely on MODEL_PATH being set ----
from video import process_video_stream
from image import process_image
# ---- SMTP / Email settings (use env vars; fallback to defaults where appropriate) ----
SMTP_SERVER = os.getenv("SMTP_SERVER", "smtp.gmail.com")
SMTP_PORT = int(os.getenv("SMTP_PORT", os.getenv("SMTP_PORT", "587")))
EMAIL_USER = os.getenv("EMAIL_USER", "nourmohamed20230@gmail.com")
EMAIL_PASS = os.getenv("EMAIL_PASS", "rklowjzoywtbttxz") # recommended: set this in Space Secrets / env vars (app password)
if not EMAIL_PASS:
logger.warning(
"EMAIL_PASS not set. Email sending will be disabled until you set EMAIL_PASS as an env var "
"(use an app password for Gmail). Set EMAIL_USER and EMAIL_PASS in Space Settings -> Variables/Secrets."
)
def make_smtp_cfg(email_to):
if email_to and email_to.strip():
return {
"enabled": True,
"smtp_server": SMTP_SERVER,
"smtp_port": SMTP_PORT,
"email_user": EMAIL_USER,
"email_pass": EMAIL_PASS,
"email_to": email_to.strip()
}
else:
return {"enabled": False}
def make_openrouter_cfg():
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "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")
OPENROUTER_BASEURL = os.getenv("OPENROUTER_BASEURL", "https://openrouter.ai/api/v1")
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-3-12b-it:free")
if OPENROUTER_API_KEY and OPENROUTER_API_KEY.strip():
return {
"api_key": OPENROUTER_API_KEY.strip(),
"base_url": OPENROUTER_BASEURL,
"model_name": OPENROUTER_MODEL
}
else:
return None
def run_video_pipeline(uploaded_video_file, email_to, conf_thresh, confirm_conf_thresh):
"""Wrapper generator for video processing to normalize outputs for Gradio."""
if uploaded_video_file is None:
yield "Please upload a video.", None, None, [], pd.DataFrame(), None
return
ts = int(time.time())
run_dir = os.path.join(BASE_OUT, f"run_{ts}")
os.makedirs(run_dir, exist_ok=True)
# save uploaded video (uploaded_video_file is a gr.File -> has .name path)
video_local = os.path.join(run_dir, "input_video.mp4")
try:
with open(video_local, "wb") as out_f, open(uploaded_video_file.name, "rb") as in_f:
out_f.write(in_f.read())
except Exception as e:
yield f"Error saving uploaded video: {e}", None, None, [], pd.DataFrame(), None
return
smtp_cfg = make_smtp_cfg(email_to)
openrouter_cfg = make_openrouter_cfg()
gen = process_video_stream(
video_path=video_local,
model_path=MODEL_PATH,
out_root=run_dir,
openrouter_cfg=openrouter_cfg,
smtp_cfg=smtp_cfg,
conf_thresh=float(conf_thresh),
confirm_conf_thresh=float(confirm_conf_thresh),
send_interval=4.0,
confirmed_block_seconds=1000.0,
progress_interval_frames=30
)
last_csv_df = pd.DataFrame()
last_gallery = []
last_annotated_video = None
last_live = None
for update in gen:
status = update.get("status", "")
live_frame = update.get("live_frame", "") # path to last suspicious frame
suspicious_list = update.get("suspicious_list", []) or []
csv_path = update.get("csv_path", "")
annotated_video = update.get("annotated_video", None)
# load CSV into dataframe if exists
if csv_path and os.path.exists(csv_path):
try:
df = pd.read_csv(csv_path)
except Exception:
df = last_csv_df
else:
df = last_csv_df
gallery_list = suspicious_list
live_img = live_frame if live_frame and os.path.exists(live_frame) else last_live
last_csv_df = df
last_gallery = gallery_list
if annotated_video:
last_annotated_video = annotated_video
last_live = live_img
# For video mode: annotated_vid filled, annotated_img None
yield status, (last_annotated_video if last_annotated_video else None), None, gallery_list, df, (live_img if live_img else None)
return
def run_image_pipeline(uploaded_image_file, email_to, conf_thresh, confirm_conf_thresh):
"""Wrapper generator for image processing to normalize outputs for Gradio."""
if uploaded_image_file is None:
yield "Please upload an image.", None, None, [], pd.DataFrame(), None
return
ts = int(time.time())
run_dir = os.path.join(BASE_OUT, f"run_{ts}")
os.makedirs(run_dir, exist_ok=True)
# save uploaded image (uploaded_image_file is gr.File -> has .name path)
image_local = os.path.join(run_dir, os.path.basename(uploaded_image_file.name))
try:
with open(image_local, "wb") as out_f, open(uploaded_image_file.name, "rb") as in_f:
out_f.write(in_f.read())
except Exception as e:
yield f"Error saving uploaded image: {e}", None, None, [], pd.DataFrame(), None
return
smtp_cfg = make_smtp_cfg(email_to)
openrouter_cfg = make_openrouter_cfg()
gen = process_image(
image_path=image_local,
model_path=MODEL_PATH,
out_root=run_dir,
openrouter_cfg=openrouter_cfg,
smtp_cfg=smtp_cfg,
conf_thresh=float(conf_thresh),
confirm_conf_thresh=float(confirm_conf_thresh)
)
last_csv_df = pd.DataFrame()
last_gallery = []
last_annotated_image = None
last_live = None
for update in gen:
status = update.get("status", "")
live_frame = update.get("live_frame", "") # path to last suspicious frame
suspicious_list = update.get("suspicious_list", []) or []
csv_path = update.get("csv_path", "")
annotated_image = update.get("annotated_image", None)
# load CSV into dataframe if exists
if csv_path and os.path.exists(csv_path):
try:
df = pd.read_csv(csv_path)
except Exception:
df = last_csv_df
else:
df = last_csv_df
gallery_list = suspicious_list
live_img = live_frame if live_frame and os.path.exists(live_frame) else last_live
last_csv_df = df
last_gallery = gallery_list
if annotated_image:
last_annotated_image = annotated_image
last_live = live_img
# For image mode: annotated_vid None, annotated_img filled
yield status, None, (last_annotated_image if last_annotated_image else None), gallery_list, df, (live_img if live_img else None)
return
def run_handler(mode, video_file, image_file, email_to, conf_thresh, confirm_conf_thresh):
"""Main dispatching function called by Gradio. It yields tuples matching outputs:
(status_txt, annotated_vid, annotated_img, gallery, csv_table (df), live_frame)
"""
if mode == "Video":
# delegate to video pipeline generator
for out in run_video_pipeline(video_file, email_to, conf_thresh, confirm_conf_thresh):
yield out
else:
# Image mode
for out in run_image_pipeline(image_file, email_to, conf_thresh, confirm_conf_thresh):
yield out
# Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Shoplifting Detection β€” Video or Image")
with gr.Row():
with gr.Column(scale=2):
mode = gr.Radio(["Video", "Image"], label="Mode", value="Video")
video_file = gr.File(label="Upload video (mp4...)", file_types=["video"], visible=True)
image_file = gr.File(label="Upload image (jpg/png...)", file_types=["image"], visible=False)
email_to = gr.Textbox(label="Recipient email (to) β€” leave empty to disable email")
conf_thresh = gr.Slider(label="Confidence threshold", minimum=0.01, maximum=1.0, value=0.5, step=0.01)
confirm_conf = gr.Slider(label="Confirmation threshold", minimum=0.01, maximum=1.0, value=0.7, step=0.01)
start_btn = gr.Button("Start")
gr.Markdown(f"**Using MODEL_PATH (static):** `{MODEL_PATH}`")
gr.Markdown("**Note:** SMTP / OpenRouter API key are read from env vars if set.")
with gr.Column(scale=2):
status_txt = gr.Textbox(label="Status", lines=3)
annotated_vid = gr.Video(label="Annotated video (final)", visible=True)
annotated_img = gr.Image(label="Annotated image (final)", visible=False)
gallery = gr.Gallery(label="Suspicious frames (click to preview)", columns=4, height="auto")
csv_table = gr.Dataframe(label="CSV Log")
live_frame = gr.Image(label="Live detected frame (real-time)")
# Toggle visibility of inputs/outputs when mode changes
def toggle_mode(m):
if m == "Video":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
# Connect mode change: updates (video_file, image_file, annotated_vid, annotated_img)
mode.change(toggle_mode, inputs=[mode], outputs=[video_file, image_file, annotated_vid, annotated_img])
# Start button triggers the dispatcher. Outputs: status_txt, annotated_vid, annotated_img, gallery, csv_table, live_frame
start_btn.click(
fn=run_handler,
inputs=[mode, video_file, image_file, email_to, conf_thresh, confirm_conf],
outputs=[status_txt, annotated_vid, annotated_img, gallery, csv_table, live_frame]
)
if __name__ == "__main__":
demo.launch(share=True)