from torch import nn def Conv2dSame(dim_in, dim_out, kernel_size, bias=True): pad_left = kernel_size // 2 pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left return nn.Sequential( nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)), nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias), )