Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,817 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
# Copyright (c) OpenMMLab. All rights reserved.
import types
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import Tensor
from mmpose.utils.typing import OptConfigType
class RepVGGBlock(BaseModule):
"""A block in RepVGG architecture, supporting optional normalization in the
identity branch.
This block consists of 3x3 and 1x1 convolutions, with an optional identity
shortcut branch that includes normalization.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): The stride of the block. Defaults to 1.
padding (int): The padding of the block. Defaults to 1.
dilation (int): The dilation of the block. Defaults to 1.
groups (int): The groups of the block. Defaults to 1.
padding_mode (str): The padding mode of the block. Defaults to 'zeros'.
norm_cfg (dict): The config dict for normalization layers.
Defaults to dict(type='BN').
act_cfg (dict): The config dict for activation layers.
Defaults to dict(type='ReLU').
without_branch_norm (bool): Whether to skip branch_norm.
Defaults to True.
init_cfg (dict): The config dict for initialization. Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
padding: int = 1,
dilation: int = 1,
groups: int = 1,
padding_mode: str = 'zeros',
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU'),
without_branch_norm: bool = True,
init_cfg: OptConfigType = None):
super(RepVGGBlock, self).__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
# judge if input shape and output shape are the same.
# If true, add a normalized identity shortcut.
self.branch_norm = None
if out_channels == in_channels and stride == 1 and \
padding == dilation and not without_branch_norm:
self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
self.branch_3x3 = ConvModule(
self.in_channels,
self.out_channels,
3,
stride=self.stride,
padding=self.padding,
groups=self.groups,
dilation=self.dilation,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.branch_1x1 = ConvModule(
self.in_channels,
self.out_channels,
1,
groups=self.groups,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.act = build_activation_layer(act_cfg)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass through the RepVGG block.
The output is the sum of 3x3 and 1x1 convolution outputs,
along with the normalized identity branch output, followed by
activation.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
if self.branch_norm is None:
branch_norm_out = 0
else:
branch_norm_out = self.branch_norm(x)
out = self.branch_3x3(x) + self.branch_1x1(x) + branch_norm_out
out = self.act(out)
return out
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
"""Pad 1x1 tensor to 3x3.
Args:
kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
Returns:
Tensor: 3x3 kernel after padded.
"""
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch: nn.Module) -> Tensor:
"""Derives the equivalent kernel and bias of a specific branch layer.
Args:
branch (nn.Module): The layer that needs to be equivalently
transformed, which can be nn.Sequential or nn.Batchnorm2d
Returns:
tuple: Equivalent kernel and bias
"""
if branch is None:
return 0, 0
if isinstance(branch, ConvModule):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d))
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(
branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def get_equivalent_kernel_bias(self):
"""Derives the equivalent kernel and bias in a differentiable way.
Returns:
tuple: Equivalent kernel and bias
"""
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1)
kernelid, biasid = (0, 0) if self.branch_norm is None else \
self._fuse_bn_tensor(self.branch_norm)
return (kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
bias3x3 + bias1x1 + biasid)
def switch_to_deploy(self, test_cfg: Optional[Dict] = None):
"""Switches the block to deployment mode.
In deployment mode, the block uses a single convolution operation
derived from the equivalent kernel and bias, replacing the original
branches. This reduces computational complexity during inference.
"""
if getattr(self, 'deploy', False):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.conv_reparam = nn.Conv2d(
in_channels=self.branch_3x3.conv.in_channels,
out_channels=self.branch_3x3.conv.out_channels,
kernel_size=self.branch_3x3.conv.kernel_size,
stride=self.branch_3x3.conv.stride,
padding=self.branch_3x3.conv.padding,
dilation=self.branch_3x3.conv.dilation,
groups=self.branch_3x3.conv.groups,
bias=True)
self.conv_reparam.weight.data = kernel
self.conv_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('branch_3x3')
self.__delattr__('branch_1x1')
if hasattr(self, 'branch_norm'):
self.__delattr__('branch_norm')
def _forward(self, x):
return self.act(self.conv_reparam(x))
self.forward = types.MethodType(_forward, self)
self.deploy = True
|