File size: 12,397 Bytes
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc07199
 
 
 
 
 
 
 
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc07199
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20d8f48
 
 
 
 
1412dfd
20d8f48
 
 
1412dfd
 
 
 
 
 
 
 
20d8f48
 
 
 
 
 
 
1412dfd
20d8f48
1412dfd
 
 
 
 
 
 
 
20d8f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1412dfd
 
 
 
 
20d8f48
1412dfd
 
20d8f48
 
1412dfd
20d8f48
1412dfd
20d8f48
1412dfd
 
 
 
 
 
20d8f48
1412dfd
20d8f48
 
 
1412dfd
20d8f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1412dfd
 
 
20d8f48
 
1412dfd
 
 
 
 
 
 
 
 
20d8f48
 
 
 
 
 
1412dfd
20d8f48
 
 
 
 
1412dfd
20d8f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1412dfd
 
 
 
20d8f48
1412dfd
20d8f48
1412dfd
 
 
 
20d8f48
 
 
 
 
 
 
 
 
 
 
 
 
1412dfd
20d8f48
1412dfd
20d8f48
 
 
 
 
 
 
1412dfd
20d8f48
 
 
 
 
 
 
1412dfd
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import os
import time
import shutil
from pathlib import Path
from typing import Optional

import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image

# Import your existing inference endpoint implementation
from handler import EndpointHandler


# ------------------------------------------------------------------------------
# Asset setup: download weights/tags/mapping so local filenames are unchanged
# ------------------------------------------------------------------------------

REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
REVISION = os.environ.get("ASSETS_REVISION")  # optional pin, e.g. "main" or a commit
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets")  # where the handler will look

# Optional: Hugging Face token for private repos
HF_TOKEN = (
    os.environ.get("HUGGINGFACE_HUB_TOKEN")
    or os.environ.get("HF_TOKEN")
    or os.environ.get("HUGGINGFACE_TOKEN")
    or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)

REQUIRED_FILES = [
    "model_v0.9.pth",
    "tags_v0.9_13k.json",
    "char_ip_map.json",
]

def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str):
    """
    1) snapshot_download the upstream repo (cached by HF Hub)
    2) copy the required files into `target_dir` with the exact filenames expected
    """
    target = Path(target_dir)
    target.mkdir(parents=True, exist_ok=True)

    # Only download if something is missing
    missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
    if not missing:
        return

    # Download snapshot (optionally filtered to speed up)
    snapshot_path = snapshot_download(
        repo_id=repo_id,
        revision=revision,
        allow_patterns=REQUIRED_FILES,  # only pull what we need
        token=HF_TOKEN,  # authenticate if repo is private
    )

    # Copy files into target_dir with the required names
    for fname in REQUIRED_FILES:
        src = Path(snapshot_path) / fname
        dst = target / fname
        if not src.exists():
            raise FileNotFoundError(
                f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}"
            )
        shutil.copyfile(src, dst)


# Fetch assets (no-op if they already exist)
ensure_assets(REPO_ID, REVISION, MODEL_DIR)


# ------------------------------------------------------------------------------
# Initialize the handler
# ------------------------------------------------------------------------------

handler = EndpointHandler(MODEL_DIR)
DEVICE_LABEL = f"Device: {handler.device.upper()}"


# ------------------------------------------------------------------------------
# Gradio wiring
# ------------------------------------------------------------------------------

