File size: 5,656 Bytes
13aa528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar

import cv2
import imgutils
import numpy as np
from imgutils.generic.yolo import (
    _image_preprocess,
    _rtdetr_postprocess,
    _yolo_postprocess,
    rgb_encode,
)
from PIL import Image, ImageDraw

T = TypeVar("T", int, float)

REPO_IDS = {
    "head": "deepghs/anime_head_detection",
    "face": "deepghs/anime_face_detection",
    "eye": "deepghs/anime_eye_detection",
}


@dataclass
class DetectorOutput(Generic[T]):
    bboxes: list[list[T]] = field(default_factory=list)
    masks: list[Image.Image] = field(default_factory=list)
    confidences: list[float] = field(default_factory=list)
    previews: Optional[Image.Image] = None


class AnimeDetector:
    """
    A class used to perform object detection on anime images.
    Please refer to the `imgutils` documentation for more information on the available models.
    """

    def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
        model_manager = imgutils.generic.yolo._open_models_for_repo_id(
            repo_id, hf_token=hf_token
        )
        model, max_infer_size, labels = model_manager._open_model(model_name)

        self.model = model

        self.max_infer_size = max_infer_size
        self.labels = labels
        self.model_type = model_manager._get_model_type(model_name)

    def __call__(
        self,
        image: Image.Image,
        conf_threshold: float = 0.3,
        iou_threshold: float = 0.7,
        allow_dynamic: bool = False,
    ) -> DetectorOutput[float]:
        """
        Perform object detection on the given image.

        Args:
            image (Image.Image): The input image on which to perform detection.
            conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3.
            iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7.
            allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False.

        Returns:
            DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image.

        Raises:
            ValueError: If the model type is unknown.
        """
        # Preprocessing
        new_image, old_size, new_size = _image_preprocess(
            image, self.max_infer_size, allow_dynamic=allow_dynamic
        )
        data = rgb_encode(new_image)[None, ...]

        # Start detection
        (output,) = self.model.run(["output0"], {"images": data})

        # Postprocessing
        if self.model_type == "yolo":
            output = _yolo_postprocess(
                output=output[0],
                conf_threshold=conf_threshold,
                iou_threshold=iou_threshold,
                old_size=old_size,
                new_size=new_size,
                labels=self.labels,
            )
        elif self.model_type == "rtdetr":
            output = _rtdetr_postprocess(
                output=output[0],
                conf_threshold=conf_threshold,
                iou_threshold=iou_threshold,
                old_size=old_size,
                new_size=new_size,
                labels=self.labels,
            )
        else:
            raise ValueError(
                f"Unknown object detection model type - {self.model_type!r}."
            )  # pragma: no cover

        if len(output) == 0:
            return DetectorOutput()

        bboxes = [x[0] for x in output]  # [x0, y0, x1, y1]
        masks = create_mask_from_bbox(bboxes, image.size)
        confidences = [x[2] for x in output]

        # Create a preview image
        previews = []
        for mask in masks:
            np_image = np.array(image)
            np_mask = np.array(mask)
            preview = cv2.bitwise_and(
                np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR)
            )
            preview = Image.fromarray(preview)
            previews.append(preview)

        return DetectorOutput(
            bboxes=bboxes, masks=masks, confidences=confidences, previews=previews
        )


def create_mask_from_bbox(
    bboxes: list[list[float]], shape: tuple[int, int]
) -> list[Image.Image]:
    """
    Creates a list of binary masks from bounding boxes.

    Args:
        bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented
                                    by a list of four float values [x_min, y_min, x_max, y_max].
        shape (tuple[int, int]): The shape of the mask (height, width).

    Returns:
        list[Image.Image]: A list of PIL Image objects representing the binary masks.
    """
    masks = []
    for bbox in bboxes:
        mask = Image.new("L", shape, 0)
        mask_draw = ImageDraw.Draw(mask)
        mask_draw.rectangle(bbox, fill=255)
        masks.append(mask)
    return masks


def create_bbox_from_mask(
    masks: list[Image.Image], shape: tuple[int, int]
) -> list[list[int]]:
    """
    Create bounding boxes from a list of mask images.

    Args:
        masks (list[Image.Image]): A list of PIL Image objects representing the masks.
        shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks.

    Returns:
        list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower].
    """
    bboxes = []
    for mask in masks:
        mask = mask.resize(shape)
        bbox = mask.getbbox()
        if bbox is not None:
            bboxes.append(list(bbox))
    return bboxes