Spaces:
Running
Running
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
# Load SegFormer for hair segmentation | |
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair") | |
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair") | |
def extract_hair(image: Image.Image) -> Image.Image: | |
""" | |
Return an RGBA image where hair pixels have alpha=255 and | |
all other pixels have alpha=0. | |
""" | |
rgb = image.convert("RGB") | |
arr = np.array(rgb) | |
h, w = arr.shape[:2] | |
# Segment hair | |
inputs = processor(images=rgb, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits.cpu() | |
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False) | |
seg = up.argmax(dim=1)[0].numpy() | |
hair_mask = (seg == 2).astype(np.uint8) | |
# Build RGBA | |
alpha = (hair_mask * 255).astype(np.uint8) | |
rgba = np.dstack([arr, alpha]) | |
return Image.fromarray(rgba) | |