Spaces:
Running
Running
import os | |
import time | |
import shutil | |
from pathlib import Path | |
from typing import Optional | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
# Import your existing inference endpoint implementation | |
from handler import EndpointHandler | |
# ------------------------------------------------------------------------------ | |
# Asset setup: download weights/tags/mapping so local filenames are unchanged | |
# ------------------------------------------------------------------------------ | |
REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") | |
REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit | |
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look | |
# Optional: Hugging Face token for private repos | |
HF_TOKEN = ( | |
os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
or os.environ.get("HF_TOKEN") | |
or os.environ.get("HUGGINGFACE_TOKEN") | |
or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
) | |
REQUIRED_FILES = [ | |
"model_v0.9.pth", | |
"tags_v0.9_13k.json", | |
"char_ip_map.json", | |
] | |
def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str): | |
""" | |
1) snapshot_download the upstream repo (cached by HF Hub) | |
2) copy the required files into `target_dir` with the exact filenames expected | |
""" | |
target = Path(target_dir) | |
target.mkdir(parents=True, exist_ok=True) | |
# Only download if something is missing | |
missing = [f for f in REQUIRED_FILES if not (target / f).exists()] | |
if not missing: | |
return | |
# Download snapshot (optionally filtered to speed up) | |
snapshot_path = snapshot_download( | |
repo_id=repo_id, | |
revision=revision, | |
allow_patterns=REQUIRED_FILES, # only pull what we need | |
token=HF_TOKEN, # authenticate if repo is private | |
) | |
# Copy files into target_dir with the required names | |
for fname in REQUIRED_FILES: | |
src = Path(snapshot_path) / fname | |
dst = target / fname | |
if not src.exists(): | |
raise FileNotFoundError( | |
f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}" | |
) | |
shutil.copyfile(src, dst) | |
# Fetch assets (no-op if they already exist) | |
ensure_assets(REPO_ID, REVISION, MODEL_DIR) | |
# ------------------------------------------------------------------------------ | |
# Initialize the handler | |
# ------------------------------------------------------------------------------ | |
handler = EndpointHandler(MODEL_DIR) | |
DEVICE_LABEL = f"Device: {handler.device.upper()}" | |
# ------------------------------------------------------------------------------ | |
# Gradio wiring | |
# ------------------------------------------------------------------------------ | |
def run_inference( | |
source_choice: str, | |
image: Optional[Image.Image], | |
url: str, | |
general_threshold: float, | |
character_threshold: float, | |
mode_val: str, | |
topk_general_val: int, | |
topk_character_val: int, | |
include_scores_val: bool, | |
underscore_mode_val: bool, | |
): | |
# Determine which input to use based on which Run button invoked the function. | |
# We'll pass a string flag via source_choice: either "url" or "image". | |
if source_choice == "image": | |
if image is None: | |
raise gr.Error("Please upload an image.") | |
inputs = image | |
else: | |
if not url or not url.strip(): | |
raise gr.Error("Please provide an image URL.") | |
inputs = {"url": url.strip()} | |
params = { | |
"general_threshold": float(general_threshold), | |
"character_threshold": float(character_threshold), | |
"mode": mode_val, | |
"topk_general": int(topk_general_val), | |
"topk_character": int(topk_character_val), | |
"include_scores": bool(include_scores_val), | |
} | |
data = {"inputs": inputs, "parameters": params} | |
started = time.time() | |
try: | |
out = handler(data) | |
except Exception as e: | |
raise gr.Error(f"Inference error: {e}") from e | |
latency = round(time.time() - started, 4) | |
# Individual outputs | |
if underscore_mode_val: | |
characters = " ".join(out.get("character", [])) or "β" | |
ips = " ".join(out.get("ip", [])) or "β" | |
features = " ".join(out.get("feature", [])) or "β" | |
elif include_scores_val: | |
gen_scores = out.get("feature_scores", {}) | |
char_scores = out.get("character_scores", {}) | |
characters = ", ".join( | |
f"{k.replace('_', ' ')} ({char_scores[k]:.2f})" for k in sorted(char_scores, key=char_scores.get, reverse=True) | |
) or "β" | |
ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β" | |
features = ", ".join( | |
f"{k.replace('_', ' ')} ({gen_scores[k]:.2f})" for k in sorted(gen_scores, key=gen_scores.get, reverse=True) | |
) or "β" | |
else: | |
characters = ", ".join(sorted(t.replace("_", " ") for t in out.get("character", []))) or "β" | |
ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β" | |
features = ", ".join(sorted(t.replace("_", " ") for t in out.get("feature", []))) or "β" | |
# Combined output: probability-descending if scores available; else character, IP, general | |
if underscore_mode_val: | |
combined = " ".join(out.get("character", []) + out.get("ip", []) + out.get("feature", [])) or "β" | |
else: | |
char_scores = out.get("character_scores") or {} | |
gen_scores = out.get("feature_scores") or {} | |
if include_scores_val and (char_scores or gen_scores): | |
# Build (tag, score) pairs | |
char_pairs = [(k, float(char_scores.get(k, 0.0))) for k in out.get("character", [])] | |
ip_pairs = [(k, 1.0) for k in out.get("ip", [])] # IP has no score; treat equally | |
gen_pairs = [(k, float(gen_scores.get(k, 0.0))) for k in out.get("feature", [])] | |
all_pairs = char_pairs + ip_pairs + gen_pairs | |
all_pairs.sort(key=lambda t: t[1], reverse=True) | |
combined = ", ".join( | |
[f"{k.replace('_', ' ')} ({score:.2f})" if (k in char_scores or k in gen_scores) else k.replace('_', ' ') for k, score in all_pairs] | |
) or "β" | |
else: | |
combined = ", ".join( | |
list(sorted(t.replace("_", " ") for t in out.get("character", []))) + | |
[tag.replace("_", " ") for tag in out.get("ip", [])] + | |
list(sorted(t.replace("_", " ") for t in out.get("feature", []))) | |
) or "β" | |
meta = { | |
"device": handler.device, | |
"latency_s_total": latency, | |
**out.get("_timings", {}), | |
"params": out.get("_params", {}), | |
} | |
return features, characters, ips, combined, meta, out | |
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="violet", radius_size="lg",) | |
with gr.Blocks(title="PixAI Tagger v0.9 β Demo", fill_height=True, theme=theme, analytics_enabled=False) as demo: | |
gr.Markdown( | |
""" | |
# PixAI Tagger v0.9 β Gradio Demo | |
""" | |
) | |
with gr.Row(): | |
gr.Markdown(f"**{DEVICE_LABEL}** β adjust thresholds or switch to Top-K mode.") | |
with gr.Accordion("Settings", open=False): | |
mode = gr.Radio( | |
choices=["threshold", "topk"], value="threshold", label="Mode" | |
) | |
with gr.Group(visible=True) as threshold_group: | |
general_threshold = gr.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold" | |
) | |
character_threshold = gr.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold" | |
) | |
with gr.Group(visible=False) as topk_group: | |
topk_general = gr.Slider( | |
minimum=0, maximum=100, step=1, value=25, label="Top-K general" | |
) | |
topk_character = gr.Slider( | |
minimum=0, maximum=100, step=1, value=10, label="Top-K character" | |
) | |
include_scores = gr.Checkbox(value=False, label="Include scores in output") | |
underscore_mode = gr.Checkbox(value=False, label="Underscore-separated output") | |
def toggle_mode(selected): | |
return ( | |
gr.update(visible=(selected == "threshold")), | |
gr.update(visible=(selected == "topk")), | |
) | |
mode.change(toggle_mode, inputs=[mode], outputs=[threshold_group, topk_group]) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=2): | |
image = gr.Image(label="Upload image", type="pil", visible=True, height="420px") | |
url = gr.Textbox(label="Image URL", placeholder="https://β¦", visible=True) | |
def toggle_inputs(choice): | |
return ( | |
gr.update(visible=(choice == "Upload image")), | |
gr.update(visible=(choice == "From URL")), | |
) | |
with gr.Column(scale=3): | |
# No source choice; show both inputs and two run buttons | |
with gr.Row(): | |
run_image_btn = gr.Button("Run from image", variant="primary") | |
run_url_btn = gr.Button("Run from URL") | |
clear_btn = gr.Button("Clear") | |
gr.Markdown("### Combined Output (character β IP β general)") | |
combined_out = gr.Textbox(label="Combined tags", lines=10,) | |
copy_combined = gr.Button("Copy combined") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Character / General / IP") | |
with gr.Row(): | |
with gr.Column(): | |
characters_out = gr.Textbox(label="Character tags", lines=5,) | |
with gr.Column(): | |
features_out = gr.Textbox(label="General tags", lines=5,) | |
with gr.Column(): | |
ip_out = gr.Textbox(label="IP tags", lines=5,) | |
with gr.Row(): | |
copy_characters = gr.Button("Copy character") | |
copy_features = gr.Button("Copy general") | |
copy_ip = gr.Button("Copy IP") | |
with gr.Accordion("Metadata & Raw Output", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
meta_out = gr.JSON(label="Timings/Device") | |
raw_out = gr.JSON(label="Raw JSON") | |
copy_raw = gr.Button("Copy raw JSON") | |
examples = gr.Examples( | |
label="Examples (URL mode)", | |
examples=[ | |
[None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85, "threshold", 25, 10, False, False], | |
], | |
inputs=[image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode], | |
cache_examples=False, | |
) | |
def clear(): | |
return (None, "", 0.30, 0.85, "", "", "", "", {}, {}) | |
# Bind buttons separately with a flag for source | |
run_url_btn.click( | |
run_inference, | |
inputs=[ | |
gr.State("url"), image, url, | |
general_threshold, character_threshold, | |
mode, topk_general, topk_character, include_scores, underscore_mode, | |
], | |
outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out], | |
api_name="predict_url", | |
) | |
run_image_btn.click( | |
run_inference, | |
inputs=[ | |
gr.State("image"), image, url, | |
general_threshold, character_threshold, | |
mode, topk_general, topk_character, include_scores, underscore_mode, | |
], | |
outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out], | |
api_name="predict_image", | |
) | |
# Copy buttons | |
copy_combined.click(lambda x: x, inputs=[combined_out], outputs=[combined_out]) | |
copy_characters.click(lambda x: x, inputs=[characters_out], outputs=[characters_out]) | |
copy_features.click(lambda x: x, inputs=[features_out], outputs=[features_out]) | |
copy_ip.click(lambda x: x, inputs=[ip_out], outputs=[ip_out]) | |
copy_raw.click(lambda x: x, inputs=[raw_out], outputs=[raw_out]) | |
clear_btn.click( | |
clear, | |
inputs=None, | |
outputs=[ | |
image, url, general_threshold, character_threshold, | |
features_out, characters_out, ip_out, meta_out, raw_out | |
], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=8).launch() | |