Wopke's picture
Upload folder using huggingface_hub
77783a6 verified
import spaces
import os
import gradio as gr
import gc
try:
import moviepy.editor as mp
got_mp = True
except:
got_mp = False
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import glob
import pathlib
from PIL import Image
import numpy
transform_image = None
birefnet = None
def load_model(model):
global birefnet
birefnet = None
gc.collect()
torch.cuda.empty_cache()
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/"+model, trust_remote_code=True
)
birefnet.eval()
birefnet.half()
spaces.automatically_move_to_gpu_when_forward(birefnet)
with spaces.capture_gpu_object() as birefnet_gpu_obj:
load_model("BiRefNet_HR")
def common_setup(w, h):
global transform_image
transform_image = transforms.Compose(
[
transforms.Resize((w, h)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def process(image, save_flat, bg_colour):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
image = load_img(im)
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
if save_flat:
bg_colour += "FF"
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
background = Image.new("RGBA", image_size, colour_rgb)
image = Image.alpha_composite(background, image)
image = image.convert("RGB")
return image
# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def video_process(video, bg_colour):
# Load the video using moviepy
video = mp.VideoFileClip(video)
fps = video.fps
# Extract audio from the video
audio = video.audio
# Extract frames at the specified FPS
frames = video.iter_frames(fps=fps)
# Process each frame for background removal
processed_frames = []
for i, frame in enumerate(frames):
print (f"birefnet [video]: frame {i+1}", end='\r', flush=True)
image = Image.fromarray(frame)
if i == 0:
image_size = image.size
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5))
background = Image.new("RGBA", image_size, colour_rgb + (255,))
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Apply mask and composite
image.putalpha(mask)
processed_image = Image.alpha_composite(background, image)
processed_frames.append(numpy.array(processed_image))
# Create a new video from the processed frames
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
# Add the original audio back to the processed video
processed_video = processed_video.set_audio(audio)
# Save the processed video using modified original filename (goes to gradio temp)
filename, _ = os.path.splitext(video.filename)
filename += "-birefnet.mp4"
processed_video.write_videofile(filename, codec="libx264")
return filename
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour):
# Ensure output folder exists
os.makedirs(output_folder, exist_ok=True)
# Supported image extensions
image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"]
# Collect all image files from input folder
input_images = []
for ext in image_extensions:
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}')))
if save_flat:
bg_colour += "FF"
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
# Process each image
processed_images = []
for i, image_path in enumerate(input_images):
print (f"birefnet [batch]: image {i+1}", end='\r', flush=True)
try:
# Load image
im = load_img(image_path, output_type="pil")
im = im.convert("RGB")
image_size = im.size
image = load_img(im)
# Prepare image for processing
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Apply mask
image.putalpha(mask)
# Save processed image
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}")
if save_flat:
background = Image.new("RGBA", image_size, colour_rgb)
image = Image.alpha_composite(background, image)
image = image.convert("RGB")
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"):
# jpegs don't support alpha channel, so add .png extension (not change, to avoid potential overwrites)
output_filename += ".png"
if save_png and not output_filename.lower().endswith(".png"):
output_filename += ".png"
image.save(output_filename)
processed_images.append(output_filename)
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
return processed_images
def unload():
global birefnet, transform_image
birefnet = None
transform_image = None
gc.collect()
torch.cuda.empty_cache()
css = """
.gradio-container {
max-width: 1280px !important;
}
footer {
display: none !important;
}
"""
with gr.Blocks(css=css, analytics_enabled=False) as demo:
gr.Markdown("# birefnet for background removal")
with gr.Tab("image"):
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload an image", type='pil', height=584)
go_image = gr.Button("Remove background")
with gr.Column():
result1 = gr.Image(label="birefnet", type="pil", height=544)
with gr.Tab("URL"):
with gr.Row():
with gr.Column():
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1)
go_text = gr.Button("Remove background")
with gr.Column():
result2 = gr.Image(label="birefnet", type="pil", height=544)
if got_mp:
with gr.Tab("video"):
with gr.Row():
with gr.Column():
video = gr.Video(label="Upload a video", height=584)
go_video = gr.Button("Remove background")
with gr.Column():
result4 = gr.Video(label="birefnet", height=544, show_share_button=False)
with gr.Tab("batch"):
with gr.Row():
with gr.Column():
input_dir = gr.Textbox(label="Input folder path", max_lines=1)
output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1)
always_png = gr.Checkbox(label="Always save as PNG", value=True)
go_batch = gr.Button("Remove background(s)")
with gr.Column():
result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple")
with gr.Tab("options"):
gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.")
model = gr.Dropdown(label="Model (download on selection, see console for progress)",
choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value")
gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.")
gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).")
with gr.Row():
proc_sizeW = gr.Slider(label="birefnet processing image width",
minimum=256, maximum=2560, value=2048, step=32)
proc_sizeH = gr.Slider(label="birefnet processing image height",
minimum=256, maximum=2048, value=2048, step=32)
with gr.Row():
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False)
bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True)
model.change(fn=load_model, inputs=model, outputs=None)
gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7")
go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1)
go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2)
if got_mp:
go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
fn=video_process, inputs=[video, bg_colour], outputs=result4)
go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3)
demo.unload(unload)
if __name__ == "__main__":
demo.launch(inbrowser=True)