Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
class ResNetClassifier(nn.Module): | |
def __init__(self, model_name='resnet18', num_classes=5, pretrained=True, dropout=0.2): | |
super().__init__() | |
if model_name == 'resnet18': | |
self.backbone = models.resnet18(pretrained=pretrained) | |
self.feature_dim = 512 | |
elif model_name == 'resnet34': | |
self.backbone = models.resnet34(pretrained=pretrained) | |
self.feature_dim = 512 | |
elif model_name == 'resnet50': | |
self.backbone = models.resnet50(pretrained=pretrained) | |
self.feature_dim = 2048 | |
else: | |
raise ValueError(f'不支持的ResNet模型: {model_name}') | |
self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Dropout(dropout), | |
nn.Linear(self.feature_dim, 512), | |
nn.ReLU(inplace=True), | |
nn.Dropout(dropout * 0.5), | |
nn.Linear(512, num_classes) | |
) | |
self._initialize_weights() | |
def _initialize_weights(self): | |
for m in self.classifier: | |
if isinstance(m, nn.Linear): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
features = self.backbone(x) | |
return self.classifier(features) | |
def get_features(self, x): | |
features = self.backbone(x) | |
return features.view(features.size(0), -1) | |