# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import cv2 import time import sys import argparse sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) import paddle import paddle.nn.functional as F import numpy as np from PIL import Image, ImageDraw import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from segment_anything.modeling.clip_paddle import build_clip_model, _transform from segment_anything.utils.sample_tokenizer import tokenize from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list ID_PHOTO_IMAGE_DEMO = "./examples/cityscapes_demo.png" CACHE_DIR = ".temp" model_link = { 'vit_h': "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams", 'vit_l': "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams", 'vit_b': "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams", 'vit_t': "https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam", 'clip_b_32': "https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams" } parser = argparse.ArgumentParser(description=( "Runs automatic mask generation on an input image or directory of images, " "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " "as well as pycocotools if saving in RLE format.")) parser.add_argument( "--model-type", type=str, default="vit_h", required=True, help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", ) def download(img): if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR) while True: name = str(int(time.time())) tmp_name = os.path.join(CACHE_DIR, name + '.jpg') if not os.path.exists(tmp_name): break else: time.sleep(1) img.save(tmp_name, 'png') return tmp_name def segment_image(image, segment_mask): image_array = np.array(image) gray_image = Image.new("RGB", image.size, (128, 128, 128)) segmented_image_array = np.zeros_like(image_array) segmented_image_array[segment_mask] = image_array[segment_mask] segmented_image = Image.fromarray(segmented_image_array) transparency = np.zeros_like(segment_mask, dtype=np.uint8) transparency[segment_mask] = 255 transparency_image = Image.fromarray(transparency, mode='L') gray_image.paste(segmented_image, mask=transparency_image) return gray_image def image_text_match(cropped_objects, text_query): transformed_images = [transform(image) for image in cropped_objects] tokenized_text = tokenize([text_query]) batch_images = paddle.stack(transformed_images) image_features = model.encode_image(batch_images) print("encode_image done!") text_features = model.encode_text(tokenized_text) print("encode_text done!") image_features /= image_features.norm(axis=-1, keepdim=True) text_features /= text_features.norm(axis=-1, keepdim=True) if len(text_features.shape) == 3: text_features = text_features.squeeze(0) probs = 100. * image_features @text_features.T return F.softmax(probs[:, 0], axis=0) def masks2pseudomap(masks): result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255 for i, mask_data in enumerate(masks): result[mask_data["segmentation"] == 1] = i + 1 pred_result = result result = get_pseudo_color_map(result) return pred_result, result def visualize(image, result, color_map, weight=0.6): """ Convert predict result to color image, and save added image. Args: image (str): The path of origin image. result (np.ndarray): The predict result of image. color_map (list): The color used to save the prediction results. save_dir (str): The directory for saving visual image. Default: None. weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6 Returns: vis_result (np.ndarray): If `save_dir` is None, return the visualized result. """ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] color_map = np.array(color_map).astype("uint8") # Use OpenCV LUT for color mapping c1 = cv2.LUT(result, color_map[:, 0]) c2 = cv2.LUT(result, color_map[:, 1]) c3 = cv2.LUT(result, color_map[:, 2]) pseudo_img = np.dstack((c3, c2, c1)) vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0) return vis_result def get_id_photo_output(image, text): """ Get the special size and background photo. Args: img(numpy:ndarray): The image array. size(str): The size user specified. bg(str): The background color user specified. download_size(str): The size for image saving. """ image_ori = image.copy() image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) masks = mask_generator.generate(image) pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image added_pseudo_map = visualize( image, pred_result, color_map=get_color_map_list(256)) cropped_objects = [] image_pil = Image.fromarray(image) for mask in masks: bbox = [ mask["bbox"][0], mask["bbox"][1], mask["bbox"][0] + mask["bbox"][2], mask["bbox"][1] + mask["bbox"][3] ] cropped_objects.append( segment_image(image_pil, mask["segmentation"]).crop(bbox)) scores = image_text_match(cropped_objects, str(text)) text_matching_masks = [] for idx, score in enumerate(scores): if score < 0.05: continue text_matching_mask = Image.fromarray( masks[idx]["segmentation"].astype('uint8') * 255) text_matching_masks.append(text_matching_mask) image_pil_ori = Image.fromarray(image_ori) alpha_image = Image.new('RGBA', image_pil_ori.size, (0, 0, 0, 0)) alpha_color = (255, 0, 0, 180) draw = ImageDraw.Draw(alpha_image) for text_matching_mask in text_matching_masks: draw.bitmap((0, 0), text_matching_mask, fill=alpha_color) result_image = Image.alpha_composite( image_pil_ori.convert('RGBA'), alpha_image) res_download = download(result_image) return result_image, added_pseudo_map, res_download def gradio_display(): import gradio as gr examples_sam = [["./examples/cityscapes_demo.png", "a photo of car"], ["examples/dog.jpg", "dog"], ["examples/zixingche.jpeg", "kid"]] demo_mask_sam = gr.Interface( fn=get_id_photo_output, inputs=[ gr.Image(label="Input image", height=400), gr.Textbox(label="Input text prompt", value="a car"), ], outputs=[ gr.Image(label="Output based on text", height=300), gr.Image(label="Output mask", height=300) ], examples=examples_sam, description="

\ SAM+CLIP: Text prompt for segmentation.
\ Choose an example below; Or, upload by yourself:
\ 1. Upload images to be tested to 'input image'. 2. Input a text prompt to 'input text prompt' and click 'submit'.
\

", cache_examples=False, flagging_mode="never" ) demo = gr.TabbedInterface( [demo_mask_sam], ['SAM+CLIP(Text to Segment)'], title=" 🔥 Text to Segment Anything with PaddleSeg 🔥" ) demo.launch( server_name="0.0.0.0", server_port=8078, share=True ) args = parser.parse_args() print("Loading model...") if paddle.is_compiled_with_cuda(): paddle.set_device("gpu") else: paddle.set_device("cpu") sam = sam_model_registry[args.model_type]( checkpoint=model_link[args.model_type]) mask_generator = SamAutomaticMaskGenerator(sam) model, transform = build_clip_model(model_link["clip_b_32"]) gradio_display()