def run_inference(
    source_choice: str,
    image: Optional[Image.Image],
    url: str,
    general_threshold: float,
    character_threshold: float,
    mode_val: str,
    topk_general_val: int,
    topk_character_val: int,
    include_scores_val: bool,
    underscore_mode_val: bool,
):
    # Determine which input to use based on which Run button invoked the function.
    # We'll pass a string flag via source_choice: either "url" or "image".
    if source_choice == "image":
        if image is None:
            raise gr.Error("Please upload an image.")
        inputs = image
    else:
        if not url or not url.strip():
            raise gr.Error("Please provide an image URL.")
        inputs = {"url": url.strip()}

    params = {
        "general_threshold": float(general_threshold),
        "character_threshold": float(character_threshold),
        "mode": mode_val,
        "topk_general": int(topk_general_val),
        "topk_character": int(topk_character_val),
        "include_scores": bool(include_scores_val),
    }
    data = {"inputs": inputs, "parameters": params}

    started = time.time()
    try:
        out = handler(data)
    except Exception as e:
        raise gr.Error(f"Inference error: {e}") from e
    latency = round(time.time() - started, 4)

    # Individual outputs
    if underscore_mode_val:
        characters = " ".join(out.get("character", [])) or "β€”"
        ips = " ".join(out.get("ip", [])) or "β€”"
        features = " ".join(out.get("feature", [])) or "β€”"
    elif include_scores_val:
        gen_scores = out.get("feature_scores", {})
        char_scores = out.get("character_scores", {})
        characters = ", ".join(
            f"{k.replace('_', ' ')} ({char_scores[k]:.2f})" for k in sorted(char_scores, key=char_scores.get, reverse=True)
        ) or "β€”"
        ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β€”"
        features = ", ".join(
            f"{k.replace('_', ' ')} ({gen_scores[k]:.2f})" for k in sorted(gen_scores, key=gen_scores.get, reverse=True)
        ) or "β€”"
    else:
        characters = ", ".join(sorted(t.replace("_", " ") for t in out.get("character", []))) or "β€”"
        ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β€”"
        features = ", ".join(sorted(t.replace("_", " ") for t in out.get("feature", []))) or "β€”"

    # Combined output: probability-descending if scores available; else character, IP, general
    if underscore_mode_val:
        combined = " ".join(out.get("character", []) + out.get("ip", []) + out.get("feature", [])) or "β€”"
    else:
        char_scores = out.get("character_scores") or {}
        gen_scores = out.get("feature_scores") or {}
        if include_scores_val and (char_scores or gen_scores):
            # Build (tag, score) pairs
            char_pairs = [(k, float(char_scores.get(k, 0.0))) for k in out.get("character", [])]
            ip_pairs = [(k, 1.0) for k in out.get("ip", [])]  # IP has no score; treat equally
            gen_pairs = [(k, float(gen_scores.get(k, 0.0))) for k in out.get("feature", [])]
            all_pairs = char_pairs + ip_pairs + gen_pairs
            all_pairs.sort(key=lambda t: t[1], reverse=True)
            combined = ", ".join(
                [f"{k.replace('_', ' ')} ({score:.2f})" if (k in char_scores or k in gen_scores) else k.replace('_', ' ') for k, score in all_pairs]
            ) or "β€”"
        else:
            combined = ", ".join(
                list(sorted(t.replace("_", " ") for t in out.get("character", []))) +
                [tag.replace("_", " ") for tag in out.get("ip", [])] +
                list(sorted(t.replace("_", " ") for t in out.get("feature", [])))
            ) or "β€”"

    meta = {
        "device": handler.device,
        "latency_s_total": latency,
        **out.get("_timings", {}),
        "params": out.get("_params", {}),
    }

    return features, characters, ips, combined, meta, out


theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="violet", radius_size="lg",)

