JasonSmithSO's picture
Upload 777 files
0034848 verified
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from custom_controlnet_aux.util import custom_torch_download
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False)
mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True)
mob_blocks = mobilenet_v2.features
# Encoder
self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
mob_blocks[0],
mob_blocks[1]
)
self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
mob_blocks[2],
mob_blocks[3],
)
self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
mob_blocks[4],
mob_blocks[5],
mob_blocks[6],
)
self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
mob_blocks[7],
mob_blocks[8],
mob_blocks[9],
mob_blocks[10],
mob_blocks[11],
mob_blocks[12],
mob_blocks[13],
)
self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
mob_blocks[14],
mob_blocks[15],
mob_blocks[16],
)
# Decoder
self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(160, 96, kernel_size=3, padding=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
nn.InstanceNorm2d(32),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
nn.InstanceNorm2d(24),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
nn.InstanceNorm2d(16),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.2)
)
self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
nn.Softmax2d()
)
def forward(self, x):
e0 = self.en_block0(x)
e1 = self.en_block1(e0)
e2 = self.en_block2(e1)
e3 = self.en_block3(e2)
e4 = self.en_block4(e3)
d4 = self.de_block4(e4)
c4 = torch.cat((d4,e3),1)
d3 = self.de_block3(c4)
c3 = torch.cat((d3,e2),1)
d2 = self.de_block2(c3)
c2 =torch.cat((d2,e1),1)
d1 = self.de_block1(c2)
c1 = torch.cat((d1,e0),1)
y = self.de_block0(c1)
return y