from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation from PIL import Image import torch import torch.nn.functional as F import numpy as np # Load SegFormer for hair segmentation processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair") model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair") def extract_hair(image: Image.Image) -> Image.Image: """ Return an RGBA image where hair pixels have alpha=255 and all other pixels have alpha=0. """ rgb = image.convert("RGB") arr = np.array(rgb) h, w = arr.shape[:2] # Segment hair inputs = processor(images=rgb, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits.cpu() up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False) seg = up.argmax(dim=1)[0].numpy() hair_mask = (seg == 2).astype(np.uint8) # Build RGBA alpha = (hair_mask * 255).astype(np.uint8) rgba = np.dstack([arr, alpha]) return Image.fromarray(rgba)