danhtran2mind's picture
Upload 164 files
b7f710c verified
import torch
from PIL import Image
from typing import Union, List, Tuple
from . import mtcnn
from .face_yolo import face_yolo_detection
# Device configuration
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Initialize MTCNN model
MTCNN_MODEL = mtcnn.MTCNN(device=DEVICE, crop_size=(112, 112))
def add_image_padding(pil_img: Image.Image, top: int, right: int, bottom: int, left: int,
color: Tuple[int, int, int] = (0, 0, 0)) -> Image.Image:
"""Add padding to a PIL image."""
width, height = pil_img.size
new_width, new_height = width + right + left, height + top + bottom
padded_img = Image.new(pil_img.mode, (new_width, new_height), color)
padded_img.paste(pil_img, (left, top))
return padded_img
def detect_faces_mtcnn(image: Union[str, Image.Image]) -> Tuple[Union[list, None], Union[Image.Image, None]]:
"""Detect and align faces using MTCNN model."""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
if not isinstance(image, Image.Image):
raise TypeError("Input must be a PIL Image or path to an image")
try:
bboxes, faces = MTCNN_MODEL.align_multi(image, limit=1)
return bboxes[0] if bboxes else None, faces[0] if faces else None
except Exception as e:
print(f"MTCNN face detection failed: {e}")
return None, None
def get_aligned_face(image_input: Union[str, List[str]],
algorithm: str = 'mtcnn') -> List[Tuple[Union[list, None], Union[Image.Image, None]]]:
"""Get aligned faces from image(s) using specified algorithm."""
if algorithm not in ['mtcnn', 'yolo']:
raise ValueError("Algorithm must be 'mtcnn' or 'yolo'")
# Convert single image path to list for consistent processing
image_paths = [image_input] if isinstance(image_input, str) else image_input
if not isinstance(image_paths, list):
raise TypeError("Input must be a string or list of strings")
if algorithm == 'mtcnn':
return [detect_faces_mtcnn(path) for path in image_paths]
# YOLO detection
results = face_yolo_detection(
image_paths,
use_batch=True,
device=DEVICE
)
return list(results)