segmentanything / text_to_sam_clip.py
hbazai's picture
Upload folder using huggingface_hub
a1687ef verified
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import time
import sys
import argparse
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
import paddle
import paddle.nn.functional as F
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything.modeling.clip_paddle import build_clip_model, _transform
from segment_anything.utils.sample_tokenizer import tokenize
from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list
ID_PHOTO_IMAGE_DEMO = "./examples/cityscapes_demo.png"
CACHE_DIR = ".temp"
model_link = {
'vit_h':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
'vit_l':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
'vit_t':
"https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam",
'clip_b_32':
"https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams"
}
parser = argparse.ArgumentParser(description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."))
parser.add_argument(
"--model-type",
type=str,
default="vit_h",
required=True,
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", )
def download(img):
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
while True:
name = str(int(time.time()))
tmp_name = os.path.join(CACHE_DIR, name + '.jpg')
if not os.path.exists(tmp_name):
break
else:
time.sleep(1)
img.save(tmp_name, 'png')
return tmp_name
def segment_image(image, segment_mask):
image_array = np.array(image)
gray_image = Image.new("RGB", image.size, (128, 128, 128))
segmented_image_array = np.zeros_like(image_array)
segmented_image_array[segment_mask] = image_array[segment_mask]
segmented_image = Image.fromarray(segmented_image_array)
transparency = np.zeros_like(segment_mask, dtype=np.uint8)
transparency[segment_mask] = 255
transparency_image = Image.fromarray(transparency, mode='L')
gray_image.paste(segmented_image, mask=transparency_image)
return gray_image
def image_text_match(cropped_objects, text_query):
transformed_images = [transform(image) for image in cropped_objects]
tokenized_text = tokenize([text_query])
batch_images = paddle.stack(transformed_images)
image_features = model.encode_image(batch_images)
print("encode_image done!")
text_features = model.encode_text(tokenized_text)
print("encode_text done!")
image_features /= image_features.norm(axis=-1, keepdim=True)
text_features /= text_features.norm(axis=-1, keepdim=True)
if len(text_features.shape) == 3:
text_features = text_features.squeeze(0)
probs = 100. * image_features @text_features.T
return F.softmax(probs[:, 0], axis=0)
def masks2pseudomap(masks):
result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255
for i, mask_data in enumerate(masks):
result[mask_data["segmentation"] == 1] = i + 1
pred_result = result
result = get_pseudo_color_map(result)
return pred_result, result
def visualize(image, result, color_map, weight=0.6):
"""
Convert predict result to color image, and save added image.
Args:
image (str): The path of origin image.
result (np.ndarray): The predict result of image.
color_map (list): The color used to save the prediction results.
save_dir (str): The directory for saving visual image. Default: None.
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
Returns:
vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
"""
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c3, c2, c1))
vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
return vis_result
def get_id_photo_output(image, text):
"""
Get the special size and background photo.
Args:
img(numpy:ndarray): The image array.
size(str): The size user specified.
bg(str): The background color user specified.
download_size(str): The size for image saving.
"""
image_ori = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image
added_pseudo_map = visualize(
image, pred_result, color_map=get_color_map_list(256))
cropped_objects = []
image_pil = Image.fromarray(image)
for mask in masks:
bbox = [
mask["bbox"][0], mask["bbox"][1], mask["bbox"][0] + mask["bbox"][2],
mask["bbox"][1] + mask["bbox"][3]
]
cropped_objects.append(
segment_image(image_pil, mask["segmentation"]).crop(bbox))
scores = image_text_match(cropped_objects, str(text))
text_matching_masks = []
for idx, score in enumerate(scores):
if score < 0.05:
continue
text_matching_mask = Image.fromarray(
masks[idx]["segmentation"].astype('uint8') * 255)
text_matching_masks.append(text_matching_mask)
image_pil_ori = Image.fromarray(image_ori)
alpha_image = Image.new('RGBA', image_pil_ori.size, (0, 0, 0, 0))
alpha_color = (255, 0, 0, 180)
draw = ImageDraw.Draw(alpha_image)
for text_matching_mask in text_matching_masks:
draw.bitmap((0, 0), text_matching_mask, fill=alpha_color)
result_image = Image.alpha_composite(
image_pil_ori.convert('RGBA'), alpha_image)
res_download = download(result_image)
return result_image, added_pseudo_map, res_download
def gradio_display():
import gradio as gr
examples_sam = [["./examples/cityscapes_demo.png", "a photo of car"],
["examples/dog.jpg", "dog"],
["examples/zixingche.jpeg", "kid"]]
demo_mask_sam = gr.Interface(
fn=get_id_photo_output,
inputs=[
gr.Image(label="Input image", height=400),
gr.Textbox(label="Input text prompt", value="a car"),
],
outputs=[
gr.Image(label="Output based on text", height=300),
gr.Image(label="Output mask", height=300)
],
examples=examples_sam,
description="<p> \
<strong>SAM+CLIP: Text prompt for segmentation. </strong> <br>\
Choose an example below; Or, upload by yourself: <br>\
1. Upload images to be tested to 'input image'. 2. Input a text prompt to 'input text prompt' and click 'submit'</strong>. <br>\
</p>",
cache_examples=False,
flagging_mode="never"
)
demo = gr.TabbedInterface(
[demo_mask_sam],
['SAM+CLIP(Text to Segment)'],
title=" 🔥 Text to Segment Anything with PaddleSeg 🔥"
)
demo.launch(
server_name="0.0.0.0",
server_port=8078,
share=True
)
args = parser.parse_args()
print("Loading model...")
if paddle.is_compiled_with_cuda():
paddle.set_device("gpu")
else:
paddle.set_device("cpu")
sam = sam_model_registry[args.model_type](
checkpoint=model_link[args.model_type])
mask_generator = SamAutomaticMaskGenerator(sam)
model, transform = build_clip_model(model_link["clip_b_32"])
gradio_display()