Max Meyer
commited on
Fix example, load weights safely and remove extra whitespace (#2)
Browse files- Fix example, load weights safely, remove whitespace (e840827775a894eaf8241a70535c2921b36bd430)
- README.md +11 -11
- image2.png +0 -0
- model.py +27 -27
README.md
CHANGED
|
@@ -13,34 +13,34 @@ tags:
|
|
| 13 |
|
| 14 |
# BEN - Background Erase Network (Beta Base Model)
|
| 15 |
|
| 16 |
-
BEN is a deep learning model designed to automatically remove backgrounds from images, producing both a mask and a foreground image.
|
| 17 |
|
| 18 |
- MADE IN AMERICA
|
| 19 |
|
| 20 |
## Quick Start Code
|
|
|
|
| 21 |
```python
|
| 22 |
-
|
| 23 |
from PIL import Image
|
| 24 |
import torch
|
| 25 |
|
| 26 |
|
| 27 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 28 |
|
| 29 |
-
file = "./image2.
|
| 30 |
|
| 31 |
-
model = model.BEN_Base().to(device).eval() #init pipeline
|
| 32 |
|
| 33 |
-
model.loadcheckpoints("./
|
| 34 |
image = Image.open(file)
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
mask.save("./mask.png")
|
| 38 |
foreground.save("./foreground.png")
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
```
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |

|
| 46 |
|
|
@@ -84,4 +84,4 @@ foreground.save("./foreground.png")
|
|
| 84 |
|
| 85 |
## Installation
|
| 86 |
1. Clone Repo
|
| 87 |
-
2. Install requirements.txt
|
|
|
|
| 13 |
|
| 14 |
# BEN - Background Erase Network (Beta Base Model)
|
| 15 |
|
| 16 |
+
BEN is a deep learning model designed to automatically remove backgrounds from images, producing both a mask and a foreground image.
|
| 17 |
|
| 18 |
- MADE IN AMERICA
|
| 19 |
|
| 20 |
## Quick Start Code
|
| 21 |
+
|
| 22 |
```python
|
| 23 |
+
import model
|
| 24 |
from PIL import Image
|
| 25 |
import torch
|
| 26 |
|
| 27 |
|
| 28 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 29 |
|
| 30 |
+
file = "./image2.png" # input image
|
| 31 |
|
| 32 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
| 33 |
|
| 34 |
+
model.loadcheckpoints("./BEN_Base.pth")
|
| 35 |
image = Image.open(file)
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
mask, foreground = model.inference(image)
|
| 38 |
|
| 39 |
mask.save("./mask.png")
|
| 40 |
foreground.save("./foreground.png")
|
|
|
|
|
|
|
|
|
|
| 41 |
```
|
| 42 |
+
|
| 43 |
+
# BEN SOA Benchmarks on Disk 5k Eval
|
| 44 |
|
| 45 |

|
| 46 |
|
|
|
|
| 84 |
|
| 85 |
## Installation
|
| 86 |
1. Clone Repo
|
| 87 |
+
2. Install requirements.txt
|
image2.png
ADDED
|
model.py
CHANGED
|
@@ -560,7 +560,7 @@ class SwinTransformer(nn.Module):
|
|
| 560 |
# interpolate the position embedding to the corresponding size
|
| 561 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 562 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 563 |
-
|
| 564 |
outs = [x.contiguous()]
|
| 565 |
x = x.flatten(2).transpose(1, 2)
|
| 566 |
x = self.pos_drop(x)
|
|
@@ -634,7 +634,7 @@ class PositionEmbeddingSine:
|
|
| 634 |
scale = 2 * math.pi
|
| 635 |
self.scale = scale
|
| 636 |
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
| 637 |
-
|
| 638 |
def __call__(self, b, h, w):
|
| 639 |
device = self.dim_t.device
|
| 640 |
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
|
@@ -646,18 +646,18 @@ class PositionEmbeddingSine:
|
|
| 646 |
eps = 1e-6
|
| 647 |
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
| 648 |
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
| 649 |
-
|
| 650 |
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
| 651 |
pos_x = x_embed[:, :, :, None] / dim_t
|
| 652 |
pos_y = y_embed[:, :, :, None] / dim_t
|
| 653 |
-
|
| 654 |
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 655 |
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 656 |
-
|
| 657 |
-
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 658 |
|
|
|
|
| 659 |
|
| 660 |
-
|
|
|
|
| 661 |
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 662 |
super(MCLM, self).__init__()
|
| 663 |
self.attention = nn.ModuleList([
|
|
@@ -688,10 +688,10 @@ class MCLM(nn.Module):
|
|
| 688 |
l: 4,c,h,w
|
| 689 |
g: 1,c,h,w
|
| 690 |
"""
|
| 691 |
-
b, c, h, w = l.size()
|
| 692 |
# 4,c,h,w -> 1,c,2h,2w
|
| 693 |
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
| 694 |
-
|
| 695 |
pools = []
|
| 696 |
for pool_ratio in self.pool_ratios:
|
| 697 |
# b,c,h,w
|
|
@@ -734,7 +734,7 @@ class MCLM(nn.Module):
|
|
| 734 |
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 735 |
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 736 |
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 737 |
-
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 738 |
|
| 739 |
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 740 |
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
|
@@ -770,42 +770,42 @@ class MCRM(nn.Module):
|
|
| 770 |
|
| 771 |
def forward(self, x):
|
| 772 |
device = x.device
|
| 773 |
-
b, c, h, w = x.size()
|
| 774 |
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 775 |
-
|
| 776 |
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 777 |
-
|
| 778 |
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 779 |
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
| 780 |
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 781 |
-
|
| 782 |
pools = []
|
| 783 |
for pool_ratio in self.pool_ratios:
|
| 784 |
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 785 |
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 786 |
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 787 |
-
|
| 788 |
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 789 |
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 790 |
-
|
| 791 |
outputs = []
|
| 792 |
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
| 793 |
v = pools[i]
|
| 794 |
k = v
|
| 795 |
outputs.append(self.attention[i](q, k, v)[0])
|
| 796 |
-
|
| 797 |
-
outputs = torch.cat(outputs, 1)
|
| 798 |
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 799 |
src = self.norm1(src)
|
| 800 |
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
| 801 |
src = self.norm2(src)
|
| 802 |
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 803 |
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
| 804 |
-
|
| 805 |
return torch.cat((src, glb), 0), token_attention_map
|
| 806 |
|
| 807 |
|
| 808 |
-
class BEN_Base(nn.Module):
|
| 809 |
def __init__(self):
|
| 810 |
super().__init__()
|
| 811 |
|
|
@@ -868,7 +868,7 @@ class BEN_Base(nn.Module):
|
|
| 868 |
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
| 869 |
|
| 870 |
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
| 871 |
-
e4 = self.conv4(e4)
|
| 872 |
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
| 873 |
e3 = self.conv3(e3)
|
| 874 |
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
|
@@ -909,11 +909,11 @@ class BEN_Base(nn.Module):
|
|
| 909 |
return blurred_mask, foreground
|
| 910 |
|
| 911 |
def loadcheckpoints(self,model_path):
|
| 912 |
-
model_dict = torch.load(model_path,map_location="cpu")
|
| 913 |
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
| 914 |
del model_path
|
| 915 |
|
| 916 |
-
|
| 917 |
|
| 918 |
|
| 919 |
def rgb_loader_refiner( original_image):
|
|
@@ -923,16 +923,16 @@ def rgb_loader_refiner( original_image):
|
|
| 923 |
# Convert to RGB if necessary
|
| 924 |
if image.mode != 'RGB':
|
| 925 |
image = image.convert('RGB')
|
| 926 |
-
|
| 927 |
# Resize the image
|
| 928 |
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
| 929 |
|
| 930 |
-
return image.convert('RGB'), h, w,original_image
|
| 931 |
-
|
| 932 |
# Define the image transformation
|
| 933 |
img_transform = transforms.Compose([
|
| 934 |
transforms.ToTensor(),
|
| 935 |
-
transforms.ConvertImageDtype(torch.float32),
|
| 936 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 937 |
])
|
| 938 |
|
|
|
|
| 560 |
# interpolate the position embedding to the corresponding size
|
| 561 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 562 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 563 |
+
|
| 564 |
outs = [x.contiguous()]
|
| 565 |
x = x.flatten(2).transpose(1, 2)
|
| 566 |
x = self.pos_drop(x)
|
|
|
|
| 634 |
scale = 2 * math.pi
|
| 635 |
self.scale = scale
|
| 636 |
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
| 637 |
+
|
| 638 |
def __call__(self, b, h, w):
|
| 639 |
device = self.dim_t.device
|
| 640 |
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
|
|
|
| 646 |
eps = 1e-6
|
| 647 |
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
| 648 |
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
| 649 |
+
|
| 650 |
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
| 651 |
pos_x = x_embed[:, :, :, None] / dim_t
|
| 652 |
pos_y = y_embed[:, :, :, None] / dim_t
|
| 653 |
+
|
| 654 |
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 655 |
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
|
|
|
|
|
| 656 |
|
| 657 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 658 |
|
| 659 |
+
|
| 660 |
+
class MCLM(nn.Module):
|
| 661 |
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 662 |
super(MCLM, self).__init__()
|
| 663 |
self.attention = nn.ModuleList([
|
|
|
|
| 688 |
l: 4,c,h,w
|
| 689 |
g: 1,c,h,w
|
| 690 |
"""
|
| 691 |
+
b, c, h, w = l.size()
|
| 692 |
# 4,c,h,w -> 1,c,2h,2w
|
| 693 |
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
| 694 |
+
|
| 695 |
pools = []
|
| 696 |
for pool_ratio in self.pool_ratios:
|
| 697 |
# b,c,h,w
|
|
|
|
| 734 |
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 735 |
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 736 |
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 737 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 738 |
|
| 739 |
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 740 |
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
|
|
|
| 770 |
|
| 771 |
def forward(self, x):
|
| 772 |
device = x.device
|
| 773 |
+
b, c, h, w = x.size()
|
| 774 |
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 775 |
+
|
| 776 |
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 777 |
+
|
| 778 |
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 779 |
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
| 780 |
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 781 |
+
|
| 782 |
pools = []
|
| 783 |
for pool_ratio in self.pool_ratios:
|
| 784 |
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 785 |
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 786 |
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 787 |
+
|
| 788 |
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 789 |
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 790 |
+
|
| 791 |
outputs = []
|
| 792 |
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
| 793 |
v = pools[i]
|
| 794 |
k = v
|
| 795 |
outputs.append(self.attention[i](q, k, v)[0])
|
| 796 |
+
|
| 797 |
+
outputs = torch.cat(outputs, 1)
|
| 798 |
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 799 |
src = self.norm1(src)
|
| 800 |
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
| 801 |
src = self.norm2(src)
|
| 802 |
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 803 |
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
| 804 |
+
|
| 805 |
return torch.cat((src, glb), 0), token_attention_map
|
| 806 |
|
| 807 |
|
| 808 |
+
class BEN_Base(nn.Module):
|
| 809 |
def __init__(self):
|
| 810 |
super().__init__()
|
| 811 |
|
|
|
|
| 868 |
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
| 869 |
|
| 870 |
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
| 871 |
+
e4 = self.conv4(e4)
|
| 872 |
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
| 873 |
e3 = self.conv3(e3)
|
| 874 |
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
|
|
|
| 909 |
return blurred_mask, foreground
|
| 910 |
|
| 911 |
def loadcheckpoints(self,model_path):
|
| 912 |
+
model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 913 |
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
| 914 |
del model_path
|
| 915 |
|
| 916 |
+
|
| 917 |
|
| 918 |
|
| 919 |
def rgb_loader_refiner( original_image):
|
|
|
|
| 923 |
# Convert to RGB if necessary
|
| 924 |
if image.mode != 'RGB':
|
| 925 |
image = image.convert('RGB')
|
| 926 |
+
|
| 927 |
# Resize the image
|
| 928 |
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
| 929 |
|
| 930 |
+
return image.convert('RGB'), h, w,original_image
|
| 931 |
+
|
| 932 |
# Define the image transformation
|
| 933 |
img_transform = transforms.Compose([
|
| 934 |
transforms.ToTensor(),
|
| 935 |
+
transforms.ConvertImageDtype(torch.float32),
|
| 936 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 937 |
])
|
| 938 |
|