Spaces:
Runtime error
Runtime error
import cv2 | |
import gradio as gr | |
import os | |
import requests | |
from PIL import Image | |
import numpy as np | |
import torch | |
from torch.autograd import Variable | |
from torchvision import transforms | |
import torch.nn.functional as F | |
# Automatically download required files | |
# 1. data_loader_cache.py from GitHub | |
if not os.path.exists("data_loader_cache.py"): | |
print("Downloading data_loader_cache.py...") | |
try: | |
response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/DIS/IS-Net/data_loader_cache.py") | |
response.raise_for_status() | |
with open("data_loader_cache.py", "wb") as f: | |
f.write(response.content) | |
except requests.RequestException as e: | |
print(f"Failed to download data_loader_cache.py: {e}") | |
raise | |
# 2. models.py from GitHub | |
if not os.path.exists("models.py"): | |
print("Downloading models.py...") | |
try: | |
response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/DIS/IS-Net/models.py") | |
response.raise_for_status() | |
with open("models.py", "wb") as f: | |
f.write(response.content) | |
except requests.RequestException as e: | |
print(f"Failed to download models.py: {e}") | |
raise | |
# 3. isnet.pth from Hugging Face Git LFS (direct URL from screenshot) | |
if not os.path.exists("saved_models"): | |
os.makedirs("saved_models") | |
isnet_path = "saved_models/isnet.pth" | |
if not os.path.exists(isnet_path): | |
print("Downloading isnet.pth from Hugging Face Git LFS...") | |
try: | |
lfs_url = "https://cdn-lfs.huggingface.co/repos/e0/a8/e0a889743a78391b48db7c4c0b4de1963ee320cb10934c75a32481dc5af9c61/e0a889743a78391b48db7c4c0b4de1963ee320cb10934c75a32481dc5af9c61?download=true" | |
response = requests.get(lfs_url, stream=True) | |
response.raise_for_status() | |
with open(isnet_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
except requests.RequestException as e: | |
print(f"Failed to download isnet.pth: {e}") | |
raise | |
# Project imports | |
from data_loader_cache import normalize, im_reader, im_preprocess | |
from models import * | |
# Helpers | |
device = 'cpu' | |
class GOSNormalize(object): | |
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): | |
self.mean = mean | |
self.std = std | |
def __call__(self, image): | |
image = normalize(image, self.mean, self.std) | |
return image | |
transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])]) | |
def load_image(im_path, hypar): | |
im = im_reader(im_path) | |
im, im_shp = im_preprocess(im, hypar["cache_size"]) | |
im = torch.divide(im, 255.0) | |
shape = torch.from_numpy(np.array(im_shp)) | |
return transform(im).unsqueeze(0), shape.unsqueeze(0) | |
def build_model(hypar, device): | |
net = hypar["model"] | |
net.to(device) | |
if hypar["restore_model"]: | |
net.load_state_dict(torch.load(os.path.join(hypar["model_path"], hypar["restore_model"]), map_location=device)) | |
net.eval() | |
return net | |
def predict(net, inputs_val, shapes_val, hypar, device): | |
net.eval() | |
inputs_val = inputs_val.type(torch.FloatTensor).to(device) | |
with torch.no_grad(): | |
inputs_val_v = Variable(inputs_val) | |
ds_val = net(inputs_val_v)[0] | |
pred_val = ds_val[0][0, :, :, :] | |
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear')) | |
ma = torch.max(pred_val) | |
mi = torch.min(pred_val) | |
pred_val = (pred_val - mi) / (ma - mi) | |
return (pred_val.cpu().numpy() * 255).astype(np.uint8) | |
# Set Parameters | |
hypar = { | |
"model_path": "saved_models", | |
"restore_model": "isnet.pth", | |
"cache_size": [512, 512], | |
"input_size": [512, 512], | |
"crop_size": [512, 512], | |
"model": ISNetDIS() | |
} | |
# Build Model | |
net = build_model(hypar, device) | |
def inference(image): | |
image_path = image | |
image_tensor, orig_size = load_image(image_path, hypar) | |
mask = predict(net, image_tensor, orig_size, hypar, device) | |
pil_mask = Image.fromarray(mask).convert('L') | |
im_rgb = Image.open(image).convert("RGB") | |
im_rgba = im_rgb.copy() | |
im_rgba.putalpha(pil_mask) | |
return [im_rgba, pil_mask] | |
title = "Dichotomous Image Segmentation" | |
description = "Upload an image to remove its background." | |
interface = gr.Interface( | |
fn=inference, | |
inputs=gr.Image(type='filepath'), | |
outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png")], | |
title=title, | |
description=description, | |
flagging_mode="never", | |
cache_mode="lazy" | |
).launch() |