Spaces:
Running
Running
import numpy as np | |
from pathlib import Path | |
from typing import List, Tuple | |
import logging | |
from PIL import Image | |
from torchvision.transforms.v2 import ToPILImage | |
from scripts.parse_cut_from_page import extract_cutbox_coordinates | |
from src.detectors import AnimeDetector | |
from src.pipelines import TagAndFilteringPipeline | |
from src.taggers import WaifuDiffusionTagger | |
from src.utils.device import determine_accelerator | |
topil = ToPILImage() | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
# 1. Initialize the filtering pipeline | |
device = determine_accelerator() | |
logger.info(f"Using device: {device}") | |
logger.info("Initializing filtering pipeline...") | |
detector = AnimeDetector( | |
repo_id="deepghs/anime_face_detection", | |
model_name="face_detect_v1.4_s", | |
hf_token=None, | |
) | |
tagger = WaifuDiffusionTagger(device=device) | |
filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector) | |
def process_single_image(lineart_pil_img: Image.Image) -> List[Tuple[int, Tuple[int, int, int, int]]]: | |
""" | |
Worker function to process a single set of images. | |
Opens images, extracts bounding boxes, crops and filters the cut images. | |
""" | |
try: | |
line_img = lineart_pil_img.convert("RGB") | |
except Exception as e: | |
logger.error(f"Error loading images for {lineart_pil_img}: {e}") | |
return | |
bounding_boxes = extract_cutbox_coordinates(line_img) | |
images = [topil(np.array(line_img)[top:bottom, left:right]) for (left, top, right, bottom) in bounding_boxes] | |
# 3. Filter images using the filtering pipeline | |
logger.info(f"Filtering images...") | |
filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7) | |
filter_flags = filter_output.filter_flags | |
logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.") | |
filtered_bboxes = [bb for bb, flag in zip(bounding_boxes, filter_flags) if flag] | |
index_added_filtered_bboxes = [(i+1, (left, top, right - left, bottom - top)) for i, (left, top, right, bottom) in enumerate(filtered_bboxes)] | |
return index_added_filtered_bboxes | |