Spaces:
Running
Running
import base64 | |
import io | |
import json | |
import logging | |
import os | |
import time | |
from pathlib import Path | |
from typing import Any | |
import requests | |
import timm | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
class TaggingHead(torch.nn.Module): | |
def __init__(self, input_dim, num_classes): | |
super().__init__() | |
self.input_dim = input_dim | |
self.num_classes = num_classes | |
self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes)) | |
def forward(self, x): | |
logits = self.head(x) | |
probs = torch.nn.functional.sigmoid(logits) | |
return probs | |
def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]: | |
with tags_file.open("r", encoding="utf-8") as f: | |
tag_info = json.load(f) | |
tag_map = tag_info["tag_map"] | |
tag_split = tag_info["tag_split"] | |
gen_tag_count = tag_split["gen_tag_count"] | |
character_tag_count = tag_split["character_tag_count"] | |
return tag_map, gen_tag_count, character_tag_count | |
def get_character_ip_mapping(mapping_file: Path): | |
with mapping_file.open("r", encoding="utf-8") as f: | |
mapping = json.load(f) | |
return mapping | |
def get_encoder(): | |
base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3" | |
encoder = timm.create_model(base_model_repo, pretrained=False) | |
encoder.reset_classifier(0) | |
return encoder | |
def get_decoder(): | |
decoder = TaggingHead(1024, 13461) | |
return decoder | |
def get_model(): | |
encoder = get_encoder() | |
decoder = get_decoder() | |
model = torch.nn.Sequential(encoder, decoder) | |
return model | |
def load_model(weights_file, device): | |
model = get_model() | |
states_dict = torch.load(weights_file, map_location=device, weights_only=True) | |
model.load_state_dict(states_dict) | |
model.to(device) | |
model.eval() | |
return model | |
def pure_pil_alpha_to_color_v2( | |
image: Image.Image, color: tuple[int, int, int] = (255, 255, 255) | |
) -> Image.Image: | |
""" | |
Convert a PIL image with an alpha channel to a RGB image. | |
This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel. | |
This function will convert the image to a RGB image, and fill the alpha channel with the given color. | |
The alpha channel is the 4th channel of the image. | |
""" | |
image.load() # needed for split() | |
background = Image.new("RGB", image.size, color) | |
background.paste(image, mask=image.split()[3]) # 3 is the alpha channel | |
return background | |
def pil_to_rgb(image: Image.Image) -> Image.Image: | |
if image.mode == "RGBA": | |
image = pure_pil_alpha_to_color_v2(image) | |
elif image.mode == "P": | |
image = pure_pil_alpha_to_color_v2(image.convert("RGBA")) | |
else: | |
image = image.convert("RGB") | |
return image | |
class EndpointHandler: | |
def __init__(self, path: str): | |
repo_path = Path(path) | |
assert repo_path.is_dir(), f"Model directory not found: {repo_path}" | |
weights_file = repo_path / "model_v0.9.pth" | |
tags_file = repo_path / "tags_v0.9_13k.json" | |
mapping_file = repo_path / "char_ip_map.json" | |
if not weights_file.exists(): | |
raise FileNotFoundError(f"Model file not found: {weights_file}") | |
if not tags_file.exists(): | |
raise FileNotFoundError(f"Tags file not found: {tags_file}") | |
if not mapping_file.exists(): | |
raise FileNotFoundError(f"Mapping file not found: {mapping_file}") | |
# Robust device selection: prefer CPU unless CUDA is truly usable | |
force_cpu = os.environ.get("FORCE_CPU", "0") in {"1", "true", "TRUE", "yes", "on"} | |
if not force_cpu and torch.cuda.is_available(): | |
try: | |
# Probe that CUDA can actually be used (driver present) | |
torch.zeros(1).to("cuda") | |
self.device = "cuda" | |
except Exception: | |
self.device = "cpu" | |
else: | |
self.device = "cpu" | |
self.model = load_model(str(weights_file), self.device) | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize((448, 448)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
] | |
) | |
self.fetch_image_timeout = 5.0 | |
self.default_general_threshold = 0.3 | |
self.default_character_threshold = 0.85 | |
tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file) | |
# Invert the tag_map for efficient index-to-tag lookups | |
self.index_to_tag_map = {v: k for k, v in tag_map.items()} | |
self.character_ip_mapping = get_character_ip_mapping(mapping_file) | |
def __call__(self, data: dict[str, Any]) -> dict[str, Any]: | |
inputs = data.pop("inputs", data) | |
fetch_start_time = time.time() | |
if isinstance(inputs, Image.Image): | |
image = inputs | |
elif image_url := inputs.pop("url", None): | |
with requests.get( | |
image_url, stream=True, timeout=self.fetch_image_timeout | |
) as res: | |
res.raise_for_status() | |
image = Image.open(res.raw) | |
elif image_base64_encoded := inputs.pop("image", None): | |
image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded))) | |
else: | |
raise ValueError(f"No image or url provided: {data}") | |
# remove alpha channel if it exists | |
image = pil_to_rgb(image) | |
fetch_time = time.time() - fetch_start_time | |
parameters = data.pop("parameters", {}) | |
general_threshold = parameters.pop( | |
"general_threshold", self.default_general_threshold | |
) | |
character_threshold = parameters.pop( | |
"character_threshold", self.default_character_threshold | |
) | |
# Optional behavior controls | |
mode = parameters.pop("mode", "threshold") # "threshold" | "topk" | |
include_scores = bool(parameters.pop("include_scores", False)) | |
topk_general = int(parameters.pop("topk_general", 25)) | |
topk_character = int(parameters.pop("topk_character", 10)) | |
inference_start_time = time.time() | |
with torch.inference_mode(): | |
# Preprocess image on CPU | |
image_tensor = self.transform(image).unsqueeze(0) | |
# Pin memory and use non_blocking transfer only when using CUDA | |
if self.device == "cuda": | |
image_tensor = image_tensor.pin_memory().to(self.device, non_blocking=True) | |
else: | |
image_tensor = image_tensor.to(self.device) | |
# Run model on GPU | |
probs = self.model(image_tensor)[0] # Get probs for the single image | |
if mode == "topk": | |
# Select top-k by category, independent of thresholds | |
gen_slice = probs[: self.gen_tag_count] | |
char_slice = probs[self.gen_tag_count :] | |
k_gen = max(0, min(int(topk_general), self.gen_tag_count)) | |
k_char = max(0, min(int(topk_character), self.character_tag_count)) | |
gen_scores, gen_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long)) | |
char_scores, char_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long)) | |
if k_gen > 0: | |
gen_scores, gen_idx = torch.topk(gen_slice, k_gen) | |
if k_char > 0: | |
char_scores, char_idx = torch.topk(char_slice, k_char) | |
char_idx = char_idx + self.gen_tag_count | |
# Merge for unified post-processing | |
combined_indices = torch.cat((gen_idx, char_idx)).cpu() | |
combined_scores = torch.cat((gen_scores, char_scores)).cpu() | |
else: | |
# Perform thresholding directly on the GPU | |
general_mask = probs[: self.gen_tag_count] > general_threshold | |
character_mask = probs[self.gen_tag_count :] > character_threshold | |
# Get the indices of positive tags on the GPU | |
general_indices = general_mask.nonzero(as_tuple=True)[0] | |
character_indices = ( | |
character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count | |
) | |
# Combine indices and move the small result tensor to the CPU | |
combined_indices = torch.cat((general_indices, character_indices)).cpu() | |
combined_scores = probs[combined_indices].detach().float().cpu() | |
inference_time = time.time() - inference_start_time | |
post_process_start_time = time.time() | |
cur_gen_tags = [] | |
cur_char_tags = [] | |
gen_scores_out: dict[str, float] = {} | |
char_scores_out: dict[str, float] = {} | |
# Use the efficient pre-computed map for lookups | |
for pos, i in enumerate(combined_indices): | |
idx = int(i.item()) | |
tag = self.index_to_tag_map[idx] | |
if idx < self.gen_tag_count: | |
cur_gen_tags.append(tag) | |
if include_scores: | |
score = float(combined_scores[pos].item()) | |
gen_scores_out[tag] = score | |
else: | |
cur_char_tags.append(tag) | |
if include_scores: | |
score = float(combined_scores[pos].item()) | |
char_scores_out[tag] = score | |
ip_tags = [] | |
for tag in cur_char_tags: | |
if tag in self.character_ip_mapping: | |
ip_tags.extend(self.character_ip_mapping[tag]) | |
ip_tags = sorted(set(ip_tags)) | |
post_process_time = time.time() - post_process_start_time | |
logging.info( | |
f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s" | |
) | |
out: dict[str, Any] = { | |
"feature": cur_gen_tags, | |
"character": cur_char_tags, | |
"ip": ip_tags, | |
"_timings": { | |
"fetch_s": round(fetch_time, 4), | |
"inference_s": round(inference_time, 4), | |
"post_process_s": round(post_process_time, 4), | |
"total_s": round(fetch_time + inference_time + post_process_time, 4), | |
}, | |
"_params": { | |
"mode": mode, | |
"general_threshold": general_threshold, | |
"character_threshold": character_threshold, | |
"topk_general": topk_general, | |
"topk_character": topk_character, | |
}, | |
} | |
if include_scores: | |
out["feature_scores"] = gen_scores_out | |
out["character_scores"] = char_scores_out | |
return out | |