# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This implementation refers to: https://github.com/facebookresearch/segment-anything import os import sys sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) import time import cv2 # type: ignore import argparse import numpy as np # type: ignore import paddle from segment_anything import SamAutomaticMaskGenerator, sam_model_registry from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list ID_PHOTO_IMAGE_DEMO = "examples/cityscapes_demo.png" CACHE_DIR = ".temp" model_link = { 'vit_h': "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams", 'vit_l': "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams", 'vit_b': "https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams", 'vit_t': "https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam" } parser = argparse.ArgumentParser(description=( "Runs automatic mask generation on an input image or directory of images, " "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " "as well as pycocotools if saving in RLE format.")) parser.add_argument( "--model-type", type=str, default="vit_l", required=True, help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", ) parser.add_argument( "--convert-to-rle", action="store_true", help=( "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " "Requires pycocotools."), ) amg_settings = parser.add_argument_group("AMG Settings") amg_settings.add_argument( "--points-per-side", type=int, default=None, help="Generate masks by sampling a grid over the image with this many points to a side.", ) amg_settings.add_argument( "--points-per-batch", type=int, default=None, help="How many input points to process simultaneously in one batch.", ) amg_settings.add_argument( "--pred-iou-thresh", type=float, default=None, help="Exclude masks with a predicted score from the model that is lower than this threshold.", ) amg_settings.add_argument( "--stability-score-thresh", type=float, default=None, help="Exclude masks with a stability score lower than this threshold.", ) amg_settings.add_argument( "--stability-score-offset", type=float, default=None, help="Larger values perturb the mask more when measuring stability score.", ) amg_settings.add_argument( "--box-nms-thresh", type=float, default=None, help="The overlap threshold for excluding a duplicate mask.", ) amg_settings.add_argument( "--crop-n-layers", type=int, default=None, help=( "If >0, mask generation is run on smaller crops of the image to generate more masks. " "The value sets how many different scales to crop at."), ) amg_settings.add_argument( "--crop-nms-thresh", type=float, default=None, help="The overlap threshold for excluding duplicate masks across different crops.", ) amg_settings.add_argument( "--crop-overlap-ratio", type=int, default=None, help="Larger numbers mean image crops will overlap more.", ) amg_settings.add_argument( "--crop-n-points-downscale-factor", type=int, default=None, help="The number of points-per-side in each layer of crop is reduced by this factor.", ) amg_settings.add_argument( "--min-mask-region-area", type=int, default=None, help=( "Disconnected mask regions or holes with area smaller than this value " "in pixels are removed by postprocessing."), ) def get_amg_kwargs(args): amg_kwargs = { "points_per_side": args.points_per_side, "points_per_batch": args.points_per_batch, "pred_iou_thresh": args.pred_iou_thresh, "stability_score_thresh": args.stability_score_thresh, "stability_score_offset": args.stability_score_offset, "box_nms_thresh": args.box_nms_thresh, "crop_n_layers": args.crop_n_layers, "crop_nms_thresh": args.crop_nms_thresh, "crop_overlap_ratio": args.crop_overlap_ratio, "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, "min_mask_region_area": args.min_mask_region_area, } amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} return amg_kwargs def delete_result(): """clear old result in `.temp`""" results = sorted(os.listdir(CACHE_DIR)) for res in results: if int(time.time()) - int(os.path.splitext(res)[0]) > 10000: os.remove(os.path.join(CACHE_DIR, res)) def download(img): if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR) while True: name = str(int(time.time())) tmp_name = os.path.join(CACHE_DIR, name + '.jpg') if not os.path.exists(tmp_name): break else: time.sleep(1) img.save(tmp_name, 'png') return tmp_name def masks2pseudomap(masks): result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255 for i, mask_data in enumerate(masks): result[mask_data["segmentation"] == 1] = i + 1 pred_result = result result = get_pseudo_color_map(result) return pred_result, result def visualize(image, result, color_map, weight=0.6): """ Convert predict result to color image, and save added image. Args: image (str): The path of origin image. result (np.ndarray): The predict result of image. color_map (list): The color used to save the prediction results. save_dir (str): The directory for saving visual image. Default: None. weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6 Returns: vis_result (np.ndarray): If `save_dir` is None, return the visualized result. """ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] color_map = np.array(color_map).astype("uint8") # Use OpenCV LUT for color mapping c1 = cv2.LUT(result, color_map[:, 0]) c2 = cv2.LUT(result, color_map[:, 1]) c3 = cv2.LUT(result, color_map[:, 2]) pseudo_img = np.dstack((c3, c2, c1)) # im = cv2.imread(image) vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0) return vis_result def gradio_display(generator): import gradio as gr def clear_image_all(): delete_result() return None, None, None, None def get_id_photo_output(img): """ Get the special size and background photo. Args: img(numpy:ndarray): The image array. size(str): The size user specified. bg(str): The background color user specified. download_size(str): The size for image saving. """ predictor = generator masks = predictor.generate(img) pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image added_pseudo_map = visualize( img, pred_result, color_map=get_color_map_list(256)) res_download = download(pseudo_map) return pseudo_map, added_pseudo_map, res_download with gr.Blocks() as demo: gr.Markdown("""# Segment Anything (PaddleSeg) """) with gr.Tab("InputImage"): image_in = gr.Image(value=ID_PHOTO_IMAGE_DEMO, label="Input image") with gr.Row(): image_clear_btn = gr.Button("Clear") image_submit_btn = gr.Button("Submit") with gr.Row(): img_out1 = gr.Image( label="Output image", interactive=False).style(height=300) img_out2 = gr.Image( label="Output image with mask", interactive=False).style(height=300) downloaded_img = gr.File(label='Image download').style(height=50) image_clear_btn.click( fn=clear_image_all, inputs=None, outputs=[image_in, img_out1, img_out2, downloaded_img]) image_submit_btn.click( fn=get_id_photo_output, inputs=[image_in, ], outputs=[img_out1, img_out2, downloaded_img]) gr.Markdown( """Tips: You can try segment the default image OR upload any images you want to segment by click on the clear button first.""" ) gr.Markdown( """This is Segment Anything build with PaddlePaddle. We refer to the [SAM](https://github.com/facebookresearch/segment-anything) for code strucure and model architecture. If you have any question or feature request, welcome to raise issues on [GitHub](https://github.com/PaddlePaddle/PaddleSeg/issues). """ ) gr.Button.style(1) demo.launch(server_name="0.0.0.0", server_port=8017, share=True) def main(args: argparse.Namespace) -> None: print("Loading model...") sam = sam_model_registry[args.model_type]( checkpoint=model_link[args.model_type]) if paddle.is_compiled_with_cuda(): paddle.set_device("gpu") else: paddle.set_device("cpu") output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" amg_kwargs = get_amg_kwargs(args) generator = SamAutomaticMaskGenerator( sam, output_mode=output_mode, **amg_kwargs) gradio_display(generator) if __name__ == "__main__": args = parser.parse_args() main(args)