File size: 2,151 Bytes
13aa528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
from pathlib import Path
from typing import List, Tuple
import logging
from PIL import Image
from torchvision.transforms.v2 import ToPILImage
from scripts.parse_cut_from_page import extract_cutbox_coordinates

from src.detectors import AnimeDetector
from src.pipelines import TagAndFilteringPipeline
from src.taggers import WaifuDiffusionTagger
from src.utils.device import determine_accelerator

topil = ToPILImage()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# 1. Initialize the filtering pipeline
device = determine_accelerator()
logger.info(f"Using device: {device}")

logger.info("Initializing filtering pipeline...")
detector = AnimeDetector(
    repo_id="deepghs/anime_face_detection",
    model_name="face_detect_v1.4_s",
    hf_token=None,
)

tagger = WaifuDiffusionTagger(device=device)

filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector)


def process_single_image(lineart_pil_img: Image.Image) -> List[Tuple[int, Tuple[int, int, int, int]]]:
    """
    Worker function to process a single set of images.
    Opens images, extracts bounding boxes, crops and filters the cut images.
    """
    try:
        line_img = lineart_pil_img.convert("RGB")
    except Exception as e:
        logger.error(f"Error loading images for {lineart_pil_img}: {e}")
        return

    bounding_boxes = extract_cutbox_coordinates(line_img)
    
    images = [topil(np.array(line_img)[top:bottom, left:right]) for (left, top, right, bottom) in bounding_boxes]
    # 3. Filter images using the filtering pipeline
    logger.info(f"Filtering images...")   
    filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7)
    filter_flags = filter_output.filter_flags
    logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.")
    filtered_bboxes = [bb for bb, flag in zip(bounding_boxes, filter_flags) if flag]
    index_added_filtered_bboxes = [(i+1, (left, top, right - left, bottom - top)) for i, (left, top, right, bottom) in enumerate(filtered_bboxes)]
    
    return index_added_filtered_bboxes