|
|
|
""" |
|
Shoplifting detection Gradio app β robust startup: |
|
- finds local model file (best.pt) anywhere in repo |
|
- avoids writing to /data if not writable (chooses a writable fallback) |
|
- sets YOLO_CONFIG_DIR to a writable dir to silence Ultralytics permission warnings |
|
""" |
|
import os |
|
import time |
|
import logging |
|
import gradio as gr |
|
import pandas as pd |
|
import tempfile |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
YOLO_CONFIG_DIR = os.path.join(os.getcwd(), ".ultralytics") |
|
os.environ.setdefault("YOLO_CONFIG_DIR", YOLO_CONFIG_DIR) |
|
try: |
|
os.makedirs(YOLO_CONFIG_DIR, exist_ok=True) |
|
logger.info(f"YOLO_CONFIG_DIR set to: {YOLO_CONFIG_DIR}") |
|
except Exception as e: |
|
logger.warning(f"Failed creating YOLO_CONFIG_DIR {YOLO_CONFIG_DIR}: {e}") |
|
|
|
|
|
MODEL_FILENAME = "best.pt" |
|
|
|
COMMON_MODEL_PATHS = [ |
|
os.path.join(os.getcwd(), MODEL_FILENAME), |
|
os.path.join(os.getcwd(), "models", MODEL_FILENAME), |
|
] |
|
|
|
|
|
def choose_writable_path(preferred_paths, fallback_name): |
|
""" |
|
Return first writable path from preferred_paths (creates if needed). |
|
Falls back to tempfile or current working dir subfolder. |
|
""" |
|
for p in preferred_paths: |
|
if not p: |
|
continue |
|
try: |
|
os.makedirs(p, exist_ok=True) |
|
|
|
test_path = os.path.join(p, f".write_test_{int(time.time())}") |
|
with open(test_path, "w") as f: |
|
f.write("ok") |
|
os.remove(test_path) |
|
logger.info(f"Using writable path: {p}") |
|
return p |
|
except Exception as e: |
|
logger.warning(f"Cannot use path '{p}': {e}") |
|
|
|
tmp_base = os.path.join(tempfile.gettempdir(), fallback_name) |
|
try: |
|
os.makedirs(tmp_base, exist_ok=True) |
|
logger.info(f"Falling back to temporary path: {tmp_base}") |
|
return tmp_base |
|
except Exception as e: |
|
cwd_fallback = os.path.join(os.getcwd(), fallback_name) |
|
try: |
|
os.makedirs(cwd_fallback, exist_ok=True) |
|
logger.info(f"Falling back to CWD path: {cwd_fallback}") |
|
return cwd_fallback |
|
except Exception as e2: |
|
raise RuntimeError(f"Failed to create fallback dirs: {e} / {e2}") |
|
|
|
|
|
PREFERRED_HF_HOME = os.getenv("HF_HOME") |
|
if not PREFERRED_HF_HOME: |
|
PREFERRED_HF_HOME = os.path.join(os.getcwd(), ".huggingface") |
|
|
|
HF_HOME_DIR = choose_writable_path([PREFERRED_HF_HOME, os.path.join(os.getcwd(), ".huggingface")], "hf_cache") |
|
CACHE_DIR = os.path.join(HF_HOME_DIR, "hub") |
|
|
|
PREFERRED_BASE_OUT = os.getenv("BASE_OUT") or os.path.join(os.getcwd(), "shoplift_outputs") |
|
BASE_OUT = choose_writable_path([PREFERRED_BASE_OUT, os.path.join(os.getcwd(), "shoplift_outputs")], "shoplift_outputs") |
|
os.makedirs(BASE_OUT, exist_ok=True) |
|
|
|
logger.info(f"CACHE_DIR resolved to: {CACHE_DIR}") |
|
logger.info(f"BASE_OUT resolved to: {BASE_OUT}") |
|
|
|
|
|
def find_local_model(): |
|
|
|
for p in COMMON_MODEL_PATHS: |
|
try: |
|
if os.path.exists(p): |
|
size = os.path.getsize(p) |
|
if size > 100 * 1024: |
|
return p |
|
else: |
|
logger.warning(f"Found {p} but size is small ({size} bytes) β might be a pointer file.") |
|
except Exception: |
|
continue |
|
|
|
|
|
for root, dirs, files in os.walk(os.getcwd()): |
|
if MODEL_FILENAME in files: |
|
candidate = os.path.join(root, MODEL_FILENAME) |
|
try: |
|
size = os.path.getsize(candidate) |
|
except Exception: |
|
size = 0 |
|
if size > 100 * 1024: |
|
return candidate |
|
else: |
|
|
|
raise RuntimeError( |
|
f"Found {candidate} but its size is {size} bytes β looks like a Git LFS pointer. " |
|
"Make sure you uploaded the real model binary (use git lfs) or place the full .pt in the repo." |
|
) |
|
return None |
|
|
|
|
|
local_model = find_local_model() |
|
if local_model: |
|
MODEL_PATH = local_model |
|
logger.info(f"Using local model file at: {MODEL_PATH}") |
|
else: |
|
|
|
raise RuntimeError( |
|
"No local model 'best.pt' found in the repository. Please add the model binary to the repo (recommended),\n" |
|
"or set HUGGINGFACE_HUB_TOKEN in Space settings to allow downloading from the Hub. " |
|
"Recommended step: put the full model binary at the project root or at models/best.pt and re-run." |
|
) |
|
|
|
|
|
from video import process_video_stream |
|
from image import process_image |
|
|
|
|
|
SMTP_SERVER = os.getenv("SMTP_SERVER", "smtp.gmail.com") |
|
SMTP_PORT = int(os.getenv("SMTP_PORT", os.getenv("SMTP_PORT", "587"))) |
|
EMAIL_USER = os.getenv("EMAIL_USER", "nourmohamed20230@gmail.com") |
|
EMAIL_PASS = os.getenv("EMAIL_PASS", "rklowjzoywtbttxz") |
|
|
|
if not EMAIL_PASS: |
|
logger.warning( |
|
"EMAIL_PASS not set. Email sending will be disabled until you set EMAIL_PASS as an env var " |
|
"(use an app password for Gmail). Set EMAIL_USER and EMAIL_PASS in Space Settings -> Variables/Secrets." |
|
) |
|
|
|
def make_smtp_cfg(email_to): |
|
if email_to and email_to.strip(): |
|
return { |
|
"enabled": True, |
|
"smtp_server": SMTP_SERVER, |
|
"smtp_port": SMTP_PORT, |
|
"email_user": EMAIL_USER, |
|
"email_pass": EMAIL_PASS, |
|
"email_to": email_to.strip() |
|
} |
|
else: |
|
return {"enabled": False} |
|
|
|
|
|
def make_openrouter_cfg(): |
|
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "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") |
|
OPENROUTER_BASEURL = os.getenv("OPENROUTER_BASEURL", "https://openrouter.ai/api/v1") |
|
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-3-12b-it:free") |
|
if OPENROUTER_API_KEY and OPENROUTER_API_KEY.strip(): |
|
return { |
|
"api_key": OPENROUTER_API_KEY.strip(), |
|
"base_url": OPENROUTER_BASEURL, |
|
"model_name": OPENROUTER_MODEL |
|
} |
|
else: |
|
return None |
|
|
|
|
|
def run_video_pipeline(uploaded_video_file, email_to, conf_thresh, confirm_conf_thresh): |
|
"""Wrapper generator for video processing to normalize outputs for Gradio.""" |
|
if uploaded_video_file is None: |
|
yield "Please upload a video.", None, None, [], pd.DataFrame(), None |
|
return |
|
|
|
ts = int(time.time()) |
|
run_dir = os.path.join(BASE_OUT, f"run_{ts}") |
|
os.makedirs(run_dir, exist_ok=True) |
|
|
|
|
|
video_local = os.path.join(run_dir, "input_video.mp4") |
|
try: |
|
with open(video_local, "wb") as out_f, open(uploaded_video_file.name, "rb") as in_f: |
|
out_f.write(in_f.read()) |
|
except Exception as e: |
|
yield f"Error saving uploaded video: {e}", None, None, [], pd.DataFrame(), None |
|
return |
|
|
|
smtp_cfg = make_smtp_cfg(email_to) |
|
openrouter_cfg = make_openrouter_cfg() |
|
|
|
gen = process_video_stream( |
|
video_path=video_local, |
|
model_path=MODEL_PATH, |
|
out_root=run_dir, |
|
openrouter_cfg=openrouter_cfg, |
|
smtp_cfg=smtp_cfg, |
|
conf_thresh=float(conf_thresh), |
|
confirm_conf_thresh=float(confirm_conf_thresh), |
|
send_interval=4.0, |
|
confirmed_block_seconds=1000.0, |
|
progress_interval_frames=30 |
|
) |
|
|
|
last_csv_df = pd.DataFrame() |
|
last_gallery = [] |
|
last_annotated_video = None |
|
last_live = None |
|
|
|
for update in gen: |
|
status = update.get("status", "") |
|
live_frame = update.get("live_frame", "") |
|
suspicious_list = update.get("suspicious_list", []) or [] |
|
csv_path = update.get("csv_path", "") |
|
annotated_video = update.get("annotated_video", None) |
|
|
|
|
|
if csv_path and os.path.exists(csv_path): |
|
try: |
|
df = pd.read_csv(csv_path) |
|
except Exception: |
|
df = last_csv_df |
|
else: |
|
df = last_csv_df |
|
|
|
gallery_list = suspicious_list |
|
live_img = live_frame if live_frame and os.path.exists(live_frame) else last_live |
|
|
|
last_csv_df = df |
|
last_gallery = gallery_list |
|
if annotated_video: |
|
last_annotated_video = annotated_video |
|
last_live = live_img |
|
|
|
|
|
yield status, (last_annotated_video if last_annotated_video else None), None, gallery_list, df, (live_img if live_img else None) |
|
|
|
return |
|
|
|
|
|
def run_image_pipeline(uploaded_image_file, email_to, conf_thresh, confirm_conf_thresh): |
|
"""Wrapper generator for image processing to normalize outputs for Gradio.""" |
|
if uploaded_image_file is None: |
|
yield "Please upload an image.", None, None, [], pd.DataFrame(), None |
|
return |
|
|
|
ts = int(time.time()) |
|
run_dir = os.path.join(BASE_OUT, f"run_{ts}") |
|
os.makedirs(run_dir, exist_ok=True) |
|
|
|
|
|
image_local = os.path.join(run_dir, os.path.basename(uploaded_image_file.name)) |
|
try: |
|
with open(image_local, "wb") as out_f, open(uploaded_image_file.name, "rb") as in_f: |
|
out_f.write(in_f.read()) |
|
except Exception as e: |
|
yield f"Error saving uploaded image: {e}", None, None, [], pd.DataFrame(), None |
|
return |
|
|
|
smtp_cfg = make_smtp_cfg(email_to) |
|
openrouter_cfg = make_openrouter_cfg() |
|
|
|
gen = process_image( |
|
image_path=image_local, |
|
model_path=MODEL_PATH, |
|
out_root=run_dir, |
|
openrouter_cfg=openrouter_cfg, |
|
smtp_cfg=smtp_cfg, |
|
conf_thresh=float(conf_thresh), |
|
confirm_conf_thresh=float(confirm_conf_thresh) |
|
) |
|
|
|
last_csv_df = pd.DataFrame() |
|
last_gallery = [] |
|
last_annotated_image = None |
|
last_live = None |
|
|
|
for update in gen: |
|
status = update.get("status", "") |
|
live_frame = update.get("live_frame", "") |
|
suspicious_list = update.get("suspicious_list", []) or [] |
|
csv_path = update.get("csv_path", "") |
|
annotated_image = update.get("annotated_image", None) |
|
|
|
|
|
if csv_path and os.path.exists(csv_path): |
|
try: |
|
df = pd.read_csv(csv_path) |
|
except Exception: |
|
df = last_csv_df |
|
else: |
|
df = last_csv_df |
|
|
|
gallery_list = suspicious_list |
|
live_img = live_frame if live_frame and os.path.exists(live_frame) else last_live |
|
|
|
last_csv_df = df |
|
last_gallery = gallery_list |
|
if annotated_image: |
|
last_annotated_image = annotated_image |
|
last_live = live_img |
|
|
|
|
|
yield status, None, (last_annotated_image if last_annotated_image else None), gallery_list, df, (live_img if live_img else None) |
|
|
|
return |
|
|
|
|
|
def run_handler(mode, video_file, image_file, email_to, conf_thresh, confirm_conf_thresh): |
|
"""Main dispatching function called by Gradio. It yields tuples matching outputs: |
|
(status_txt, annotated_vid, annotated_img, gallery, csv_table (df), live_frame) |
|
""" |
|
if mode == "Video": |
|
|
|
for out in run_video_pipeline(video_file, email_to, conf_thresh, confirm_conf_thresh): |
|
yield out |
|
else: |
|
|
|
for out in run_image_pipeline(image_file, email_to, conf_thresh, confirm_conf_thresh): |
|
yield out |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Shoplifting Detection β Video or Image") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
mode = gr.Radio(["Video", "Image"], label="Mode", value="Video") |
|
video_file = gr.File(label="Upload video (mp4...)", file_types=["video"], visible=True) |
|
image_file = gr.File(label="Upload image (jpg/png...)", file_types=["image"], visible=False) |
|
email_to = gr.Textbox(label="Recipient email (to) β leave empty to disable email") |
|
conf_thresh = gr.Slider(label="Confidence threshold", minimum=0.01, maximum=1.0, value=0.5, step=0.01) |
|
confirm_conf = gr.Slider(label="Confirmation threshold", minimum=0.01, maximum=1.0, value=0.7, step=0.01) |
|
start_btn = gr.Button("Start") |
|
gr.Markdown(f"**Using MODEL_PATH (static):** `{MODEL_PATH}`") |
|
gr.Markdown("**Note:** SMTP / OpenRouter API key are read from env vars if set.") |
|
|
|
with gr.Column(scale=2): |
|
status_txt = gr.Textbox(label="Status", lines=3) |
|
annotated_vid = gr.Video(label="Annotated video (final)", visible=True) |
|
annotated_img = gr.Image(label="Annotated image (final)", visible=False) |
|
gallery = gr.Gallery(label="Suspicious frames (click to preview)", columns=4, height="auto") |
|
csv_table = gr.Dataframe(label="CSV Log") |
|
live_frame = gr.Image(label="Live detected frame (real-time)") |
|
|
|
|
|
def toggle_mode(m): |
|
if m == "Video": |
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
mode.change(toggle_mode, inputs=[mode], outputs=[video_file, image_file, annotated_vid, annotated_img]) |
|
|
|
|
|
start_btn.click( |
|
fn=run_handler, |
|
inputs=[mode, video_file, image_file, email_to, conf_thresh, confirm_conf], |
|
outputs=[status_txt, annotated_vid, annotated_img, gallery, csv_table, live_frame] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|