webtoon_cropper / src /detectors /imgutils_detector.py
wise-water's picture
init commit
13aa528
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar
import cv2
import imgutils
import numpy as np
from imgutils.generic.yolo import (
_image_preprocess,
_rtdetr_postprocess,
_yolo_postprocess,
rgb_encode,
)
from PIL import Image, ImageDraw
T = TypeVar("T", int, float)
REPO_IDS = {
"head": "deepghs/anime_head_detection",
"face": "deepghs/anime_face_detection",
"eye": "deepghs/anime_eye_detection",
}
@dataclass
class DetectorOutput(Generic[T]):
bboxes: list[list[T]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list)
confidences: list[float] = field(default_factory=list)
previews: Optional[Image.Image] = None
class AnimeDetector:
"""
A class used to perform object detection on anime images.
Please refer to the `imgutils` documentation for more information on the available models.
"""
def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
model_manager = imgutils.generic.yolo._open_models_for_repo_id(
repo_id, hf_token=hf_token
)
model, max_infer_size, labels = model_manager._open_model(model_name)
self.model = model
self.max_infer_size = max_infer_size
self.labels = labels
self.model_type = model_manager._get_model_type(model_name)
def __call__(
self,
image: Image.Image,
conf_threshold: float = 0.3,
iou_threshold: float = 0.7,
allow_dynamic: bool = False,
) -> DetectorOutput[float]:
"""
Perform object detection on the given image.
Args:
image (Image.Image): The input image on which to perform detection.
conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3.
iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7.
allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False.
Returns:
DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image.
Raises:
ValueError: If the model type is unknown.
"""
# Preprocessing
new_image, old_size, new_size = _image_preprocess(
image, self.max_infer_size, allow_dynamic=allow_dynamic
)
data = rgb_encode(new_image)[None, ...]
# Start detection
(output,) = self.model.run(["output0"], {"images": data})
# Postprocessing
if self.model_type == "yolo":
output = _yolo_postprocess(
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=self.labels,
)
elif self.model_type == "rtdetr":
output = _rtdetr_postprocess(
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=self.labels,
)
else:
raise ValueError(
f"Unknown object detection model type - {self.model_type!r}."
) # pragma: no cover
if len(output) == 0:
return DetectorOutput()
bboxes = [x[0] for x in output] # [x0, y0, x1, y1]
masks = create_mask_from_bbox(bboxes, image.size)
confidences = [x[2] for x in output]
# Create a preview image
previews = []
for mask in masks:
np_image = np.array(image)
np_mask = np.array(mask)
preview = cv2.bitwise_and(
np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR)
)
preview = Image.fromarray(preview)
previews.append(preview)
return DetectorOutput(
bboxes=bboxes, masks=masks, confidences=confidences, previews=previews
)
def create_mask_from_bbox(
bboxes: list[list[float]], shape: tuple[int, int]
) -> list[Image.Image]:
"""
Creates a list of binary masks from bounding boxes.
Args:
bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented
by a list of four float values [x_min, y_min, x_max, y_max].
shape (tuple[int, int]): The shape of the mask (height, width).
Returns:
list[Image.Image]: A list of PIL Image objects representing the binary masks.
"""
masks = []
for bbox in bboxes:
mask = Image.new("L", shape, 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rectangle(bbox, fill=255)
masks.append(mask)
return masks
def create_bbox_from_mask(
masks: list[Image.Image], shape: tuple[int, int]
) -> list[list[int]]:
"""
Create bounding boxes from a list of mask images.
Args:
masks (list[Image.Image]): A list of PIL Image objects representing the masks.
shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks.
Returns:
list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower].
"""
bboxes = []
for mask in masks:
mask = mask.resize(shape)
bbox = mask.getbbox()
if bbox is not None:
bboxes.append(list(bbox))
return bboxes