import gradio as gr import mediapipe as mp import numpy as np from PIL import Image from mediapipe.tasks import python from mediapipe.tasks.python import vision from scipy.ndimage import binary_dilation, label from croper import Croper BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite" category_options = ["hair", "clothes", "background"] base_options = python.BaseOptions(model_asset_path=MODEL_PATH) options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) segmenter = vision.ImageSegmenter.create_from_options(options) labels = segmenter.labels def get_session_token(request: gr.Request): x_ip_token = request.headers['x-ip-token'] return x_ip_token def segment(input_image, category): image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) segmentation_result = segmenter.segment(image) category_mask = segmentation_result.category_mask category_mask_np = category_mask.numpy_view() if category == "hair": target_mask = get_hair_mask(category_mask_np, should_dilate=True) elif category == "clothes": target_mask = get_clothes_mask(category_mask_np) else: target_mask = category_mask_np == 0 croper = Croper(input_image, target_mask) croper.corp_mask_image() restore_image = croper.restore_result(croper.resized_square_image) mask_image = croper.resized_square_mask_image return mask_image, restore_image def get_clothes_mask(category_mask_np): body_skin_mask = category_mask_np == 2 clothes_mask = category_mask_np == 4 combined_mask = np.logical_or(body_skin_mask, clothes_mask) combined_mask = binary_dilation(combined_mask, iterations=4) return combined_mask def get_hair_mask(category_mask_np, should_dilate=False): hair_mask = category_mask_np == 1 hair_mask = binary_dilation(hair_mask, iterations=4) if not should_dilate: return hair_mask body_skin_mask = category_mask_np == 2 face_skin_mask = category_mask_np == 3 clothes_mask = category_mask_np == 4 face_indices = np.where(face_skin_mask) min_face_y = np.min(face_indices[0]) max_face_y = np.max(face_indices[0]) labeled_hair, hair_features = label(hair_mask) top_hair_mask = np.zeros_like(hair_mask) for i in range(1, hair_features + 1): component_mask = labeled_hair == i component_indices = np.where(component_mask) min_component_y = np.min(component_indices[0]) if min_component_y <= min_face_y: top_hair_mask[component_mask] = True # Combine the reference masks (body, clothes) reference_mask = np.logical_or(body_skin_mask, clothes_mask) # Remove the area above the face by 40 pixels reference_mask[:max_face_y+40, :] = 0 # Expand the hair mask downward until it reaches the reference areas expanded_hair_mask = top_hair_mask while not np.any(np.logical_and(expanded_hair_mask, reference_mask)): expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10) # Trim the expanded_hair_mask # 1. Remove the area above hair_mask by 10 pixels hair_indices = np.where(hair_mask) min_hair_y = np.min(hair_indices[0]) expanded_hair_mask[:min_hair_y - 10, :] = 0 # 2. Remove the areas on both sides that exceed the clothing coordinates clothes_indices = np.where(clothes_mask) min_clothes_x = np.min(clothes_indices[1]) max_clothes_x = np.max(clothes_indices[1]) expanded_hair_mask[:, :min_clothes_x] = 0 expanded_hair_mask[:, max_clothes_x+1:] = 0 # exclude the face-skin, body-skin and clothes areas expanded_hair_mask = np.logical_and(expanded_hair_mask, ~face_skin_mask) expanded_hair_mask = np.logical_and(expanded_hair_mask, ~body_skin_mask) expanded_hair_mask = np.logical_and(expanded_hair_mask, ~clothes_mask) # combine the hair mask with the expanded hair mask expanded_hair_mask = np.logical_or(hair_mask, expanded_hair_mask) return expanded_hair_mask with gr.Blocks() as app: with gr.Row(): with gr.Column(): input_image = gr.Image(type='pil', label='Upload image') category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0]) submit_btn = gr.Button(value='Submit', variant='primary') session_token = gr.Textbox(label='Session token', value='') with gr.Column(): mask_image = gr.Image(type='pil', label='Segmentation mask') output_image = gr.Image(type='pil', label='Segmented image') submit_btn.click( fn=segment, inputs=[ input_image, category, ], outputs=[mask_image, output_image] ) app.load(get_session_token, None, session_token) app.launch(debug=False, show_error=True)