import numpy as np from pathlib import Path from typing import List, Tuple import logging from PIL import Image from torchvision.transforms.v2 import ToPILImage from scripts.parse_cut_from_page import extract_cutbox_coordinates from src.detectors import AnimeDetector from src.pipelines import TagAndFilteringPipeline from src.taggers import WaifuDiffusionTagger from src.utils.device import determine_accelerator topil = ToPILImage() logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # 1. Initialize the filtering pipeline device = determine_accelerator() logger.info(f"Using device: {device}") logger.info("Initializing filtering pipeline...") detector = AnimeDetector( repo_id="deepghs/anime_face_detection", model_name="face_detect_v1.4_s", hf_token=None, ) tagger = WaifuDiffusionTagger(device=device) filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector) def process_single_image(lineart_pil_img: Image.Image) -> List[Tuple[int, Tuple[int, int, int, int]]]: """ Worker function to process a single set of images. Opens images, extracts bounding boxes, crops and filters the cut images. """ try: line_img = lineart_pil_img.convert("RGB") except Exception as e: logger.error(f"Error loading images for {lineart_pil_img}: {e}") return bounding_boxes = extract_cutbox_coordinates(line_img) images = [topil(np.array(line_img)[top:bottom, left:right]) for (left, top, right, bottom) in bounding_boxes] # 3. Filter images using the filtering pipeline logger.info(f"Filtering images...") filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7) filter_flags = filter_output.filter_flags logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.") filtered_bboxes = [bb for bb, flag in zip(bounding_boxes, filter_flags) if flag] index_added_filtered_bboxes = [(i+1, (left, top, right - left, bottom - top)) for i, (left, top, right, bottom) in enumerate(filtered_bboxes)] return index_added_filtered_bboxes