Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| def fuse_conv_and_bn(conv, bn): | |
| # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | |
| fusedconv = ( | |
| nn.Conv2d( | |
| conv.in_channels, | |
| conv.out_channels, | |
| kernel_size=conv.kernel_size, | |
| stride=conv.stride, | |
| padding=conv.padding, | |
| groups=conv.groups, | |
| bias=True, | |
| ) | |
| .requires_grad_(False) | |
| .to(conv.weight.device) | |
| ) | |
| # prepare filters | |
| w_conv = conv.weight.clone().view(conv.out_channels, -1) | |
| w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) | |
| fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) | |
| # prepare spatial bias | |
| b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias | |
| b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) | |
| fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |
| return fusedconv | |
| def copy_attr(a, b, include=(), exclude=()): | |
| # Copy attributes from b to a, options to only include [...] and to exclude [...] | |
| for k, v in b.__dict__.items(): | |
| if (include and k not in include) or k.startswith("_") or k in exclude: | |
| continue | |
| setattr(a, k, v) | |