File size: 3,035 Bytes
86fe406
 
 
 
 
 
3659121
fd943d1
648e4b5
 
86fe406
 
 
 
 
 
 
fd943d1
86fe406
 
 
 
 
 
 
 
648e4b5
 
 
 
 
 
 
 
 
 
 
 
 
 
3659121
 
648e4b5
 
 
86fe406
 
 
3659121
86fe406
 
 
 
 
 
3659121
86fe406
 
 
 
 
 
 
3659121
 
 
 
 
 
 
9b5c5ae
3659121
 
 
 
86fe406
 
 
 
 
 
 
fd943d1
86fe406
fd943d1
 
3659121
fd943d1
 
86fe406
 
fd943d1
 
 
3659121
fd943d1
 
 
86fe406
 
fd943d1
 
 
86fe406
 
 
9b5c5ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from gradio_imageslider import ImageSlider
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import io
from PIL import Image
import requests
from io import BytesIO

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def load_img(image, output_type="pil"):
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            response = requests.get(image)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        else:
            image = Image.open(image).convert("RGB")
    elif isinstance(image, Image.Image):
        image = image.convert("RGB")
    else:
        raise ValueError("Unsupported image type")
    
    if output_type == "pil":
        return image
    elif output_type == "numpy":
        return np.array(image)
    else:
        raise ValueError("Unsupported output type")

@spaces.GPU
def fn(image):
    if image is None or len(image) == 0:
        return image, None
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    
    # Convert image to bytes for download
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    
    return (image, origin), img_byte_arr

def create_download_component(img_bytes):
    if img_bytes is not None:
        return gr.File(value=img_bytes, visible=True, label="Download Result")
    return gr.File(visible=False)

slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
chameleon = load_img("butterfly.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

tab1 = gr.Interface(
    fn,
    inputs=image,
    outputs=[slider1, gr.File(label="Download Result")],
    examples=[chameleon],
    api_name="image"
)

tab2 = gr.Interface(
    fn,
    inputs=text,
    outputs=[slider2, gr.File(label="Download Result")],
    examples=[url],
    api_name="text"
)

demo = gr.TabbedInterface(
    [tab1, tab2],
    ["image", "text"],
    title="birefnet for background removal"
)

if __name__ == "__main__":
    demo.launch()