real-cugan-oneflow / upcunet.py
hesha's picture
weights
4947b46
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
import numpy as np
import os, sys
root_path = os.path.abspath('.')
sys.path.append(root_path)
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=8, bias=False):
super(SEBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias)
self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias)
def forward(self, x):
dim = (2, 3)
if 'Half' in x.type():
x0 = flow.mean(x.float(), dim=dim, keepdim=True).half()
else:
x0 = flow.mean(x, dim=dim, keepdim=True)
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = F.sigmoid(x0)
x = flow.mul(x, x0)
return x
def forward_mean(self, x, x0):
x0 = self.conv1(x0)
x0 = F.relu(x0, inplace=True)
x0 = self.conv2(x0)
x0 = F.sigmoid(x0)
x = flow.mul(x, x0)
return x
class UNetConv(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, se):
super(UNetConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
nn.LeakyReLU(0.1, inplace=True)
)
if se:
self.seblock = SEBlock(out_channels, 8, True)
else:
self.seblock = None
def forward(self, x):
z = self.conv(x)
if self.seblock is not None:
z = self.seblock(z)
return z
class UNet1(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x1, x2):
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet1x3(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet1x3, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 128, 64, se=True)
self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x1, x2):
x2 = self.conv2_up(x2)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x1 = F.pad(x1, (-4, -4, -4, -4))
x3 = self.conv3(x1 + x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
z = self.conv_bottom(x3)
return z
class UNet2(nn.Module):
def __init__(self, in_channels, out_channels, deconv):
super(UNet2, self).__init__()
self.conv1 = UNetConv(in_channels, 32, 64, se=False)
self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
self.conv2 = UNetConv(64, 64, 128, se=True)
self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
self.conv3 = UNetConv(128, 256, 128, se=True)
self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
self.conv4 = UNetConv(128, 64, 64, se=True)
self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
if deconv:
self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
else:
self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2(x2)
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3(x3)
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4(x2 + x3)
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
def forward_a(self, x):
x1 = self.conv1(x)
x2 = self.conv1_down(x1)
x2 = F.leaky_relu(x2, 0.1, inplace=True)
x2 = self.conv2.conv(x2)
return x1, x2
def forward_b(self, x2):
x3 = self.conv2_down(x2)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x3 = self.conv3.conv(x3)
return x3
def forward_c(self, x2, x3):
x3 = self.conv3_up(x3)
x3 = F.leaky_relu(x3, 0.1, inplace=True)
x2 = F.pad(x2, (-4, -4, -4, -4))
x4 = self.conv4.conv(x2 + x3)
return x4
def forward_d(self, x1, x4):
x4 = self.conv4_up(x4)
x4 = F.leaky_relu(x4, 0.1, inplace=True)
x1 = F.pad(x1, (-16, -16, -16, -16))
x5 = self.conv5(x1 + x4)
x5 = F.leaky_relu(x5, 0.1, inplace=True)
z = self.conv_bottom(x5)
return z
class UpCunet(nn.Module):
def __init__(self, scale_factor, in_channels=3, out_channels=3):
super(UpCunet, self).__init__()
self.scalef = scale_factor
self.unet1 = UNet1(in_channels, out_channels if scale_factor == 2 else 64, deconv=True)
self.unet2 = UNet2(
in_channels if scale_factor == 2 else 64,
out_channels if scale_factor == 2 else 64,
deconv=False
)
if scale_factor == 4:
self.ps = nn.PixelShuffle(2)
self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
def forward(self, x, tile_mode):
n, c, h0, w0 = x.shape
x00 = x
if tile_mode == 0:
ph = ((h0 - 1) // 2 + 1) * 2
pw = ((w0 - 1) // 2 + 1) * 2
mx = 18 if self.scalef == 2 else 19
x = F.pad(x, (mx, mx + pw - w0, mx, mx + ph - h0), 'reflect')
x = self.unet1.forward(x)
x0 = self.unet2.forward(x)
x1 = F.pad(x, (-20, -20, -20, -20))
x = flow.add(x0, x1)
if self.scalef == 4:
x = self.conv_final(x)
x = F.pad(x, (-1, -1, -1, -1))
x = self.ps(x)
if (w0 != pw or h0 != ph):
x = x[:, :, :h0 * self.scalef, :w0 * self.scalef]
if self.scalef == 4:
x += F.interpolate(x00, scale_factor=4, mode='nearest')
return x
elif (tile_mode == 1):
if (w0 >= h0):
crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2
crop_size_h = (h0 - 1) // 2 * 2 + 2
else:
crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2
crop_size_w = (w0 - 1) // 2 * 2 + 2
crop_size = (crop_size_h, crop_size_w)
elif (tile_mode == 2):
crop_size = (((h0 - 1) // 4 * 4 + 4) // 2, ((w0 - 1) // 4 * 4 + 4) // 2)
elif (tile_mode == 3):
crop_size = (((h0 - 1) // 6 * 6 + 6) // 3, ((w0 - 1) // 6 * 6 + 6) // 3)
elif (tile_mode == 4):
crop_size = (((h0 - 1) // 8 * 8 + 8) // 4, ((w0 - 1) // 8 * 8 + 8) // 4)
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
mx = 18 if self.scalef == 2 else 19
x = F.pad(x, (mx, mx + pw - w0, mx, mx + ph - h0), 'reflect')
n, c, h, w = x.shape
se_mean0 = flow.zeros((n, 64, 1, 1)).to(x.device)
n_patch = 0
tmp_dict = {}
opt_res_dict = {}
scale_range = 36 if self.scalef == 2 else 38
for i in range(0, h - scale_range, crop_size[0]):
tmp_dict[i] = {}
for j in range(0, w - scale_range, crop_size[1]):
x_crop = x[:, :, i:i + crop_size[0] + scale_range, j:j + crop_size[1] + scale_range]
n, c1, h1, w1 = x_crop.shape
tmp0, x_crop = self.unet1.forward_a(x_crop)
tmp_se_mean = flow.mean(x_crop, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
n_patch += 1
tmp_dict[i][j] = (tmp0, x_crop)
se_mean0 /= n_patch
se_mean1 = flow.zeros((n, 128, 1, 1)).to(x.device)
for i in range(0, h - scale_range, crop_size[0]):
for j in range(0, w - scale_range, crop_size[1]):
tmp0, x_crop = tmp_dict[i][j]
x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
tmp_se_mean = flow.mean(tmp_x2, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
se_mean1 /= n_patch
se_mean0 = flow.zeros((n, 128, 1, 1)).to(x.device)
for i in range(0, h - scale_range, crop_size[0]):
for j in range(0, w - scale_range, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
tmp_x3 = self.unet2.forward_b(tmp_x2)
tmp_se_mean = flow.mean(tmp_x3, dim=(2, 3), keepdim=True)
se_mean0 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
se_mean0 /= n_patch
se_mean1 = flow.zeros((n, 64, 1, 1)).to(x.device)
for i in range(0, h - scale_range, crop_size[0]):
for j in range(0, w - scale_range, crop_size[1]):
opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
tmp_se_mean = flow.mean(tmp_x4, dim=(2, 3), keepdim=True)
se_mean1 += tmp_se_mean
tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
se_mean1 /= n_patch
for i in range(0, h - scale_range, crop_size[0]):
opt_res_dict[i] = {}
for j in range(0, w - scale_range, crop_size[1]):
opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
x_crop = flow.add(x0, x1)
opt_res_dict[i][j] = x_crop
del tmp_dict
flow.cuda.empty_cache()
cal = 72 if self.scalef == 2 else 152
res = flow.zeros((n, c, h * self.scalef - cal, w * self.scalef - cal)).to(x.device)
for i in range(0, h - scale_range, crop_size[0]):
for j in range(0, w - scale_range, crop_size[1]):
res[:, :, i * self.scalef:i * self.scalef + h1 * self.scalef - cal, j * self.scalef:j * self.scalef + w1 * self.scalef - cal] = opt_res_dict[i][j]
del opt_res_dict
flow.cuda.empty_cache()
if (w0 != pw or h0 != ph):
res = res[:, :, :h0 * self.scalef, :w0 * self.scalef]
if self.scalef == 4:
res += F.interpolate(x00, scale_factor=4, mode='nearest')
return res
class RealWaifuUpScaler(object):
def __init__(self, scalef, weight_path, half, device):
weight = flow.load(weight_path)
self.scalef = scalef
self.model = eval('UpCunet')(scale_factor=scalef)
if half:
self.model = self.model.half().to(device)
else:
self.model = self.model.to(device)
self.model.load_state_dict(weight)
self.model.eval()
self.half = half
self.device = device
def np2tensor(self, frame):
if not self.half:
return flow.from_numpy(np.transpose(frame, (2, 0, 1))).unsqueeze(0).to(self.device).float() / 255
else:
return flow.from_numpy(np.transpose(frame, (2, 0, 1))).unsqueeze(0).to(self.device).half() / 255
def tensor2np(self, tensor):
if not self.half:
return (np.transpose((tensor.data.squeeze() * 255.0).round().clamp_(0, 255).to(flow.uint8).cpu().numpy(), (1, 2, 0)))
else:
return (np.transpose((tensor.data.squeeze().float() * 255.0).round().clamp_(0, 255).byte().cpu().numpy(), (1, 2, 0)))
def __call__(self, frame, tile):
with flow.no_grad():
tensor = self.np2tensor(frame)
result = self.tensor2np(self.model(tensor, tile))
return result