File size: 3,699 Bytes
e0f4d55
 
 
 
 
 
 
 
 
7ad5fd3
e0f4d55
2f3566f
 
 
 
7ad5fd3
 
 
 
 
 
 
2f3566f
1b36964
2f3566f
 
 
 
 
 
e0f4d55
7ad5fd3
 
 
e0f4d55
2f3566f
e0f4d55
 
7ad5fd3
e0f4d55
 
7ad5fd3
e0f4d55
7ad5fd3
 
cf254d2
7ad5fd3
 
 
 
 
 
2f3566f
7ad5fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f4d55
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
# NOVIC Gradio Space

# Imports
import os
import glob
from typing import Optional
import PIL.Image
import gradio as gr
import app_novic

# Sample images
IMAGE_EXTS = ('jpg', 'jpeg', 'png', 'webp')
SAMPLE_IMAGES = sorted(image_path for image_ext in IMAGE_EXTS for image_path in glob.glob(os.path.join('sample_images', f'*.{image_ext}')))

# Model checkpoints
MODEL_CHECKPOINTS = {
	'NOVIC SigLIP B/16 FT2': os.path.join('ovod_20240610_105233', 'ovod_chunk0899_20240612_005748.train'),
	'NOVIC SigLIP SO/14 FT2': os.path.join('ovod_20240626_001447', 'ovod_chunk0899_20240627_112729.train'),
	'NOVIC DFN-5B H/14-378 FT2': os.path.join('ovod_20240620_162925', 'ovod_chunk0899_20240621_202727.train'),
	'NOVIC DFN-5B H/14-378 FT0': os.path.join('ovod_20240628_142131', 'ovod_chunk0433_20240630_235415.train'),
}
DEFAULT_MODEL = 'NOVIC DFN-5B H/14-378 FT0'

# Get a checkpoint path
def get_checkpoint(model: str) -> str:
	return os.path.join('checkpoints', MODEL_CHECKPOINTS[model])

# Ensure the default model is preloaded
app_novic.get_model(checkpoint=get_checkpoint(model=DEFAULT_MODEL))

# Classify an image
def classify_image(image: Optional[PIL.Image.Image], model: Optional[str]) -> dict[str, float]:
	if image is None or model is None:
		return {}
	return app_novic.classify_image(image=image, checkpoint=get_checkpoint(model=model))

# Gradio UI
with gr.Blocks(
	theme=None,
	analytics_enabled=True,
	title="🖼️ NOVIC Demo",
	fill_width=False,
) as demo:

	gr.HTML("<h1 style='text-align: center; margin-bottom: 1rem'>🖼️ NOVIC: Unconstrained Open Vocabulary Image Classification</h1><div style='text-align: center'><i>Select</i> an example image below <b>OR</b> <i>Upload</i> an image file <b>OR</b> <i>Capture</i> a camera image <b>OR</b> <i>Copy-paste</i> an image from your clipboard ⇒ The label predictions on the right will update automatically!<br>Note that inference on GPU is naturally <i>MUCH</i> faster (real-time) than the CPU inference in this demo. CPU inference is also slightly numerically different than proper GPU inference.<br><span style=\"margin-right: 20px;\"><b>GitHub:</b> <a href=\"https://github.com/pallgeuer/novic\" target=\"_blank\">https://github.com/pallgeuer/novic</a></span><span><b>Paper:</b> <a href=\"https://arxiv.org/abs/2407.11211\" target=\"_blank\">https://arxiv.org/abs/2407.11211</a></span></div>")

	with gr.Row(equal_height=True):

		with gr.Column(scale=1):
			input_model = gr.Dropdown(
				choices=list(MODEL_CHECKPOINTS),
				value=DEFAULT_MODEL,
				type='value',
				multiselect=False,
				allow_custom_value=False,
				filterable=False,
				label='NOVIC model',
				show_label=True,
				interactive=True,
			)
			input_image = gr.Image(
				height=400,
				image_mode='RGB',
				type='pil',
				label='Input image',
				show_label=True,
				interactive=True,
				show_fullscreen_button=True,
			)

		with gr.Column(scale=1):
			output_label = gr.Label(
				num_top_classes=3,
				label='Predicted label',
				show_label=True,
				scale=1,
				show_heading=True,
			)

	with gr.Row(equal_height=True):
		gr.ClearButton(
			components=[input_image],
			value='Clear input image',
			variant='secondary',
			size='lg',
			scale=1,
		)
		gr.DeepLinkButton(
			variant='secondary',
			size='lg',
			scale=1,
		)

	gr.Examples(
		examples=SAMPLE_IMAGES,
		inputs=input_image,
		cache_examples=False,
		examples_per_page=100,
		label='Example images',
	)

	gr.on(
		triggers=[input_image.change, input_model.change],
		fn=classify_image,
		inputs=[input_image, input_model],  # noqa
		outputs=output_label,
		api_name='classify',
		show_progress='full',
	)

# Run demo
if __name__ == '__main__':
	demo.launch()
# EOF