File size: 10,763 Bytes
1412dfd
 
 
 
618b677
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618b677
 
 
 
 
 
 
 
 
 
 
1412dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20d8f48
 
 
 
 
1412dfd
 
 
618b677
 
 
 
 
 
 
1412dfd
 
 
 
20d8f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1412dfd
20d8f48
 
 
 
 
1412dfd
20d8f48
 
 
1412dfd
 
 
 
 
 
 
20d8f48
 
1412dfd
 
20d8f48
 
1412dfd
 
 
20d8f48
 
 
1412dfd
 
20d8f48
 
 
1412dfd
 
 
 
 
 
 
 
 
 
 
 
20d8f48
1412dfd
 
 
20d8f48
 
 
 
 
 
 
 
 
 
 
 
 
1412dfd
20d8f48
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import base64
import io
import json
import logging
import os
import time
from pathlib import Path
from typing import Any

import requests
import timm
import torch
import torchvision.transforms as transforms
from PIL import Image


class TaggingHead(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))

    def forward(self, x):
        logits = self.head(x)
        probs = torch.nn.functional.sigmoid(logits)
        return probs


def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
    with tags_file.open("r", encoding="utf-8") as f:
        tag_info = json.load(f)
    tag_map = tag_info["tag_map"]
    tag_split = tag_info["tag_split"]
    gen_tag_count = tag_split["gen_tag_count"]
    character_tag_count = tag_split["character_tag_count"]
    return tag_map, gen_tag_count, character_tag_count


def get_character_ip_mapping(mapping_file: Path):
    with mapping_file.open("r", encoding="utf-8") as f:
        mapping = json.load(f)
    return mapping


def get_encoder():
    base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
    encoder = timm.create_model(base_model_repo, pretrained=False)
    encoder.reset_classifier(0)
    return encoder


def get_decoder():
    decoder = TaggingHead(1024, 13461)
    return decoder


def get_model():
    encoder = get_encoder()
    decoder = get_decoder()
    model = torch.nn.Sequential(encoder, decoder)
    return model


def load_model(weights_file, device):
    model = get_model()
    states_dict = torch.load(weights_file, map_location=device, weights_only=True)
    model.load_state_dict(states_dict)
    model.to(device)
    model.eval()
    return model


def pure_pil_alpha_to_color_v2(
    image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
) -> Image.Image:
    """
    Convert a PIL image with an alpha channel to a RGB image.
    This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel.
    This function will convert the image to a RGB image, and fill the alpha channel with the given color.
    The alpha channel is the 4th channel of the image.
    """
    image.load()  # needed for split()
    background = Image.new("RGB", image.size, color)
    background.paste(image, mask=image.split()[3])  # 3 is the alpha channel
    return background


def pil_to_rgb(image: Image.Image) -> Image.Image:
    if image.mode == "RGBA":
        image = pure_pil_alpha_to_color_v2(image)
    elif image.mode == "P":
        image = pure_pil_alpha_to_color_v2(image.convert("RGBA"))
    else:
        image = image.convert("RGB")
    return image


