from torch import nn | |
def Conv2dSame(dim_in, dim_out, kernel_size, bias=True): | |
pad_left = kernel_size // 2 | |
pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left | |
return nn.Sequential( | |
nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)), | |
nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias), | |
) | |