|
|
from torch import nn |
|
|
try: |
|
|
from torchvision.models.utils import load_state_dict_from_url |
|
|
except: |
|
|
from torch.hub import load_state_dict_from_url |
|
|
import torch.nn.functional as F |
|
|
|
|
|
__all__ = ['MobileNetV2', 'mobilenet_v2'] |
|
|
|
|
|
|
|
|
model_urls = { |
|
|
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', |
|
|
} |
|
|
|
|
|
|
|
|
def _make_divisible(v, divisor, min_value=None): |
|
|
""" |
|
|
This function is taken from the original tf repo. |
|
|
It ensures that all layers have a channel number that is divisible by 8 |
|
|
It can be seen here: |
|
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
|
|
:param v: |
|
|
:param divisor: |
|
|
:param min_value: |
|
|
:return: |
|
|
""" |
|
|
if min_value is None: |
|
|
min_value = divisor |
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
|
|
|
|
|
if new_v < 0.9 * v: |
|
|
new_v += divisor |
|
|
return new_v |
|
|
|
|
|
|
|
|
class ConvBNReLU(nn.Sequential): |
|
|
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): |
|
|
|
|
|
super(ConvBNReLU, self).__init__( |
|
|
nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), |
|
|
nn.BatchNorm2d(out_planes), |
|
|
nn.ReLU6(inplace=True) |
|
|
) |
|
|
|
|
|
def fixed_padding(kernel_size, dilation): |
|
|
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) |
|
|
pad_total = kernel_size_effective - 1 |
|
|
pad_beg = pad_total // 2 |
|
|
pad_end = pad_total - pad_beg |
|
|
return (pad_beg, pad_end, pad_beg, pad_end) |
|
|
|
|
|
class InvertedResidual(nn.Module): |
|
|
def __init__(self, inp, oup, stride, dilation, expand_ratio): |
|
|
super(InvertedResidual, self).__init__() |
|
|
self.stride = stride |
|
|
assert stride in [1, 2] |
|
|
|
|
|
hidden_dim = int(round(inp * expand_ratio)) |
|
|
self.use_res_connect = self.stride == 1 and inp == oup |
|
|
|
|
|
layers = [] |
|
|
if expand_ratio != 1: |
|
|
|
|
|
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) |
|
|
|
|
|
layers.extend([ |
|
|
|
|
|
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), |
|
|
|
|
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), |
|
|
nn.BatchNorm2d(oup), |
|
|
]) |
|
|
self.conv = nn.Sequential(*layers) |
|
|
|
|
|
self.input_padding = fixed_padding( 3, dilation ) |
|
|
|
|
|
def forward(self, x): |
|
|
x_pad = F.pad(x, self.input_padding) |
|
|
if self.use_res_connect: |
|
|
return x + self.conv(x_pad) |
|
|
else: |
|
|
return self.conv(x_pad) |
|
|
|
|
|
class MobileNetV2(nn.Module): |
|
|
def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): |
|
|
""" |
|
|
MobileNet V2 main class |
|
|
|
|
|
Args: |
|
|
num_classes (int): Number of classes |
|
|
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount |
|
|
inverted_residual_setting: Network structure |
|
|
round_nearest (int): Round the number of channels in each layer to be a multiple of this number |
|
|
Set to 1 to turn off rounding |
|
|
""" |
|
|
super(MobileNetV2, self).__init__() |
|
|
block = InvertedResidual |
|
|
input_channel = 32 |
|
|
last_channel = 1280 |
|
|
self.output_stride = output_stride |
|
|
current_stride = 1 |
|
|
if inverted_residual_setting is None: |
|
|
inverted_residual_setting = [ |
|
|
|
|
|
[1, 16, 1, 1], |
|
|
[6, 24, 2, 2], |
|
|
[6, 32, 3, 2], |
|
|
[6, 64, 4, 2], |
|
|
[6, 96, 3, 1], |
|
|
[6, 160, 3, 2], |
|
|
[6, 320, 1, 1], |
|
|
] |
|
|
|
|
|
|
|
|
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: |
|
|
raise ValueError("inverted_residual_setting should be non-empty " |
|
|
"or a 4-element list, got {}".format(inverted_residual_setting)) |
|
|
|
|
|
|
|
|
input_channel = _make_divisible(input_channel * width_mult, round_nearest) |
|
|
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) |
|
|
features = [ConvBNReLU(3, input_channel, stride=2)] |
|
|
current_stride *= 2 |
|
|
dilation=1 |
|
|
previous_dilation = 1 |
|
|
|
|
|
|
|
|
for t, c, n, s in inverted_residual_setting: |
|
|
output_channel = _make_divisible(c * width_mult, round_nearest) |
|
|
previous_dilation = dilation |
|
|
if current_stride == output_stride: |
|
|
stride = 1 |
|
|
dilation *= s |
|
|
else: |
|
|
stride = s |
|
|
current_stride *= s |
|
|
output_channel = int(c * width_mult) |
|
|
|
|
|
for i in range(n): |
|
|
if i==0: |
|
|
features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) |
|
|
else: |
|
|
features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) |
|
|
input_channel = output_channel |
|
|
|
|
|
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) |
|
|
|
|
|
self.features = nn.Sequential(*features) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(self.last_channel, num_classes), |
|
|
) |
|
|
|
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out') |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
nn.init.ones_(m.weight) |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.normal_(m.weight, 0, 0.01) |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.features(x) |
|
|
x = x.mean([2, 3]) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def mobilenet_v2(pretrained=False, progress=True, **kwargs): |
|
|
""" |
|
|
Constructs a MobileNetV2 architecture from |
|
|
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. |
|
|
|
|
|
Args: |
|
|
pretrained (bool): If True, returns a model pre-trained on ImageNet |
|
|
progress (bool): If True, displays a progress bar of the download to stderr |
|
|
""" |
|
|
model = MobileNetV2(**kwargs) |
|
|
if pretrained: |
|
|
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], |
|
|
progress=progress) |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
|