Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchvision.models import mobilenet_v3_small | |
class MobileNetV3SmallClassifier(nn.Module): | |
def __init__(self, num_classes=5, pretrained=True): | |
super().__init__() | |
self.backbone = mobilenet_v3_small(pretrained=pretrained) | |
in_features = self.backbone.classifier[3].in_features | |
self.backbone.classifier[3] = nn.Linear(in_features, num_classes) | |
def forward(self, x): | |
return self.backbone(x) | |