with gr.Blocks(title="PixAI Tagger v0.9 β€” Demo", fill_height=True, theme=theme, analytics_enabled=False) as demo:
    gr.Markdown(
        """
        # PixAI Tagger v0.9 β€” Gradio Demo
        """
    )
    with gr.Row():
        gr.Markdown(f"**{DEVICE_LABEL}** β€” adjust thresholds or switch to Top-K mode.")

    with gr.Accordion("Settings", open=False):
        mode = gr.Radio(
            choices=["threshold", "topk"], value="threshold", label="Mode"
        )
        with gr.Group(visible=True) as threshold_group:
            general_threshold = gr.Slider(
                minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold"
            )
            character_threshold = gr.Slider(
                minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold"
            )
        with gr.Group(visible=False) as topk_group:
            topk_general = gr.Slider(
                minimum=0, maximum=100, step=1, value=25, label="Top-K general"
            )
            topk_character = gr.Slider(
                minimum=0, maximum=100, step=1, value=10, label="Top-K character"
            )
        include_scores = gr.Checkbox(value=False, label="Include scores in output")
        underscore_mode = gr.Checkbox(value=False, label="Underscore-separated output")

        def toggle_mode(selected):
            return (
                gr.update(visible=(selected == "threshold")),
                gr.update(visible=(selected == "topk")),
            )

        mode.change(toggle_mode, inputs=[mode], outputs=[threshold_group, topk_group])

    with gr.Row(variant="panel"):
        with gr.Column(scale=2):
            image = gr.Image(label="Upload image", type="pil", visible=True, height="420px")
            url = gr.Textbox(label="Image URL", placeholder="https://…", visible=True)

            def toggle_inputs(choice):
                return (
                    gr.update(visible=(choice == "Upload image")),
                    gr.update(visible=(choice == "From URL")),
                )



        with gr.Column(scale=3):
            # No source choice; show both inputs and two run buttons
            with gr.Row():
                run_image_btn = gr.Button("Run from image", variant="primary")
                run_url_btn = gr.Button("Run from URL")
                clear_btn = gr.Button("Clear")

            gr.Markdown("### Combined Output (character β†’ IP β†’ general)")
            combined_out = gr.Textbox(label="Combined tags", lines=10,)
            copy_combined = gr.Button("Copy combined")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Character / General / IP")
            with gr.Row():
                with gr.Column():
                    characters_out = gr.Textbox(label="Character tags", lines=5,)
                with gr.Column():
                    features_out = gr.Textbox(label="General tags", lines=5,)
                with gr.Column():
                    ip_out = gr.Textbox(label="IP tags", lines=5,)
            with gr.Row():
                copy_characters = gr.Button("Copy character")
                copy_features = gr.Button("Copy general")
                copy_ip = gr.Button("Copy IP")

    with gr.Accordion("Metadata & Raw Output", open=False):
        with gr.Row():
            with gr.Column():
                meta_out = gr.JSON(label="Timings/Device")
                raw_out = gr.JSON(label="Raw JSON")
                copy_raw = gr.Button("Copy raw JSON")

    examples = gr.Examples(
        label="Examples (URL mode)",
        examples=[
            [None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85, "threshold", 25, 10, False, False],
        ],
        inputs=[image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode],
        cache_examples=False,
    )

    def clear():
        return (None, "", 0.30, 0.85, "", "", "", "", {}, {})

    # Bind buttons separately with a flag for source
    run_url_btn.click(
        run_inference,
        inputs=[
            gr.State("url"), image, url,
            general_threshold, character_threshold,
            mode, topk_general, topk_character, include_scores, underscore_mode,
        ],
        outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out],
        api_name="predict_url",
    )

    run_image_btn.click(
        run_inference,
        inputs=[
            gr.State("image"), image, url,
            general_threshold, character_threshold,
            mode, topk_general, topk_character, include_scores, underscore_mode,
        ],
        outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out],
        api_name="predict_image",
    )

    # Copy buttons
    copy_combined.click(lambda x: x, inputs=[combined_out], outputs=[combined_out])
    copy_characters.click(lambda x: x, inputs=[characters_out], outputs=[characters_out])
    copy_features.click(lambda x: x, inputs=[features_out], outputs=[features_out])
    copy_ip.click(lambda x: x, inputs=[ip_out], outputs=[ip_out])
    copy_raw.click(lambda x: x, inputs=[raw_out], outputs=[raw_out])
    clear_btn.click(
        clear,
        inputs=None,
        outputs=[
            image, url, general_threshold, character_threshold,
            features_out, characters_out, ip_out, meta_out, raw_out
        ],
    )

if __name__ == "__main__":
    demo.queue(max_size=8).launch()