import cv2 import gradio as gr import os import requests from PIL import Image import numpy as np import torch from torch.autograd import Variable from torchvision import transforms import torch.nn.functional as F # Automatically download required files # 1. data_loader_cache.py from GitHub if not os.path.exists("data_loader_cache.py"): print("Downloading data_loader_cache.py...") try: response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/DIS/IS-Net/data_loader_cache.py") response.raise_for_status() with open("data_loader_cache.py", "wb") as f: f.write(response.content) except requests.RequestException as e: print(f"Failed to download data_loader_cache.py: {e}") raise # 2. models.py from GitHub if not os.path.exists("models.py"): print("Downloading models.py...") try: response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/DIS/IS-Net/models.py") response.raise_for_status() with open("models.py", "wb") as f: f.write(response.content) except requests.RequestException as e: print(f"Failed to download models.py: {e}") raise # 3. isnet.pth from Hugging Face Git LFS (direct URL from screenshot) if not os.path.exists("saved_models"): os.makedirs("saved_models") isnet_path = "saved_models/isnet.pth" if not os.path.exists(isnet_path): print("Downloading isnet.pth from Hugging Face Git LFS...") try: lfs_url = "https://cdn-lfs.huggingface.co/repos/e0/a8/e0a889743a78391b48db7c4c0b4de1963ee320cb10934c75a32481dc5af9c61/e0a889743a78391b48db7c4c0b4de1963ee320cb10934c75a32481dc5af9c61?download=true" response = requests.get(lfs_url, stream=True) response.raise_for_status() with open(isnet_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) except requests.RequestException as e: print(f"Failed to download isnet.pth: {e}") raise # Project imports from data_loader_cache import normalize, im_reader, im_preprocess from models import * # Helpers device = 'cpu' class GOSNormalize(object): def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): self.mean = mean self.std = std def __call__(self, image): image = normalize(image, self.mean, self.std) return image transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])]) def load_image(im_path, hypar): im = im_reader(im_path) im, im_shp = im_preprocess(im, hypar["cache_size"]) im = torch.divide(im, 255.0) shape = torch.from_numpy(np.array(im_shp)) return transform(im).unsqueeze(0), shape.unsqueeze(0) def build_model(hypar, device): net = hypar["model"] net.to(device) if hypar["restore_model"]: net.load_state_dict(torch.load(os.path.join(hypar["model_path"], hypar["restore_model"]), map_location=device)) net.eval() return net def predict(net, inputs_val, shapes_val, hypar, device): net.eval() inputs_val = inputs_val.type(torch.FloatTensor).to(device) with torch.no_grad(): inputs_val_v = Variable(inputs_val) ds_val = net(inputs_val_v)[0] pred_val = ds_val[0][0, :, :, :] pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear')) ma = torch.max(pred_val) mi = torch.min(pred_val) pred_val = (pred_val - mi) / (ma - mi) return (pred_val.cpu().numpy() * 255).astype(np.uint8) # Set Parameters hypar = { "model_path": "saved_models", "restore_model": "isnet.pth", "cache_size": [512, 512], "input_size": [512, 512], "crop_size": [512, 512], "model": ISNetDIS() } # Build Model net = build_model(hypar, device) def inference(image): image_path = image image_tensor, orig_size = load_image(image_path, hypar) mask = predict(net, image_tensor, orig_size, hypar, device) pil_mask = Image.fromarray(mask).convert('L') im_rgb = Image.open(image).convert("RGB") im_rgba = im_rgb.copy() im_rgba.putalpha(pil_mask) return [im_rgba, pil_mask] title = "Dichotomous Image Segmentation" description = "Upload an image to remove its background." interface = gr.Interface( fn=inference, inputs=gr.Image(type='filepath'), outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png")], title=title, description=description, flagging_mode="never", cache_mode="lazy" ).launch()