Spaces:
Running
Running
from dataclasses import dataclass, field | |
from typing import Generic, Optional, TypeVar | |
import cv2 | |
import imgutils | |
import numpy as np | |
from imgutils.generic.yolo import ( | |
_image_preprocess, | |
_rtdetr_postprocess, | |
_yolo_postprocess, | |
rgb_encode, | |
) | |
from PIL import Image, ImageDraw | |
T = TypeVar("T", int, float) | |
REPO_IDS = { | |
"head": "deepghs/anime_head_detection", | |
"face": "deepghs/anime_face_detection", | |
"eye": "deepghs/anime_eye_detection", | |
} | |
class DetectorOutput(Generic[T]): | |
bboxes: list[list[T]] = field(default_factory=list) | |
masks: list[Image.Image] = field(default_factory=list) | |
confidences: list[float] = field(default_factory=list) | |
previews: Optional[Image.Image] = None | |
class AnimeDetector: | |
""" | |
A class used to perform object detection on anime images. | |
Please refer to the `imgutils` documentation for more information on the available models. | |
""" | |
def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None): | |
model_manager = imgutils.generic.yolo._open_models_for_repo_id( | |
repo_id, hf_token=hf_token | |
) | |
model, max_infer_size, labels = model_manager._open_model(model_name) | |
self.model = model | |
self.max_infer_size = max_infer_size | |
self.labels = labels | |
self.model_type = model_manager._get_model_type(model_name) | |
def __call__( | |
self, | |
image: Image.Image, | |
conf_threshold: float = 0.3, | |
iou_threshold: float = 0.7, | |
allow_dynamic: bool = False, | |
) -> DetectorOutput[float]: | |
""" | |
Perform object detection on the given image. | |
Args: | |
image (Image.Image): The input image on which to perform detection. | |
conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3. | |
iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7. | |
allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False. | |
Returns: | |
DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image. | |
Raises: | |
ValueError: If the model type is unknown. | |
""" | |
# Preprocessing | |
new_image, old_size, new_size = _image_preprocess( | |
image, self.max_infer_size, allow_dynamic=allow_dynamic | |
) | |
data = rgb_encode(new_image)[None, ...] | |
# Start detection | |
(output,) = self.model.run(["output0"], {"images": data}) | |
# Postprocessing | |
if self.model_type == "yolo": | |
output = _yolo_postprocess( | |
output=output[0], | |
conf_threshold=conf_threshold, | |
iou_threshold=iou_threshold, | |
old_size=old_size, | |
new_size=new_size, | |
labels=self.labels, | |
) | |
elif self.model_type == "rtdetr": | |
output = _rtdetr_postprocess( | |
output=output[0], | |
conf_threshold=conf_threshold, | |
iou_threshold=iou_threshold, | |
old_size=old_size, | |
new_size=new_size, | |
labels=self.labels, | |
) | |
else: | |
raise ValueError( | |
f"Unknown object detection model type - {self.model_type!r}." | |
) # pragma: no cover | |
if len(output) == 0: | |
return DetectorOutput() | |
bboxes = [x[0] for x in output] # [x0, y0, x1, y1] | |
masks = create_mask_from_bbox(bboxes, image.size) | |
confidences = [x[2] for x in output] | |
# Create a preview image | |
previews = [] | |
for mask in masks: | |
np_image = np.array(image) | |
np_mask = np.array(mask) | |
preview = cv2.bitwise_and( | |
np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR) | |
) | |
preview = Image.fromarray(preview) | |
previews.append(preview) | |
return DetectorOutput( | |
bboxes=bboxes, masks=masks, confidences=confidences, previews=previews | |
) | |
def create_mask_from_bbox( | |
bboxes: list[list[float]], shape: tuple[int, int] | |
) -> list[Image.Image]: | |
""" | |
Creates a list of binary masks from bounding boxes. | |
Args: | |
bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented | |
by a list of four float values [x_min, y_min, x_max, y_max]. | |
shape (tuple[int, int]): The shape of the mask (height, width). | |
Returns: | |
list[Image.Image]: A list of PIL Image objects representing the binary masks. | |
""" | |
masks = [] | |
for bbox in bboxes: | |
mask = Image.new("L", shape, 0) | |
mask_draw = ImageDraw.Draw(mask) | |
mask_draw.rectangle(bbox, fill=255) | |
masks.append(mask) | |
return masks | |
def create_bbox_from_mask( | |
masks: list[Image.Image], shape: tuple[int, int] | |
) -> list[list[int]]: | |
""" | |
Create bounding boxes from a list of mask images. | |
Args: | |
masks (list[Image.Image]): A list of PIL Image objects representing the masks. | |
shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks. | |
Returns: | |
list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower]. | |
""" | |
bboxes = [] | |
for mask in masks: | |
mask = mask.resize(shape) | |
bbox = mask.getbbox() | |
if bbox is not None: | |
bboxes.append(list(bbox)) | |
return bboxes | |