import cv2 as cv import numpy as np import gradio as gr from vittrack import VitTrack from huggingface_hub import hf_hub_download import os import tempfile # Download ONNX model at startup MODEL_PATH = hf_hub_download( repo_id="opencv/object_tracking_vittrack", filename="object_tracking_vittrack_2023sep.onnx" ) backend_id = cv.dnn.DNN_BACKEND_OPENCV target_id = cv.dnn.DNN_TARGET_CPU car_on_road_video = "examples/car.mp4" car_in_desert_video = "examples/desert_car.mp4" # Global state state = { "points": [], "bbox": None, "video_path": None, "first_frame": None } #Example bounding boxes bbox_dict = { "car.mp4": "(152, 356, 332, 104)", "desert_car.mp4": "(758, 452, 119, 65)", } def load_first_frame(video_path): """Load video, grab first frame, reset state.""" state["video_path"] = video_path cap = cv.VideoCapture(video_path) has_frame, frame = cap.read() cap.release() if not has_frame: return None state["first_frame"] = frame.copy() return cv.cvtColor(frame, cv.COLOR_BGR2RGB) def select_point(img, evt: gr.SelectData): """Accumulate up to 4 clicks, draw polygon + bounding box.""" if state["first_frame"] is None: return None x, y = int(evt.index[0]), int(evt.index[1]) if len(state["points"]) < 4: state["points"].append((x, y)) vis = state["first_frame"].copy() # draw each point for pt in state["points"]: cv.circle(vis, pt, 5, (0, 255, 0), -1) # draw connecting polygon if len(state["points"]) > 1: pts = np.array(state["points"], dtype=np.int32) cv.polylines(vis, [pts], isClosed=False, color=(255, 255, 0), thickness=2) # once we have exactly 4, compute & draw bounding rect if len(state["points"]) == 4: pts = np.array(state["points"], dtype=np.int32) x0, y0, w, h = cv.boundingRect(pts) state["bbox"] = (x0, y0, w, h) cv.rectangle(vis, (x0, y0), (x0 + w, y0 + h), (0, 0, 255), 2) return cv.cvtColor(vis, cv.COLOR_BGR2RGB) def clear_points(): """Reset selected points only.""" state["points"].clear() state["bbox"] = None if state["first_frame"] is None: return None return cv.cvtColor(state["first_frame"], cv.COLOR_BGR2RGB) def clear_all(): """Reset everything.""" state["points"].clear() state["bbox"] = None state["video_path"] = None state["first_frame"] = None return None, None, None def track_video(): """Init VitTrack and process entire video, return output path.""" if state["video_path"] is None or state["bbox"] is None: return None # instantiate VitTrack model = VitTrack( model_path=MODEL_PATH, backend_id=backend_id, target_id= target_id ) cap = cv.VideoCapture(state["video_path"]) fps = cap.get(cv.CAP_PROP_FPS) w = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) # prepare temporary output file tmpdir = tempfile.gettempdir() out_path = os.path.join(tmpdir, "vittrack_output.mp4") writer = cv.VideoWriter( out_path, cv.VideoWriter_fourcc(*"mp4v"), fps, (w, h) ) # read & init on first frame _, first_frame = cap.read() model.init(first_frame, state["bbox"]) tm = cv.TickMeter() while True: has_frame, frame = cap.read() if not has_frame: break tm.start() isLocated, bbox, score = model.infer(frame) tm.stop() vis = frame.copy() # overlay FPS cv.putText(vis, f"FPS:{tm.getFPS():.2f}", (w//4, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) # draw tracking box or loss message if isLocated and score >= 0.3: x, y, w_, h_ = bbox cv.rectangle(vis, (x, y), (x + w_, y + h_), (0, 255, 0), 2) cv.putText(vis, f"{score:.2f}", (x, y - 10), cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) else: cv.putText(vis, "Target lost!", (w // 2, h//4), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3) writer.write(vis) tm.reset() cap.release() writer.release() return out_path def example_pipeline(video_path): clear_all() filename = video_path.split('/')[-1] state["video_path"] = video_path state["bbox"] = eval(bbox_dict[filename]) return track_video() with gr.Blocks(css='''.example * { font-style: italic; font-size: 18px !important; color: #0ea5e9 !important; }''') as demo: gr.Markdown("## VitTrack: Interactive Video Object Tracking") gr.Markdown( """ **How to use this tool:** 1. **Upload a video** file (e.g., `.mp4` or `.avi`). 2. The **first frame** of the video will appear. 3. **Click exactly 4 points** on the object you want to track. These points should outline the object as closely as possible. 4. A **bounding box** will be drawn around the selected region automatically. 5. Click the **Track** button to start object tracking across the entire video. 6. The output video with tracking overlay will appear below. You can also use: - ๐Ÿงน **Clear Points** to reset the 4-point selection on the first frame. - ๐Ÿ”„ **Clear All** to reset the uploaded video, frame, and selections. """ ) with gr.Row(): video_in = gr.Video(label="Upload Video") first_frame = gr.Image(label="First Frame", interactive=True) output_video = gr.Video(label="Tracking Result") with gr.Row(): track_btn = gr.Button("Track", variant="primary") clear_pts_btn = gr.Button("Clear Points") clear_all_btn = gr.Button("Clear All") gr.Markdown("Click any row to load an example.", elem_classes=["example"]) examples = [ [car_on_road_video], [car_in_desert_video], ] gr.Examples( examples=examples, inputs=[video_in], outputs=[output_video], fn=example_pipeline, cache_examples=False, run_on_click=True ) gr.Markdown("Example videos credit: https://pixabay.com/") video_in.change(fn=load_first_frame, inputs=video_in, outputs=first_frame) first_frame.select(fn=select_point, inputs=first_frame, outputs=first_frame) clear_pts_btn.click(fn=clear_points, outputs=first_frame) clear_all_btn.click(fn=clear_all, outputs=[video_in, first_frame, output_video]) track_btn.click(fn=track_video, outputs=output_video) if __name__ == "__main__": demo.launch()