Spaces:
Running
Running
File size: 5,656 Bytes
13aa528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar
import cv2
import imgutils
import numpy as np
from imgutils.generic.yolo import (
_image_preprocess,
_rtdetr_postprocess,
_yolo_postprocess,
rgb_encode,
)
from PIL import Image, ImageDraw
T = TypeVar("T", int, float)
REPO_IDS = {
"head": "deepghs/anime_head_detection",
"face": "deepghs/anime_face_detection",
"eye": "deepghs/anime_eye_detection",
}
@dataclass
class DetectorOutput(Generic[T]):
bboxes: list[list[T]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list)
confidences: list[float] = field(default_factory=list)
previews: Optional[Image.Image] = None
class AnimeDetector:
"""
A class used to perform object detection on anime images.
Please refer to the `imgutils` documentation for more information on the available models.
"""
def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
model_manager = imgutils.generic.yolo._open_models_for_repo_id(
repo_id, hf_token=hf_token
)
model, max_infer_size, labels = model_manager._open_model(model_name)
self.model = model
self.max_infer_size = max_infer_size
self.labels = labels
self.model_type = model_manager._get_model_type(model_name)
def __call__(
self,
image: Image.Image,
conf_threshold: float = 0.3,
iou_threshold: float = 0.7,
allow_dynamic: bool = False,
) -> DetectorOutput[float]:
"""
Perform object detection on the given image.
Args:
image (Image.Image): The input image on which to perform detection.
conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3.
iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7.
allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False.
Returns:
DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image.
Raises:
ValueError: If the model type is unknown.
"""
# Preprocessing
new_image, old_size, new_size = _image_preprocess(
image, self.max_infer_size, allow_dynamic=allow_dynamic
)
data = rgb_encode(new_image)[None, ...]
# Start detection
(output,) = self.model.run(["output0"], {"images": data})
# Postprocessing
if self.model_type == "yolo":
output = _yolo_postprocess(
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=self.labels,
)
elif self.model_type == "rtdetr":
output = _rtdetr_postprocess(
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=self.labels,
)
else:
raise ValueError(
f"Unknown object detection model type - {self.model_type!r}."
) # pragma: no cover
if len(output) == 0:
return DetectorOutput()
bboxes = [x[0] for x in output] # [x0, y0, x1, y1]
masks = create_mask_from_bbox(bboxes, image.size)
confidences = [x[2] for x in output]
# Create a preview image
previews = []
for mask in masks:
np_image = np.array(image)
np_mask = np.array(mask)
preview = cv2.bitwise_and(
np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR)
)
preview = Image.fromarray(preview)
previews.append(preview)
return DetectorOutput(
bboxes=bboxes, masks=masks, confidences=confidences, previews=previews
)
def create_mask_from_bbox(
bboxes: list[list[float]], shape: tuple[int, int]
) -> list[Image.Image]:
"""
Creates a list of binary masks from bounding boxes.
Args:
bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented
by a list of four float values [x_min, y_min, x_max, y_max].
shape (tuple[int, int]): The shape of the mask (height, width).
Returns:
list[Image.Image]: A list of PIL Image objects representing the binary masks.
"""
masks = []
for bbox in bboxes:
mask = Image.new("L", shape, 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rectangle(bbox, fill=255)
masks.append(mask)
return masks
def create_bbox_from_mask(
masks: list[Image.Image], shape: tuple[int, int]
) -> list[list[int]]:
"""
Create bounding boxes from a list of mask images.
Args:
masks (list[Image.Image]): A list of PIL Image objects representing the masks.
shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks.
Returns:
list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower].
"""
bboxes = []
for mask in masks:
mask = mask.resize(shape)
bbox = mask.getbbox()
if bbox is not None:
bboxes.append(list(bbox))
return bboxes
|