segmentanything / amg_paddle.py
hbazai's picture
Upload folder using huggingface_hub
a1687ef verified
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This implementation refers to: https://github.com/facebookresearch/segment-anything
import os
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
import time
import cv2 # type: ignore
import argparse
import numpy as np # type: ignore
import paddle
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list
ID_PHOTO_IMAGE_DEMO = "examples/cityscapes_demo.png"
CACHE_DIR = ".temp"
model_link = {
'vit_h':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
'vit_l':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
'vit_t':
"https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam"
}
parser = argparse.ArgumentParser(description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."))
parser.add_argument(
"--model-type",
type=str,
default="vit_l",
required=True,
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", )
parser.add_argument(
"--convert-to-rle",
action="store_true",
help=(
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."), )
amg_settings = parser.add_argument_group("AMG Settings")
amg_settings.add_argument(
"--points-per-side",
type=int,
default=None,
help="Generate masks by sampling a grid over the image with this many points to a side.",
)
amg_settings.add_argument(
"--points-per-batch",
type=int,
default=None,
help="How many input points to process simultaneously in one batch.", )
amg_settings.add_argument(
"--pred-iou-thresh",
type=float,
default=None,
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-thresh",
type=float,
default=None,
help="Exclude masks with a stability score lower than this threshold.", )
amg_settings.add_argument(
"--stability-score-offset",
type=float,
default=None,
help="Larger values perturb the mask more when measuring stability score.",
)
amg_settings.add_argument(
"--box-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding a duplicate mask.", )
amg_settings.add_argument(
"--crop-n-layers",
type=int,
default=None,
help=(
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
"The value sets how many different scales to crop at."), )
amg_settings.add_argument(
"--crop-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding duplicate masks across different crops.",
)
amg_settings.add_argument(
"--crop-overlap-ratio",
type=int,
default=None,
help="Larger numbers mean image crops will overlap more.", )
amg_settings.add_argument(
"--crop-n-points-downscale-factor",
type=int,
default=None,
help="The number of points-per-side in each layer of crop is reduced by this factor.",
)
amg_settings.add_argument(
"--min-mask-region-area",
type=int,
default=None,
help=(
"Disconnected mask regions or holes with area smaller than this value "
"in pixels are removed by postprocessing."), )
def get_amg_kwargs(args):
amg_kwargs = {
"points_per_side": args.points_per_side,
"points_per_batch": args.points_per_batch,
"pred_iou_thresh": args.pred_iou_thresh,
"stability_score_thresh": args.stability_score_thresh,
"stability_score_offset": args.stability_score_offset,
"box_nms_thresh": args.box_nms_thresh,
"crop_n_layers": args.crop_n_layers,
"crop_nms_thresh": args.crop_nms_thresh,
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
def delete_result():
"""clear old result in `.temp`"""
results = sorted(os.listdir(CACHE_DIR))
for res in results:
if int(time.time()) - int(os.path.splitext(res)[0]) > 10000:
os.remove(os.path.join(CACHE_DIR, res))
def download(img):
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
while True:
name = str(int(time.time()))
tmp_name = os.path.join(CACHE_DIR, name + '.jpg')
if not os.path.exists(tmp_name):
break
else:
time.sleep(1)
img.save(tmp_name, 'png')
return tmp_name
def masks2pseudomap(masks):
result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255
for i, mask_data in enumerate(masks):
result[mask_data["segmentation"] == 1] = i + 1
pred_result = result
result = get_pseudo_color_map(result)
return pred_result, result
def visualize(image, result, color_map, weight=0.6):
"""
Convert predict result to color image, and save added image.
Args:
image (str): The path of origin image.
result (np.ndarray): The predict result of image.
color_map (list): The color used to save the prediction results.
save_dir (str): The directory for saving visual image. Default: None.
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
Returns:
vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
"""
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c3, c2, c1))
# im = cv2.imread(image)
vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
return vis_result
def gradio_display(generator):
import gradio as gr
def clear_image_all():
delete_result()
return None, None, None, None
def get_id_photo_output(img):
"""
Get the special size and background photo.
Args:
img(numpy:ndarray): The image array.
size(str): The size user specified.
bg(str): The background color user specified.
download_size(str): The size for image saving.
"""
predictor = generator
masks = predictor.generate(img)
pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image
added_pseudo_map = visualize(
img, pred_result, color_map=get_color_map_list(256))
res_download = download(pseudo_map)
return pseudo_map, added_pseudo_map, res_download
with gr.Blocks() as demo:
gr.Markdown("""# Segment Anything (PaddleSeg) """)
with gr.Tab("InputImage"):
image_in = gr.Image(value=ID_PHOTO_IMAGE_DEMO, label="Input image")
with gr.Row():
image_clear_btn = gr.Button("Clear")
image_submit_btn = gr.Button("Submit")
with gr.Row():
img_out1 = gr.Image(
label="Output image", interactive=False).style(height=300)
img_out2 = gr.Image(
label="Output image with mask",
interactive=False).style(height=300)
downloaded_img = gr.File(label='Image download').style(height=50)
image_clear_btn.click(
fn=clear_image_all,
inputs=None,
outputs=[image_in, img_out1, img_out2, downloaded_img])
image_submit_btn.click(
fn=get_id_photo_output,
inputs=[image_in, ],
outputs=[img_out1, img_out2, downloaded_img])
gr.Markdown(
"""<font color=Gray>Tips: You can try segment the default image OR upload any images you want to segment by click on the clear button first.</font>"""
)
gr.Markdown(
"""<font color=Gray>This is Segment Anything build with PaddlePaddle.
We refer to the [SAM](https://github.com/facebookresearch/segment-anything) for code strucure and model architecture.
If you have any question or feature request, welcome to raise issues on [GitHub](https://github.com/PaddlePaddle/PaddleSeg/issues). </font>"""
)
gr.Button.style(1)
demo.launch(server_name="0.0.0.0", server_port=8017, share=True)
def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](
checkpoint=model_link[args.model_type])
if paddle.is_compiled_with_cuda():
paddle.set_device("gpu")
else:
paddle.set_device("cpu")
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(
sam, output_mode=output_mode, **amg_kwargs)
gradio_display(generator)
if __name__ == "__main__":
args = parser.parse_args()
main(args)