venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from torch import nn
from torch.nn import functional as F
import torch.hub
class DeepLabV2(nn.Module):
def __init__(self, n_classes=182, image_size=512, use_dont_care=True):
super(DeepLabV2, self).__init__()
self.model = torch.hub.load(
"kazuto1011/deeplab-pytorch", "deeplabv2_resnet101",
pretrained=False, n_classes=182
)
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/kazuto1011/deeplab-pytorch/releases/download/'
'v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth',
map_location="cpu"
)
self.model.load_state_dict(state_dict)
self.image_size = image_size
# self.mean = torch.tensor([122.675, 116.669, 104.008], device="cuda")
self.mean = torch.tensor([104.008, 116.669, 122.675], device="cuda")
self.n_classes = n_classes
self.use_dont_care = use_dont_care
def forward(self, images, align_corners=True):
scale = self.image_size / max(images.shape[2:])
images = F.interpolate(
images, scale_factor=scale, mode='bilinear',
align_corners=align_corners
)
images = 255 * 0.5 * (images + 1) # (-1, 1) -> (0, 255)
images = images.flip(1) # RGB to BGR
images -= self.mean[None, :, None, None]
_, _, H, W = images.shape
logits = self.model(images)
logits = F.interpolate(
logits, size=(H, W), mode="bilinear",
align_corners=align_corners
)
probs = F.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1)
return pred