eadali commited on
Commit
2af59b8
·
verified ·
1 Parent(s): 70e650c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +568 -0
app.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import tqdm
4
+ import uuid
5
+ import logging
6
+
7
+ import torch
8
+ import spaces
9
+ import trackers
10
+ import numpy as np
11
+ import gradio as gr
12
+ import imageio.v3 as iio
13
+ import supervision as sv
14
+
15
+ from pathlib import Path
16
+ from functools import lru_cache
17
+ from typing import List, Optional, Tuple
18
+
19
+ from PIL import Image
20
+ from transformers import AutoModelForObjectDetection, AutoImageProcessor
21
+ from transformers.image_utils import load_image
22
+
23
+
24
+ # Configuration constants
25
+ CHECKPOINTS = [
26
+ "ustc-community/dfine-medium-obj2coco",
27
+ "ustc-community/dfine-medium-coco",
28
+ "ustc-community/dfine-medium-obj365",
29
+ "ustc-community/dfine-nano-coco",
30
+ "ustc-community/dfine-small-coco",
31
+ "ustc-community/dfine-large-coco",
32
+ "ustc-community/dfine-xlarge-coco",
33
+ "ustc-community/dfine-small-obj365",
34
+ "ustc-community/dfine-large-obj365",
35
+ "ustc-community/dfine-xlarge-obj365",
36
+ "ustc-community/dfine-small-obj2coco",
37
+ "ustc-community/dfine-large-obj2coco-e25",
38
+ "ustc-community/dfine-xlarge-obj2coco",
39
+ ]
40
+ DEFAULT_CHECKPOINT = CHECKPOINTS[0]
41
+ DEFAULT_CONFIDENCE_THRESHOLD = 0.3
42
+
43
+ TORCH_DTYPE = torch.float32
44
+
45
+ # Image
46
+ IMAGE_EXAMPLES = [
47
+ {"path": "./examples/images/tennis.jpg", "use_url": False, "url": "", "label": "Local Image"},
48
+ {"path": "./examples/images/dogs.jpg", "use_url": False, "url": "", "label": "Local Image"},
49
+ {"path": "./examples/images/nascar.jpg", "use_url": False, "url": "", "label": "Local Image"},
50
+ {"path": "./examples/images/crossroad.jpg", "use_url": False, "url": "", "label": "Local Image"},
51
+ {
52
+ "path": None,
53
+ "use_url": True,
54
+ "url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg",
55
+ "label": "Flickr Image",
56
+ },
57
+ ]
58
+
59
+ # Video
60
+ MAX_NUM_FRAMES = 250
61
+ BATCH_SIZE = 4
62
+ ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"}
63
+ VIDEO_OUTPUT_DIR = Path("static/videos")
64
+ VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
65
+
66
+ class TrackingAlgorithm:
67
+ BYTETRACK = "ByteTrack (2021)"
68
+ DEEPSORT = "DeepSORT (2017)"
69
+ SORT = "SORT (2016)"
70
+
71
+ TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT]
72
+ VIDEO_EXAMPLES = [
73
+ {"path": "./examples/videos/dogs_running.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
74
+ {"path": "./examples/videos/traffic.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK, "classes": "car, truck, bus"},
75
+ {"path": "./examples/videos/fast_and_furious.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
76
+ {"path": "./examples/videos/break_dance.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
77
+ ]
78
+
79
+
80
+ # Create a color palette for visualization
81
+ # These hex color codes define different colors for tracking different objects
82
+ color = sv.ColorPalette.from_hex([
83
+ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
84
+ "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
85
+ ])
86
+
87
+
88
+ logging.basicConfig(
89
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
90
+ )
91
+ logger = logging.getLogger(__name__)
92
+
93
+
94
+ @lru_cache(maxsize=3)
95
+ def get_model_and_processor(checkpoint: str):
96
+ model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE)
97
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
98
+ return model, image_processor
99
+
100
+
101
+ @spaces.GPU(duration=20)
102
+ def detect_objects(
103
+ checkpoint: str,
104
+ images: List[np.ndarray] | np.ndarray,
105
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
106
+ target_size: Optional[Tuple[int, int]] = None,
107
+ batch_size: int = BATCH_SIZE,
108
+ classes: Optional[List[str]] = None,
109
+ ):
110
+
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+ model, image_processor = get_model_and_processor(checkpoint)
113
+ model = model.to(device)
114
+
115
+ if classes is not None:
116
+ wrong_classes = [cls for cls in classes if cls not in model.config.label2id]
117
+ if wrong_classes:
118
+ gr.Warning(f"Classes not found in model config: {wrong_classes}")
119
+ keep_ids = [model.config.label2id[cls] for cls in classes if cls in model.config.label2id]
120
+ else:
121
+ keep_ids = None
122
+
123
+ if isinstance(images, np.ndarray) and images.ndim == 4:
124
+ images = [x for x in images] # split video array into list of images
125
+
126
+ batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
127
+
128
+ results = []
129
+ for batch in tqdm.tqdm(batches, desc="Processing frames"):
130
+
131
+ # preprocess images
132
+ inputs = image_processor(images=batch, return_tensors="pt")
133
+ inputs = inputs.to(device).to(TORCH_DTYPE)
134
+
135
+ # forward pass
136
+ with torch.no_grad():
137
+ outputs = model(**inputs)
138
+
139
+ # postprocess outputs
140
+ if target_size:
141
+ target_sizes = [target_size] * len(batch)
142
+ else:
143
+ target_sizes = [(image.shape[0], image.shape[1]) for image in batch]
144
+
145
+ batch_results = image_processor.post_process_object_detection(
146
+ outputs, target_sizes=target_sizes, threshold=confidence_threshold
147
+ )
148
+
149
+ results.extend(batch_results)
150
+
151
+ # move results to cpu
152
+ for i, result in enumerate(results):
153
+ results[i] = {k: v.cpu() for k, v in result.items()}
154
+ if keep_ids is not None:
155
+ keep = torch.isin(results[i]["labels"], torch.tensor(keep_ids))
156
+ results[i] = {k: v[keep] for k, v in results[i].items()}
157
+
158
+ return results, model.config.id2label
159
+
160
+
161
+ def process_image(
162
+ checkpoint: str = DEFAULT_CHECKPOINT,
163
+ image: Optional[Image.Image] = None,
164
+ url: Optional[str] = None,
165
+ use_url: bool = False,
166
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
167
+ ):
168
+ if not use_url:
169
+ url = None
170
+
171
+ if (image is None) ^ bool(url):
172
+ raise ValueError(f"Either image or url must be provided, but not both.")
173
+
174
+ if url:
175
+ image = load_image(url)
176
+
177
+ results, id2label = detect_objects(
178
+ checkpoint=checkpoint,
179
+ images=[np.array(image)],
180
+ confidence_threshold=confidence_threshold,
181
+ )
182
+ result = results[0] # first image in batch (we have batch size 1)
183
+
184
+ annotations = []
185
+ for label, score, box in zip(result["labels"], result["scores"], result["boxes"]):
186
+ text_label = id2label[label.item()]
187
+ formatted_label = f"{text_label} ({score:.2f})"
188
+ x_min, y_min, x_max, y_max = box.cpu().numpy().round().astype(int)
189
+ x_min = max(0, x_min)
190
+ y_min = max(0, y_min)
191
+ x_max = min(image.width - 1, x_max)
192
+ y_max = min(image.height - 1, y_max)
193
+ annotations.append(((x_min, y_min, x_max, y_max), formatted_label))
194
+
195
+ return (image, annotations)
196
+
197
+
198
+ def get_target_size(image_height, image_width, max_size: int):
199
+ if image_height < max_size and image_width < max_size:
200
+ new_height, new_width = image_height, image_width
201
+ elif image_height > image_width:
202
+ new_height = max_size
203
+ new_width = int(image_width * max_size / image_height)
204
+ else:
205
+ new_width = max_size
206
+ new_height = int(image_height * max_size / image_width)
207
+
208
+ # make even (for video codec compatibility)
209
+ new_height = new_height // 2 * 2
210
+ new_width = new_width // 2 * 2
211
+
212
+ return new_width, new_height
213
+
214
+
215
+ def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1):
216
+ cap = cv2.VideoCapture(video_path)
217
+ frames = []
218
+ i = 0
219
+ progress_bar = tqdm.tqdm(total=k, desc="Reading frames")
220
+ while cap.isOpened() and len(frames) < k:
221
+ ret, frame = cap.read()
222
+ if not ret:
223
+ break
224
+ if i % read_every_i_frame == 0:
225
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
226
+ progress_bar.update(1)
227
+ i += 1
228
+ cap.release()
229
+ progress_bar.close()
230
+ return frames
231
+
232
+
233
+ def get_tracker(tracker: str, fps: float):
234
+ if tracker == TrackingAlgorithm.SORT:
235
+ return trackers.SORTTracker(frame_rate=fps)
236
+ elif tracker == TrackingAlgorithm.DEEPSORT:
237
+ feature_extractor = trackers.DeepSORTFeatureExtractor.from_timm("mobilenetv4_conv_small.e1200_r224_in1k", device="cpu")
238
+ return trackers.DeepSORTTracker(feature_extractor, frame_rate=fps)
239
+ elif tracker == TrackingAlgorithm.BYTETRACK:
240
+ return sv.ByteTrack(frame_rate=int(fps))
241
+ else:
242
+ raise ValueError(f"Invalid tracker: {tracker}")
243
+
244
+
245
+ def update_tracker(tracker, detections, frame):
246
+ tracker_name = tracker.__class__.__name__
247
+ if tracker_name == "SORTTracker":
248
+ return tracker.update(detections)
249
+ elif tracker_name == "DeepSORTTracker":
250
+ return tracker.update(detections, frame)
251
+ elif tracker_name == "ByteTrack":
252
+ return tracker.update_with_detections(detections)
253
+ else:
254
+ raise ValueError(f"Invalid tracker: {tracker}")
255
+
256
+
257
+ def process_video(
258
+ video_path: str,
259
+ checkpoint: str,
260
+ tracker_algorithm: Optional[str] = None,
261
+ classes: str = "all",
262
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
263
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
264
+ ) -> str:
265
+
266
+ if not video_path or not os.path.isfile(video_path):
267
+ raise ValueError(f"Invalid video path: {video_path}")
268
+
269
+ ext = os.path.splitext(video_path)[1].lower()
270
+ if ext not in ALLOWED_VIDEO_EXTENSIONS:
271
+ raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}")
272
+
273
+ video_info = sv.VideoInfo.from_video_path(video_path)
274
+ read_each_i_frame = max(1, video_info.fps // 25)
275
+ target_fps = video_info.fps / read_each_i_frame
276
+ target_width, target_height = get_target_size(video_info.height, video_info.width, 1080)
277
+
278
+ n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame)
279
+ frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame)
280
+ frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames]
281
+
282
+ # Set the color lookup mode to assign colors by track ID
283
+ # This mean objects with the same track ID will be annotated by the same color
284
+ color_lookup = sv.ColorLookup.TRACK if tracker_algorithm else sv.ColorLookup.CLASS
285
+
286
+ box_annotator = sv.BoxAnnotator(color, color_lookup=color_lookup, thickness=1)
287
+ label_annotator = sv.LabelAnnotator(color, color_lookup=color_lookup, text_scale=0.5)
288
+ trace_annotator = sv.TraceAnnotator(color, color_lookup=color_lookup, thickness=1, trace_length=100)
289
+
290
+ # preprocess classes
291
+ if classes != "all":
292
+ classes_list = [cls.strip().lower() for cls in classes.split(",")]
293
+ else:
294
+ classes_list = None
295
+
296
+ results, id2label = detect_objects(
297
+ images=np.array(frames),
298
+ checkpoint=checkpoint,
299
+ confidence_threshold=confidence_threshold,
300
+ target_size=(target_height, target_width),
301
+ classes=classes_list,
302
+ )
303
+
304
+
305
+ annotated_frames = []
306
+
307
+ # detections
308
+ if tracker_algorithm:
309
+ tracker = get_tracker(tracker_algorithm, target_fps)
310
+ for frame, result in progress.tqdm(zip(frames, results), desc="Tracking objects", total=len(frames)):
311
+ detections = sv.Detections.from_transformers(result, id2label=id2label)
312
+ detections = detections.with_nms(threshold=0.95, class_agnostic=True)
313
+ detections = update_tracker(tracker, detections, frame)
314
+ labels = [f"#{tracker_id} {id2label[class_id]}" for class_id, tracker_id in zip(detections.class_id, detections.tracker_id)]
315
+ annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
316
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
317
+ annotated_frame = trace_annotator.annotate(scene=annotated_frame, detections=detections)
318
+ annotated_frames.append(annotated_frame)
319
+
320
+ else:
321
+ for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)):
322
+ detections = sv.Detections.from_transformers(result, id2label=id2label)
323
+ detections = detections.with_nms(threshold=0.95, class_agnostic=True)
324
+ annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
325
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections)
326
+ annotated_frames.append(annotated_frame)
327
+
328
+ output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4")
329
+ iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264")
330
+ return output_filename
331
+
332
+
333
+
334
+ def create_image_inputs() -> List[gr.components.Component]:
335
+ return [
336
+ gr.Image(
337
+ label="Upload Image",
338
+ type="pil",
339
+ sources=["upload", "webcam"],
340
+ interactive=True,
341
+ elem_classes="input-component",
342
+ ),
343
+ gr.Checkbox(label="Use Image URL Instead", value=False),
344
+ gr.Textbox(
345
+ label="Image URL",
346
+ placeholder="https://example.com/image.jpg",
347
+ visible=False,
348
+ elem_classes="input-component",
349
+ ),
350
+ gr.Dropdown(
351
+ choices=CHECKPOINTS,
352
+ label="Select Model Checkpoint",
353
+ value=DEFAULT_CHECKPOINT,
354
+ elem_classes="input-component",
355
+ ),
356
+ gr.Slider(
357
+ minimum=0.1,
358
+ maximum=1.0,
359
+ value=DEFAULT_CONFIDENCE_THRESHOLD,
360
+ step=0.1,
361
+ label="Confidence Threshold",
362
+ elem_classes="input-component",
363
+ ),
364
+ ]
365
+
366
+
367
+ def create_video_inputs() -> List[gr.components.Component]:
368
+ return [
369
+ gr.Video(
370
+ label="Upload Video",
371
+ sources=["upload"],
372
+ interactive=True,
373
+ format="mp4", # Ensure MP4 format
374
+ elem_classes="input-component",
375
+ ),
376
+ gr.Dropdown(
377
+ choices=CHECKPOINTS,
378
+ label="Select Model Checkpoint",
379
+ value=DEFAULT_CHECKPOINT,
380
+ elem_classes="input-component",
381
+ ),
382
+ gr.Dropdown(
383
+ choices=TRACKERS,
384
+ label="Select Tracker (Optional)",
385
+ value=None,
386
+ elem_classes="input-component",
387
+ ),
388
+ gr.TextArea(
389
+ label="Specify Class Names to Detect (comma separated)",
390
+ value="all",
391
+ lines=1,
392
+ elem_classes="input-component",
393
+ ),
394
+ gr.Slider(
395
+ minimum=0.1,
396
+ maximum=1.0,
397
+ value=DEFAULT_CONFIDENCE_THRESHOLD,
398
+ step=0.1,
399
+ label="Confidence Threshold",
400
+ elem_classes="input-component",
401
+ ),
402
+ ]
403
+
404
+
405
+ def create_button_row() -> List[gr.Button]:
406
+ return [
407
+ gr.Button(
408
+ f"Detect Objects", variant="primary", elem_classes="action-button"
409
+ ),
410
+ gr.Button(f"Clear", variant="secondary", elem_classes="action-button"),
411
+ ]
412
+
413
+
414
+ # Gradio interface
415
+ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
416
+ gr.Markdown(
417
+ """
418
+ # Object Detection Demo
419
+ Experience state-of-the-art object detection with USTC's [D-FINE](https://huggingface.co/docs/transformers/main/model_doc/d_fine) models.
420
+ - **Image** and **Video** modes are supported.
421
+ - Select a model and adjust the confidence threshold to see detections!
422
+ - On video mode, you can enable tracking powered by [Supervision](https://github.com/roboflow/supervision) and [Trackers](https://github.com/roboflow/trackers) from Roboflow.
423
+ """,
424
+ elem_classes="header-text",
425
+ )
426
+
427
+ with gr.Tabs():
428
+ with gr.Tab("Image"):
429
+ with gr.Row():
430
+ with gr.Column(scale=1, min_width=300):
431
+ with gr.Group():
432
+ (
433
+ image_input,
434
+ use_url,
435
+ url_input,
436
+ image_model_checkpoint,
437
+ image_confidence_threshold,
438
+ ) = create_image_inputs()
439
+ image_detect_button, image_clear_button = create_button_row()
440
+ with gr.Column(scale=2):
441
+ image_output = gr.AnnotatedImage(
442
+ label="Detection Results",
443
+ show_label=True,
444
+ color_map=None,
445
+ elem_classes="output-component",
446
+ )
447
+ gr.Examples(
448
+ examples=[
449
+ [
450
+ DEFAULT_CHECKPOINT,
451
+ example["path"],
452
+ example["url"],
453
+ example["use_url"],
454
+ DEFAULT_CONFIDENCE_THRESHOLD,
455
+ ]
456
+ for example in IMAGE_EXAMPLES
457
+ ],
458
+ inputs=[
459
+ image_model_checkpoint,
460
+ image_input,
461
+ url_input,
462
+ use_url,
463
+ image_confidence_threshold,
464
+ ],
465
+ outputs=[image_output],
466
+ fn=process_image,
467
+ label="Select an image example to populate inputs",
468
+ cache_examples=True,
469
+ cache_mode="lazy",
470
+ )
471
+
472
+ with gr.Tab("Video"):
473
+ gr.Markdown(
474
+ f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)."
475
+ )
476
+ with gr.Row():
477
+ with gr.Column(scale=1, min_width=300):
478
+ with gr.Group():
479
+ video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold = create_video_inputs()
480
+ video_detect_button, video_clear_button = create_button_row()
481
+ with gr.Column(scale=2):
482
+ video_output = gr.Video(
483
+ label="Detection Results",
484
+ format="mp4", # Explicit MP4 format
485
+ elem_classes="output-component",
486
+ )
487
+
488
+ gr.Examples(
489
+ examples=[
490
+ [example["path"], DEFAULT_CHECKPOINT, example["tracker"], example["classes"], DEFAULT_CONFIDENCE_THRESHOLD]
491
+ for example in VIDEO_EXAMPLES
492
+ ],
493
+ inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
494
+ outputs=[video_output],
495
+ fn=process_video,
496
+ cache_examples=False,
497
+ label="Select a video example to populate inputs",
498
+ )
499
+
500
+ # Dynamic visibility for URL input
501
+ use_url.change(
502
+ fn=lambda x: gr.update(visible=x),
503
+ inputs=use_url,
504
+ outputs=url_input,
505
+ )
506
+
507
+ # Image clear button
508
+ image_clear_button.click(
509
+ fn=lambda: (
510
+ None,
511
+ False,
512
+ "",
513
+ DEFAULT_CHECKPOINT,
514
+ DEFAULT_CONFIDENCE_THRESHOLD,
515
+ None,
516
+ ),
517
+ outputs=[
518
+ image_input,
519
+ use_url,
520
+ url_input,
521
+ image_model_checkpoint,
522
+ image_confidence_threshold,
523
+ image_output,
524
+ ],
525
+ )
526
+
527
+ # Video clear button
528
+ video_clear_button.click(
529
+ fn=lambda: (
530
+ None,
531
+ DEFAULT_CHECKPOINT,
532
+ None,
533
+ "all",
534
+ DEFAULT_CONFIDENCE_THRESHOLD,
535
+ None,
536
+ ),
537
+ outputs=[
538
+ video_input,
539
+ video_checkpoint,
540
+ video_tracker,
541
+ video_classes,
542
+ video_confidence_threshold,
543
+ video_output,
544
+ ],
545
+ )
546
+
547
+ # Image detect button
548
+ image_detect_button.click(
549
+ fn=process_image,
550
+ inputs=[
551
+ image_model_checkpoint,
552
+ image_input,
553
+ url_input,
554
+ use_url,
555
+ image_confidence_threshold,
556
+ ],
557
+ outputs=[image_output],
558
+ )
559
+
560
+ # Video detect button
561
+ video_detect_button.click(
562
+ fn=process_video,
563
+ inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
564
+ outputs=[video_output],
565
+ )
566
+
567
+ if __name__ == "__main__":
568
+ demo.queue(max_size=20).launch()