trojblue's picture
update app.py
539a789
import os
import time
import shutil
from pathlib import Path
from typing import Optional
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
# Import your existing inference endpoint implementation
from handler import EndpointHandler
# ------------------------------------------------------------------------------
# Asset setup: download weights/tags/mapping so local filenames are unchanged
# ------------------------------------------------------------------------------
REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look
# Optional: Hugging Face token for private repos
HF_TOKEN = (
os.environ.get("HUGGINGFACE_HUB_TOKEN")
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACE_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
REQUIRED_FILES = [
"model_v0.9.pth",
"tags_v0.9_13k.json",
"char_ip_map.json",
]
def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str):
"""
1) snapshot_download the upstream repo (cached by HF Hub)
2) copy the required files into `target_dir` with the exact filenames expected
"""
target = Path(target_dir)
target.mkdir(parents=True, exist_ok=True)
# Only download if something is missing
missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
if not missing:
return
# Download snapshot (optionally filtered to speed up)
snapshot_path = snapshot_download(
repo_id=repo_id,
revision=revision,
allow_patterns=REQUIRED_FILES, # only pull what we need
token=HF_TOKEN, # authenticate if repo is private
)
# Copy files into target_dir with the required names
for fname in REQUIRED_FILES:
src = Path(snapshot_path) / fname
dst = target / fname
if not src.exists():
raise FileNotFoundError(
f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}"
)
shutil.copyfile(src, dst)
# Fetch assets (no-op if they already exist)
ensure_assets(REPO_ID, REVISION, MODEL_DIR)
# ------------------------------------------------------------------------------
# Initialize the handler
# ------------------------------------------------------------------------------
handler = EndpointHandler(MODEL_DIR)
DEVICE_LABEL = f"Device: {handler.device.upper()}"
# ------------------------------------------------------------------------------
# Gradio wiring
# ------------------------------------------------------------------------------
def run_inference(
source_choice: str,
image: Optional[Image.Image],
url: str,
general_threshold: float,
character_threshold: float,
mode_val: str,
topk_general_val: int,
topk_character_val: int,
include_scores_val: bool,
underscore_mode_val: bool,
):
# Determine which input to use based on which Run button invoked the function.
# We'll pass a string flag via source_choice: either "url" or "image".
if source_choice == "image":
if image is None:
raise gr.Error("Please upload an image.")
inputs = image
else:
if not url or not url.strip():
raise gr.Error("Please provide an image URL.")
inputs = {"url": url.strip()}
params = {
"general_threshold": float(general_threshold),
"character_threshold": float(character_threshold),
"mode": mode_val,
"topk_general": int(topk_general_val),
"topk_character": int(topk_character_val),
"include_scores": bool(include_scores_val),
}
data = {"inputs": inputs, "parameters": params}
started = time.time()
try:
out = handler(data)
except Exception as e:
raise gr.Error(f"Inference error: {e}") from e
latency = round(time.time() - started, 4)
# Individual outputs
if underscore_mode_val:
characters = " ".join(out.get("character", [])) or "β€”"
ips = " ".join(out.get("ip", [])) or "β€”"
features = " ".join(out.get("feature", [])) or "β€”"
elif include_scores_val:
gen_scores = out.get("feature_scores", {})
char_scores = out.get("character_scores", {})
characters = ", ".join(
f"{k.replace('_', ' ')} ({char_scores[k]:.2f})" for k in sorted(char_scores, key=char_scores.get, reverse=True)
) or "β€”"
ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β€”"
features = ", ".join(
f"{k.replace('_', ' ')} ({gen_scores[k]:.2f})" for k in sorted(gen_scores, key=gen_scores.get, reverse=True)
) or "β€”"
else:
characters = ", ".join(sorted(t.replace("_", " ") for t in out.get("character", []))) or "β€”"
ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β€”"
features = ", ".join(sorted(t.replace("_", " ") for t in out.get("feature", []))) or "β€”"
# Combined output: probability-descending if scores available; else character, IP, general
if underscore_mode_val:
combined = " ".join(out.get("character", []) + out.get("ip", []) + out.get("feature", [])) or "β€”"
else:
char_scores = out.get("character_scores") or {}
gen_scores = out.get("feature_scores") or {}
if include_scores_val and (char_scores or gen_scores):
# Build (tag, score) pairs
char_pairs = [(k, float(char_scores.get(k, 0.0))) for k in out.get("character", [])]
ip_pairs = [(k, 1.0) for k in out.get("ip", [])] # IP has no score; treat equally
gen_pairs = [(k, float(gen_scores.get(k, 0.0))) for k in out.get("feature", [])]
all_pairs = char_pairs + ip_pairs + gen_pairs
all_pairs.sort(key=lambda t: t[1], reverse=True)
combined = ", ".join(
[f"{k.replace('_', ' ')} ({score:.2f})" if (k in char_scores or k in gen_scores) else k.replace('_', ' ') for k, score in all_pairs]
) or "β€”"
else:
combined = ", ".join(
list(sorted(t.replace("_", " ") for t in out.get("character", []))) +
[tag.replace("_", " ") for tag in out.get("ip", [])] +
list(sorted(t.replace("_", " ") for t in out.get("feature", [])))
) or "β€”"
meta = {
"device": handler.device,
"latency_s_total": latency,
**out.get("_timings", {}),
"params": out.get("_params", {}),
}
return features, characters, ips, combined, meta, out
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="violet", radius_size="lg",)
with gr.Blocks(title="PixAI Tagger v0.9 β€” Demo", fill_height=True, theme=theme, analytics_enabled=False) as demo:
gr.Markdown(
"""
# PixAI Tagger v0.9 β€” Gradio Demo
"""
)
with gr.Row():
gr.Markdown(f"**{DEVICE_LABEL}** β€” adjust thresholds or switch to Top-K mode.")
with gr.Accordion("Settings", open=False):
mode = gr.Radio(
choices=["threshold", "topk"], value="threshold", label="Mode"
)
with gr.Group(visible=True) as threshold_group:
general_threshold = gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold"
)
character_threshold = gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold"
)
with gr.Group(visible=False) as topk_group:
topk_general = gr.Slider(
minimum=0, maximum=100, step=1, value=25, label="Top-K general"
)
topk_character = gr.Slider(
minimum=0, maximum=100, step=1, value=10, label="Top-K character"
)
include_scores = gr.Checkbox(value=False, label="Include scores in output")
underscore_mode = gr.Checkbox(value=False, label="Underscore-separated output")
def toggle_mode(selected):
return (
gr.update(visible=(selected == "threshold")),
gr.update(visible=(selected == "topk")),
)
mode.change(toggle_mode, inputs=[mode], outputs=[threshold_group, topk_group])
with gr.Row(variant="panel"):
with gr.Column(scale=2):
image = gr.Image(label="Upload image", type="pil", visible=True, height="420px")
url = gr.Textbox(label="Image URL", placeholder="https://…", visible=True)
def toggle_inputs(choice):
return (
gr.update(visible=(choice == "Upload image")),
gr.update(visible=(choice == "From URL")),
)
with gr.Column(scale=3):
# No source choice; show both inputs and two run buttons
with gr.Row():
run_image_btn = gr.Button("Run from image", variant="primary")
run_url_btn = gr.Button("Run from URL")
clear_btn = gr.Button("Clear")
gr.Markdown("### Combined Output (character β†’ IP β†’ general)")
combined_out = gr.Textbox(label="Combined tags", lines=10,)
copy_combined = gr.Button("Copy combined")
with gr.Row():
with gr.Column():
gr.Markdown("### Character / General / IP")
with gr.Row():
with gr.Column():
characters_out = gr.Textbox(label="Character tags", lines=5,)
with gr.Column():
features_out = gr.Textbox(label="General tags", lines=5,)
with gr.Column():
ip_out = gr.Textbox(label="IP tags", lines=5,)
with gr.Row():
copy_characters = gr.Button("Copy character")
copy_features = gr.Button("Copy general")
copy_ip = gr.Button("Copy IP")
with gr.Accordion("Metadata & Raw Output", open=False):
with gr.Row():
with gr.Column():
meta_out = gr.JSON(label="Timings/Device")
raw_out = gr.JSON(label="Raw JSON")
copy_raw = gr.Button("Copy raw JSON")
examples = gr.Examples(
label="Examples (URL mode)",
examples=[
[None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85, "threshold", 25, 10, False, False],
],
inputs=[image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode],
cache_examples=False,
)
def clear():
return (None, "", 0.30, 0.85, "", "", "", "", {}, {})
# Bind buttons separately with a flag for source
run_url_btn.click(
run_inference,
inputs=[
gr.State("url"), image, url,
general_threshold, character_threshold,
mode, topk_general, topk_character, include_scores, underscore_mode,
],
outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out],
api_name="predict_url",
)
run_image_btn.click(
run_inference,
inputs=[
gr.State("image"), image, url,
general_threshold, character_threshold,
mode, topk_general, topk_character, include_scores, underscore_mode,
],
outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out],
api_name="predict_image",
)
# Copy buttons
copy_combined.click(lambda x: x, inputs=[combined_out], outputs=[combined_out])
copy_characters.click(lambda x: x, inputs=[characters_out], outputs=[characters_out])
copy_features.click(lambda x: x, inputs=[features_out], outputs=[features_out])
copy_ip.click(lambda x: x, inputs=[ip_out], outputs=[ip_out])
copy_raw.click(lambda x: x, inputs=[raw_out], outputs=[raw_out])
clear_btn.click(
clear,
inputs=None,
outputs=[
image, url, general_threshold, character_threshold,
features_out, characters_out, ip_out, meta_out, raw_out
],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch()