Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,038 Bytes
e546fea 13a0890 9bb32c5 e546fea 958511f e546fea 958511f b218be6 958511f e546fea 5cb992e 6438ac6 958511f 3e75999 b218be6 2d64873 5cb992e 6438ac6 2d64873 958511f e546fea 3e75999 2d64873 b218be6 13a0890 6438ac6 b218be6 2d64873 e546fea 12472ea b218be6 20a2fe0 b218be6 b2b24c7 b218be6 958511f b218be6 e546fea 958511f b218be6 958511f e546fea 6438ac6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from typing import Union, Tuple
from PIL import Image
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
"""
Remove the background from an image and return both the transparent version and the original.
This function performs background removal using a BiRefNet segmentation model. It is intended for use
with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
with the background removed, along with the original RGB version for comparison.
Args:
image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
Returns:
tuple:
- processed_image (PIL.Image): The input image with the background removed and transparency applied.
- origin (PIL.Image): The original RGB image, unchanged.
"""
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
return (processed_image, origin)
@spaces.GPU
def process(image: Image.Image) -> Image.Image:
"""
Apply BiRefNet-based image segmentation to remove the background.
This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
and applies the mask as an alpha (transparency) channel to the original image.
Args:
image (PIL.Image): The input RGB image.
Returns:
PIL.Image: The image with the background removed, using the segmentation mask as transparency.
"""
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f: str) -> str:
"""
Load an image file from disk, remove the background, and save the output as a transparent PNG.
Args:
f (str): Filepath of the image to process.
Returns:
str: Path to the saved PNG image with background removed.
"""
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True) |