Spaces:
Runtime error
Runtime error
File size: 3,035 Bytes
86fe406 3659121 fd943d1 648e4b5 86fe406 fd943d1 86fe406 648e4b5 3659121 648e4b5 86fe406 3659121 86fe406 3659121 86fe406 3659121 9b5c5ae 3659121 86fe406 fd943d1 86fe406 fd943d1 3659121 fd943d1 86fe406 fd943d1 3659121 fd943d1 86fe406 fd943d1 86fe406 9b5c5ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from gradio_imageslider import ImageSlider
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import io
from PIL import Image
import requests
from io import BytesIO
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def load_img(image, output_type="pil"):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image).convert("RGB")
elif isinstance(image, Image.Image):
image = image.convert("RGB")
else:
raise ValueError("Unsupported image type")
if output_type == "pil":
return image
elif output_type == "numpy":
return np.array(image)
else:
raise ValueError("Unsupported output type")
@spaces.GPU
def fn(image):
if image is None or len(image) == 0:
return image, None
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
# Convert image to bytes for download
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return (image, origin), img_byte_arr
def create_download_component(img_bytes):
if img_bytes is not None:
return gr.File(value=img_bytes, visible=True, label="Download Result")
return gr.File(visible=False)
slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
chameleon = load_img("butterfly.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
fn,
inputs=image,
outputs=[slider1, gr.File(label="Download Result")],
examples=[chameleon],
api_name="image"
)
tab2 = gr.Interface(
fn,
inputs=text,
outputs=[slider2, gr.File(label="Download Result")],
examples=[url],
api_name="text"
)
demo = gr.TabbedInterface(
[tab1, tab2],
["image", "text"],
title="birefnet for background removal"
)
if __name__ == "__main__":
demo.launch() |