import os import time import shutil from pathlib import Path from typing import Optional import gradio as gr from huggingface_hub import snapshot_download from PIL import Image # Import your existing inference endpoint implementation from handler import EndpointHandler # ------------------------------------------------------------------------------ # Asset setup: download weights/tags/mapping so local filenames are unchanged # ------------------------------------------------------------------------------ REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look # Optional: Hugging Face token for private repos HF_TOKEN = ( os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) REQUIRED_FILES = [ "model_v0.9.pth", "tags_v0.9_13k.json", "char_ip_map.json", ] def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str): """ 1) snapshot_download the upstream repo (cached by HF Hub) 2) copy the required files into `target_dir` with the exact filenames expected """ target = Path(target_dir) target.mkdir(parents=True, exist_ok=True) # Only download if something is missing missing = [f for f in REQUIRED_FILES if not (target / f).exists()] if not missing: return # Download snapshot (optionally filtered to speed up) snapshot_path = snapshot_download( repo_id=repo_id, revision=revision, allow_patterns=REQUIRED_FILES, # only pull what we need token=HF_TOKEN, # authenticate if repo is private ) # Copy files into target_dir with the required names for fname in REQUIRED_FILES: src = Path(snapshot_path) / fname dst = target / fname if not src.exists(): raise FileNotFoundError( f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}" ) shutil.copyfile(src, dst) # Fetch assets (no-op if they already exist) ensure_assets(REPO_ID, REVISION, MODEL_DIR) # ------------------------------------------------------------------------------ # Initialize the handler # ------------------------------------------------------------------------------ handler = EndpointHandler(MODEL_DIR) DEVICE_LABEL = f"Device: {handler.device.upper()}" # ------------------------------------------------------------------------------ # Gradio wiring # ------------------------------------------------------------------------------ def run_inference( source_choice: str, image: Optional[Image.Image], url: str, general_threshold: float, character_threshold: float, mode_val: str, topk_general_val: int, topk_character_val: int, include_scores_val: bool, underscore_mode_val: bool, ): # Determine which input to use based on which Run button invoked the function. # We'll pass a string flag via source_choice: either "url" or "image". if source_choice == "image": if image is None: raise gr.Error("Please upload an image.") inputs = image else: if not url or not url.strip(): raise gr.Error("Please provide an image URL.") inputs = {"url": url.strip()} params = { "general_threshold": float(general_threshold), "character_threshold": float(character_threshold), "mode": mode_val, "topk_general": int(topk_general_val), "topk_character": int(topk_character_val), "include_scores": bool(include_scores_val), } data = {"inputs": inputs, "parameters": params} started = time.time() try: out = handler(data) except Exception as e: raise gr.Error(f"Inference error: {e}") from e latency = round(time.time() - started, 4) # Individual outputs if underscore_mode_val: characters = " ".join(out.get("character", [])) or "—" ips = " ".join(out.get("ip", [])) or "—" features = " ".join(out.get("feature", [])) or "—" elif include_scores_val: gen_scores = out.get("feature_scores", {}) char_scores = out.get("character_scores", {}) characters = ", ".join( f"{k.replace('_', ' ')} ({char_scores[k]:.2f})" for k in sorted(char_scores, key=char_scores.get, reverse=True) ) or "—" ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "—" features = ", ".join( f"{k.replace('_', ' ')} ({gen_scores[k]:.2f})" for k in sorted(gen_scores, key=gen_scores.get, reverse=True) ) or "—" else: characters = ", ".join(sorted(t.replace("_", " ") for t in out.get("character", []))) or "—" ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "—" features = ", ".join(sorted(t.replace("_", " ") for t in out.get("feature", []))) or "—" # Combined output: probability-descending if scores available; else character, IP, general if underscore_mode_val: combined = " ".join(out.get("character", []) + out.get("ip", []) + out.get("feature", [])) or "—" else: char_scores = out.get("character_scores") or {} gen_scores = out.get("feature_scores") or {} if include_scores_val and (char_scores or gen_scores): # Build (tag, score) pairs char_pairs = [(k, float(char_scores.get(k, 0.0))) for k in out.get("character", [])] ip_pairs = [(k, 1.0) for k in out.get("ip", [])] # IP has no score; treat equally gen_pairs = [(k, float(gen_scores.get(k, 0.0))) for k in out.get("feature", [])] all_pairs = char_pairs + ip_pairs + gen_pairs all_pairs.sort(key=lambda t: t[1], reverse=True) combined = ", ".join( [f"{k.replace('_', ' ')} ({score:.2f})" if (k in char_scores or k in gen_scores) else k.replace('_', ' ') for k, score in all_pairs] ) or "—" else: combined = ", ".join( list(sorted(t.replace("_", " ") for t in out.get("character", []))) + [tag.replace("_", " ") for tag in out.get("ip", [])] + list(sorted(t.replace("_", " ") for t in out.get("feature", []))) ) or "—" meta = { "device": handler.device, "latency_s_total": latency, **out.get("_timings", {}), "params": out.get("_params", {}), } return features, characters, ips, combined, meta, out theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="violet", radius_size="lg",) with gr.Blocks(title="PixAI Tagger v0.9 — Demo", fill_height=True, theme=theme, analytics_enabled=False) as demo: gr.Markdown( """ # PixAI Tagger v0.9 — Gradio Demo """ ) with gr.Row(): gr.Markdown(f"**{DEVICE_LABEL}** — adjust thresholds or switch to Top-K mode.") with gr.Accordion("Settings", open=False): mode = gr.Radio( choices=["threshold", "topk"], value="threshold", label="Mode" ) with gr.Group(visible=True) as threshold_group: general_threshold = gr.Slider( minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold" ) character_threshold = gr.Slider( minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold" ) with gr.Group(visible=False) as topk_group: topk_general = gr.Slider( minimum=0, maximum=100, step=1, value=25, label="Top-K general" ) topk_character = gr.Slider( minimum=0, maximum=100, step=1, value=10, label="Top-K character" ) include_scores = gr.Checkbox(value=False, label="Include scores in output") underscore_mode = gr.Checkbox(value=False, label="Underscore-separated output") def toggle_mode(selected): return ( gr.update(visible=(selected == "threshold")), gr.update(visible=(selected == "topk")), ) mode.change(toggle_mode, inputs=[mode], outputs=[threshold_group, topk_group]) with gr.Row(variant="panel"): with gr.Column(scale=2): image = gr.Image(label="Upload image", type="pil", visible=True, height="420px") url = gr.Textbox(label="Image URL", placeholder="https://…", visible=True) def toggle_inputs(choice): return ( gr.update(visible=(choice == "Upload image")), gr.update(visible=(choice == "From URL")), ) with gr.Column(scale=3): # No source choice; show both inputs and two run buttons with gr.Row(): run_image_btn = gr.Button("Run from image", variant="primary") run_url_btn = gr.Button("Run from URL") clear_btn = gr.Button("Clear") gr.Markdown("### Combined Output (character → IP → general)") combined_out = gr.Textbox(label="Combined tags", lines=10,) copy_combined = gr.Button("Copy combined") with gr.Row(): with gr.Column(): gr.Markdown("### Character / General / IP") with gr.Row(): with gr.Column(): characters_out = gr.Textbox(label="Character tags", lines=5,) with gr.Column(): features_out = gr.Textbox(label="General tags", lines=5,) with gr.Column(): ip_out = gr.Textbox(label="IP tags", lines=5,) with gr.Row(): copy_characters = gr.Button("Copy character") copy_features = gr.Button("Copy general") copy_ip = gr.Button("Copy IP") with gr.Accordion("Metadata & Raw Output", open=False): with gr.Row(): with gr.Column(): meta_out = gr.JSON(label="Timings/Device") raw_out = gr.JSON(label="Raw JSON") copy_raw = gr.Button("Copy raw JSON") examples = gr.Examples( label="Examples (URL mode)", examples=[ [None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85, "threshold", 25, 10, False, False], ], inputs=[image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode], cache_examples=False, ) def clear(): return (None, "", 0.30, 0.85, "", "", "", "", {}, {}) # Bind buttons separately with a flag for source run_url_btn.click( run_inference, inputs=[ gr.State("url"), image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode, ], outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out], api_name="predict_url", ) run_image_btn.click( run_inference, inputs=[ gr.State("image"), image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode, ], outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out], api_name="predict_image", ) # Copy buttons copy_combined.click(lambda x: x, inputs=[combined_out], outputs=[combined_out]) copy_characters.click(lambda x: x, inputs=[characters_out], outputs=[characters_out]) copy_features.click(lambda x: x, inputs=[features_out], outputs=[features_out]) copy_ip.click(lambda x: x, inputs=[ip_out], outputs=[ip_out]) copy_raw.click(lambda x: x, inputs=[raw_out], outputs=[raw_out]) clear_btn.click( clear, inputs=None, outputs=[ image, url, general_threshold, character_threshold, features_out, characters_out, ip_out, meta_out, raw_out ], ) if __name__ == "__main__": demo.queue(max_size=8).launch()