class EndpointHandler:
    def __init__(self, path: str):
        repo_path = Path(path)
        assert repo_path.is_dir(), f"Model directory not found: {repo_path}"
        weights_file = repo_path / "model_v0.9.pth"
        tags_file = repo_path / "tags_v0.9_13k.json"
        mapping_file = repo_path / "char_ip_map.json"
        if not weights_file.exists():
            raise FileNotFoundError(f"Model file not found: {weights_file}")
        if not tags_file.exists():
            raise FileNotFoundError(f"Tags file not found: {tags_file}")
        if not mapping_file.exists():
            raise FileNotFoundError(f"Mapping file not found: {mapping_file}")

        # Robust device selection: prefer CPU unless CUDA is truly usable
        force_cpu = os.environ.get("FORCE_CPU", "0") in {"1", "true", "TRUE", "yes", "on"}
        if not force_cpu and torch.cuda.is_available():
            try:
                # Probe that CUDA can actually be used (driver present)
                torch.zeros(1).to("cuda")
                self.device = "cuda"
            except Exception:
                self.device = "cpu"
        else:
            self.device = "cpu"
        self.model = load_model(str(weights_file), self.device)
        self.transform = transforms.Compose(
            [
                transforms.Resize((448, 448)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )
        self.fetch_image_timeout = 5.0
        self.default_general_threshold = 0.3
        self.default_character_threshold = 0.85

        tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)

        # Invert the tag_map for efficient index-to-tag lookups
        self.index_to_tag_map = {v: k for k, v in tag_map.items()}

        self.character_ip_mapping = get_character_ip_mapping(mapping_file)

    def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
        inputs = data.pop("inputs", data)

        fetch_start_time = time.time()
        if isinstance(inputs, Image.Image):
            image = inputs
        elif image_url := inputs.pop("url", None):
            with requests.get(
                image_url, stream=True, timeout=self.fetch_image_timeout
            ) as res:
                res.raise_for_status()
                image = Image.open(res.raw)
        elif image_base64_encoded := inputs.pop("image", None):
            image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded)))
        else:
            raise ValueError(f"No image or url provided: {data}")
        # remove alpha channel if it exists
        image = pil_to_rgb(image)
        fetch_time = time.time() - fetch_start_time

        parameters = data.pop("parameters", {})
        general_threshold = parameters.pop(
            "general_threshold", self.default_general_threshold
        )
        character_threshold = parameters.pop(
            "character_threshold", self.default_character_threshold
        )
        # Optional behavior controls
        mode = parameters.pop("mode", "threshold")  # "threshold" | "topk"
        include_scores = bool(parameters.pop("include_scores", False))
        topk_general = int(parameters.pop("topk_general", 25))
        topk_character = int(parameters.pop("topk_character", 10))

        inference_start_time = time.time()
        with torch.inference_mode():
            # Preprocess image on CPU
            image_tensor = self.transform(image).unsqueeze(0)
            # Pin memory and use non_blocking transfer only when using CUDA
            if self.device == "cuda":
                image_tensor = image_tensor.pin_memory().to(self.device, non_blocking=True)
            else:
                image_tensor = image_tensor.to(self.device)

            # Run model on GPU
            probs = self.model(image_tensor)[0]  # Get probs for the single image

            if mode == "topk":
                # Select top-k by category, independent of thresholds
                gen_slice = probs[: self.gen_tag_count]
                char_slice = probs[self.gen_tag_count :]
                k_gen = max(0, min(int(topk_general), self.gen_tag_count))
                k_char = max(0, min(int(topk_character), self.character_tag_count))
                gen_scores, gen_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long))
                char_scores, char_idx = (torch.tensor([]), torch.tensor([], dtype=torch.long))
                if k_gen > 0:
                    gen_scores, gen_idx = torch.topk(gen_slice, k_gen)
                if k_char > 0:
                    char_scores, char_idx = torch.topk(char_slice, k_char)
                    char_idx = char_idx + self.gen_tag_count

                # Merge for unified post-processing
                combined_indices = torch.cat((gen_idx, char_idx)).cpu()
                combined_scores = torch.cat((gen_scores, char_scores)).cpu()
            else:
                # Perform thresholding directly on the GPU
                general_mask = probs[: self.gen_tag_count] > general_threshold
                character_mask = probs[self.gen_tag_count :] > character_threshold

                # Get the indices of positive tags on the GPU
                general_indices = general_mask.nonzero(as_tuple=True)[0]
                character_indices = (
                    character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
                )

                # Combine indices and move the small result tensor to the CPU
                combined_indices = torch.cat((general_indices, character_indices)).cpu()
                combined_scores = probs[combined_indices].detach().float().cpu()

        inference_time = time.time() - inference_start_time

        post_process_start_time = time.time()

        cur_gen_tags = []
        cur_char_tags = []
        gen_scores_out: dict[str, float] = {}
        char_scores_out: dict[str, float] = {}

        # Use the efficient pre-computed map for lookups
        for pos, i in enumerate(combined_indices):
            idx = int(i.item())
            tag = self.index_to_tag_map[idx]
            if idx < self.gen_tag_count:
                cur_gen_tags.append(tag)
                if include_scores:
                    score = float(combined_scores[pos].item())
                    gen_scores_out[tag] = score
            else:
                cur_char_tags.append(tag)
                if include_scores:
                    score = float(combined_scores[pos].item())
                    char_scores_out[tag] = score

        ip_tags = []
        for tag in cur_char_tags:
            if tag in self.character_ip_mapping:
                ip_tags.extend(self.character_ip_mapping[tag])
        ip_tags = sorted(set(ip_tags))
        post_process_time = time.time() - post_process_start_time

        logging.info(
            f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s"
        )

        out: dict[str, Any] = {
            "feature": cur_gen_tags,
            "character": cur_char_tags,
            "ip": ip_tags,
            "_timings": {
                "fetch_s": round(fetch_time, 4),
                "inference_s": round(inference_time, 4),
                "post_process_s": round(post_process_time, 4),
                "total_s": round(fetch_time + inference_time + post_process_time, 4),
            },
            "_params": {
                "mode": mode,
                "general_threshold": general_threshold,
                "character_threshold": character_threshold,
                "topk_general": topk_general,
                "topk_character": topk_character,
            },
        }

        if include_scores:
            out["feature_scores"] = gen_scores_out
            out["character_scores"] = char_scores_out

        return out