import base64 import io import json import logging import os import time from pathlib import Path from typing import Any import requests import timm import torch import torchvision.transforms as transforms from PIL import Image class TaggingHead(torch.nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.input_dim = input_dim self.num_classes = num_classes self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes)) def forward(self, x): logits = self.head(x) probs = torch.nn.functional.sigmoid(logits) return probs def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]: with tags_file.open("r", encoding="utf-8") as f: tag_info = json.load(f) tag_map = tag_info["tag_map"] tag_split = tag_info["tag_split"] gen_tag_count = tag_split["gen_tag_count"] character_tag_count = tag_split["character_tag_count"] return tag_map, gen_tag_count, character_tag_count def get_character_ip_mapping(mapping_file: Path): with mapping_file.open("r", encoding="utf-8") as f: mapping = json.load(f) return mapping def get_encoder(): base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3" encoder = timm.create_model(base_model_repo, pretrained=False) encoder.reset_classifier(0) return encoder def get_decoder(): decoder = TaggingHead(1024, 13461) return decoder def get_model(): encoder = get_encoder() decoder = get_decoder() model = torch.nn.Sequential(encoder, decoder) return model def load_model(weights_file, device): model = get_model() states_dict = torch.load(weights_file, map_location=device, weights_only=True) model.load_state_dict(states_dict) model.to(device) model.eval() return model def pure_pil_alpha_to_color_v2( image: Image.Image, color: tuple[int, int, int] = (255, 255, 255) ) -> Image.Image: """ Convert a PIL image with an alpha channel to a RGB image. This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel. This function will convert the image to a RGB image, and fill the alpha channel with the given color. The alpha channel is the 4th channel of the image. """ image.load() # needed for split() background = Image.new("RGB", image.size, color) background.paste(image, mask=image.split()[3]) # 3 is the alpha channel return background def pil_to_rgb(image: Image.Image) -> Image.Image: if image.mode == "RGBA": image = pure_pil_alpha_to_color_v2(image) elif image.mode == "P": image = pure_pil_alpha_to_color_v2(image.convert("RGBA")) else: image = image.convert("RGB") return image class EndpointHandler: def __init__(self, path: str): repo_path = Path(path) assert repo_path.is_dir(), f"Model directory not found: {repo_path}" weights_file = repo_path / "model_v0.9.pth" tags_file = repo_path / "tags_v0.9_13k.json" mapping_file = repo_path / "char_ip_map.json" if not weights_file.exists(): raise FileNotFoundError(f"Model file not found: {weights_file}") if not tags_file.exists(): raise FileNotFoundError(f"Tags file not found: {tags_file}") if not mapping_file.exists(): raise FileNotFoundError(f"Mapping file not found: {mapping_file}") # Robust device selection: prefer CPU unless CUDA is truly usable force_cpu = os.environ.get("FORCE_CPU", "0") in {"1", "true", "TRUE", "yes", "on"} if not force_cpu and torch.cuda.is_available(): try: # Probe that CUDA can actually be used (driver present) torch.zeros(1).to("cuda") self.device = "cuda" except Exception: self.device = "cpu" else: self.device = "cpu" self.model = load_model(str(weights_file), self.device) self.transform = transforms.Compose( [ transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) self.fetch_image_timeout = 5.0 self.default_general_threshold = 0.3 self.default_character_threshold = 0.85 tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file) # Invert the tag_map for efficient index-to-tag lookups self.index_to_tag_map = {v: k for k, v in tag_map.items()} self.character_ip_mapping = get_character_ip_mapping(mapping_file) def __call__(self, data: dict[str, Any]) -> dict[str, Any]: inputs = data.pop("inputs", data) fetch_start_time = time.time() if isinstance(inputs, Image.Image): image = inputs elif image_url := inputs.pop("url", None): with requests.get( image_url, stream=True, timeout=self.fetch_image_timeout ) as res: res.raise_for_status() image = Image.open(res.raw) elif image_base64_encoded := inputs.pop("image", None): image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded))) else: raise ValueError(f"No image or url provided: {data}") # remove alpha channel if it exists image = pil_to_rgb(image) fetch_time = time.time() - fetch_start_time parameters = data.pop("parameters", {}) general_threshold = parameters.pop( "general_threshold", self.default_general_threshold ) character_threshold = parameters.pop( "character_threshold", self.default_character_threshold ) # Optional behavior controls mode = parameters.pop("mode", "threshold") # "threshold" | "topk" include_scores = bool(parameters.pop("include_scores", False)) topk_general = int(parameters.pop("topk_general", 25)) topk_character = int(parameters.pop("topk_character", 10)) inference_start_time = time.time() with torch.inference_mode(): # Preprocess image on CPU image_tensor = self.transform(image).unsqueeze(0) # Pin memory and use non_blocking transfer only when using CUDA if self.device == "cuda": image_tensor = image_tensor.pin_memory().to(self.device, non_blocking=True) else: image_tensor = image_tensor.to(self.device) # Run model on GPU probs = self.model(image_tensor)[0] # Get probs for the single image if mode == "topk": # Select top-k by category, independent of thresholds gen_slice = probs[: self.gen_tag_count] char_slice = probs[self.gen_tag_count :] k_gen = max(0, min(int(topk_general), self.gen_tag_count)) k_char = max(0, min(int(topk_character), self.character_tag_count)) gen_scores, gen_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long)) char_scores, char_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long)) if k_gen > 0: gen_scores, gen_idx = torch.topk(gen_slice, k_gen) if k_char > 0: char_scores, char_idx = torch.topk(char_slice, k_char) char_idx = char_idx + self.gen_tag_count # Merge for unified post-processing combined_indices = torch.cat((gen_idx, char_idx)).cpu() combined_scores = torch.cat((gen_scores, char_scores)).cpu() else: # Perform thresholding directly on the GPU general_mask = probs[: self.gen_tag_count] > general_threshold character_mask = probs[self.gen_tag_count :] > character_threshold # Get the indices of positive tags on the GPU general_indices = general_mask.nonzero(as_tuple=True)[0] character_indices = ( character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count ) # Combine indices and move the small result tensor to the CPU combined_indices = torch.cat((general_indices, character_indices)).cpu() combined_scores = probs[combined_indices].detach().float().cpu() inference_time = time.time() - inference_start_time post_process_start_time = time.time() cur_gen_tags = [] cur_char_tags = [] gen_scores_out: dict[str, float] = {} char_scores_out: dict[str, float] = {} # Use the efficient pre-computed map for lookups for pos, i in enumerate(combined_indices): idx = int(i.item()) tag = self.index_to_tag_map[idx] if idx < self.gen_tag_count: cur_gen_tags.append(tag) if include_scores: score = float(combined_scores[pos].item()) gen_scores_out[tag] = score else: cur_char_tags.append(tag) if include_scores: score = float(combined_scores[pos].item()) char_scores_out[tag] = score ip_tags = [] for tag in cur_char_tags: if tag in self.character_ip_mapping: ip_tags.extend(self.character_ip_mapping[tag]) ip_tags = sorted(set(ip_tags)) post_process_time = time.time() - post_process_start_time logging.info( f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s" ) out: dict[str, Any] = { "feature": cur_gen_tags, "character": cur_char_tags, "ip": ip_tags, "_timings": { "fetch_s": round(fetch_time, 4), "inference_s": round(inference_time, 4), "post_process_s": round(post_process_time, 4), "total_s": round(fetch_time + inference_time + post_process_time, 4), }, "_params": { "mode": mode, "general_threshold": general_threshold, "character_threshold": character_threshold, "topk_general": topk_general, "topk_character": topk_character, }, } if include_scores: out["feature_scores"] = gen_scores_out out["character_scores"] = char_scores_out return out