Spaces:
Running
Running
File size: 2,151 Bytes
13aa528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import numpy as np
from pathlib import Path
from typing import List, Tuple
import logging
from PIL import Image
from torchvision.transforms.v2 import ToPILImage
from scripts.parse_cut_from_page import extract_cutbox_coordinates
from src.detectors import AnimeDetector
from src.pipelines import TagAndFilteringPipeline
from src.taggers import WaifuDiffusionTagger
from src.utils.device import determine_accelerator
topil = ToPILImage()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# 1. Initialize the filtering pipeline
device = determine_accelerator()
logger.info(f"Using device: {device}")
logger.info("Initializing filtering pipeline...")
detector = AnimeDetector(
repo_id="deepghs/anime_face_detection",
model_name="face_detect_v1.4_s",
hf_token=None,
)
tagger = WaifuDiffusionTagger(device=device)
filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector)
def process_single_image(lineart_pil_img: Image.Image) -> List[Tuple[int, Tuple[int, int, int, int]]]:
"""
Worker function to process a single set of images.
Opens images, extracts bounding boxes, crops and filters the cut images.
"""
try:
line_img = lineart_pil_img.convert("RGB")
except Exception as e:
logger.error(f"Error loading images for {lineart_pil_img}: {e}")
return
bounding_boxes = extract_cutbox_coordinates(line_img)
images = [topil(np.array(line_img)[top:bottom, left:right]) for (left, top, right, bottom) in bounding_boxes]
# 3. Filter images using the filtering pipeline
logger.info(f"Filtering images...")
filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7)
filter_flags = filter_output.filter_flags
logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.")
filtered_bboxes = [bb for bb, flag in zip(bounding_boxes, filter_flags) if flag]
index_added_filtered_bboxes = [(i+1, (left, top, right - left, bottom - top)) for i, (left, top, right, bottom) in enumerate(filtered_bboxes)]
return index_added_filtered_bboxes
|