Spaces:
Runtime error
Runtime error
import spaces | |
import os | |
import gradio as gr | |
import gc | |
try: | |
import moviepy.editor as mp | |
got_mp = True | |
except: | |
got_mp = False | |
from loadimg import load_img | |
from transformers import AutoModelForImageSegmentation | |
import torch | |
from torchvision import transforms | |
import glob | |
import pathlib | |
from PIL import Image | |
import numpy | |
transform_image = None | |
birefnet = None | |
def load_model(model): | |
global birefnet | |
birefnet = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/"+model, trust_remote_code=True | |
) | |
birefnet.eval() | |
birefnet.half() | |
spaces.automatically_move_to_gpu_when_forward(birefnet) | |
with spaces.capture_gpu_object() as birefnet_gpu_obj: | |
load_model("BiRefNet_HR") | |
def common_setup(w, h): | |
global transform_image | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((w, h)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
def process(image, save_flat, bg_colour): | |
im = load_img(image, output_type="pil") | |
im = im.convert("RGB") | |
image_size = im.size | |
image = load_img(im) | |
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16) | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_image)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
image.putalpha(mask) | |
if save_flat: | |
bg_colour += "FF" | |
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7)) | |
background = Image.new("RGBA", image_size, colour_rgb) | |
image = Image.alpha_composite(background, image) | |
image = image.convert("RGB") | |
return image | |
# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py | |
def video_process(video, bg_colour): | |
# Load the video using moviepy | |
video = mp.VideoFileClip(video) | |
fps = video.fps | |
# Extract audio from the video | |
audio = video.audio | |
# Extract frames at the specified FPS | |
frames = video.iter_frames(fps=fps) | |
# Process each frame for background removal | |
processed_frames = [] | |
for i, frame in enumerate(frames): | |
print (f"birefnet [video]: frame {i+1}", end='\r', flush=True) | |
image = Image.fromarray(frame) | |
if i == 0: | |
image_size = image.size | |
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5)) | |
background = Image.new("RGBA", image_size, colour_rgb + (255,)) | |
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16) | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_image)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
# Apply mask and composite | |
image.putalpha(mask) | |
processed_image = Image.alpha_composite(background, image) | |
processed_frames.append(numpy.array(processed_image)) | |
# Create a new video from the processed frames | |
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps) | |
# Add the original audio back to the processed video | |
processed_video = processed_video.set_audio(audio) | |
# Save the processed video using modified original filename (goes to gradio temp) | |
filename, _ = os.path.splitext(video.filename) | |
filename += "-birefnet.mp4" | |
processed_video.write_videofile(filename, codec="libx264") | |
return filename | |
def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour): | |
# Ensure output folder exists | |
os.makedirs(output_folder, exist_ok=True) | |
# Supported image extensions | |
image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"] | |
# Collect all image files from input folder | |
input_images = [] | |
for ext in image_extensions: | |
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}'))) | |
if save_flat: | |
bg_colour += "FF" | |
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7)) | |
# Process each image | |
processed_images = [] | |
for i, image_path in enumerate(input_images): | |
print (f"birefnet [batch]: image {i+1}", end='\r', flush=True) | |
try: | |
# Load image | |
im = load_img(image_path, output_type="pil") | |
im = im.convert("RGB") | |
image_size = im.size | |
image = load_img(im) | |
# Prepare image for processing | |
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16) | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_image)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image_size) | |
# Apply mask | |
image.putalpha(mask) | |
# Save processed image | |
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}") | |
if save_flat: | |
background = Image.new("RGBA", image_size, colour_rgb) | |
image = Image.alpha_composite(background, image) | |
image = image.convert("RGB") | |
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"): | |
# jpegs don't support alpha channel, so add .png extension (not change, to avoid potential overwrites) | |
output_filename += ".png" | |
if save_png and not output_filename.lower().endswith(".png"): | |
output_filename += ".png" | |
image.save(output_filename) | |
processed_images.append(output_filename) | |
except Exception as e: | |
print(f"Error processing {image_path}: {str(e)}") | |
return processed_images | |
def unload(): | |
global birefnet, transform_image | |
birefnet = None | |
transform_image = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
css = """ | |
.gradio-container { | |
max-width: 1280px !important; | |
} | |
footer { | |
display: none !important; | |
} | |
""" | |
with gr.Blocks(css=css, analytics_enabled=False) as demo: | |
gr.Markdown("# birefnet for background removal") | |
with gr.Tab("image"): | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Upload an image", type='pil', height=584) | |
go_image = gr.Button("Remove background") | |
with gr.Column(): | |
result1 = gr.Image(label="birefnet", type="pil", height=544) | |
with gr.Tab("URL"): | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1) | |
go_text = gr.Button("Remove background") | |
with gr.Column(): | |
result2 = gr.Image(label="birefnet", type="pil", height=544) | |
if got_mp: | |
with gr.Tab("video"): | |
with gr.Row(): | |
with gr.Column(): | |
video = gr.Video(label="Upload a video", height=584) | |
go_video = gr.Button("Remove background") | |
with gr.Column(): | |
result4 = gr.Video(label="birefnet", height=544, show_share_button=False) | |
with gr.Tab("batch"): | |
with gr.Row(): | |
with gr.Column(): | |
input_dir = gr.Textbox(label="Input folder path", max_lines=1) | |
output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1) | |
always_png = gr.Checkbox(label="Always save as PNG", value=True) | |
go_batch = gr.Button("Remove background(s)") | |
with gr.Column(): | |
result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple") | |
with gr.Tab("options"): | |
gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.") | |
model = gr.Dropdown(label="Model (download on selection, see console for progress)", | |
choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value") | |
gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.") | |
gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).") | |
with gr.Row(): | |
proc_sizeW = gr.Slider(label="birefnet processing image width", | |
minimum=256, maximum=2560, value=2048, step=32) | |
proc_sizeH = gr.Slider(label="birefnet processing image height", | |
minimum=256, maximum=2048, value=2048, step=32) | |
with gr.Row(): | |
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False) | |
bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True) | |
model.change(fn=load_model, inputs=model, outputs=None) | |
gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7") | |
go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1) | |
go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2) | |
if got_mp: | |
go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then( | |
fn=video_process, inputs=[video, bg_colour], outputs=result4) | |
go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then( | |
fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3) | |
demo.unload(unload) | |
if __name__ == "__main__": | |
demo.launch(inbrowser=True) | |