Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchvision.models import shufflenet_v2_x1_0 | |
class ShuffleNetV2Classifier(nn.Module): | |
def __init__(self, num_classes=5, pretrained=True): | |
super().__init__() | |
self.backbone = shufflenet_v2_x1_0(pretrained=pretrained) | |
in_features = self.backbone.fc.in_features | |
self.backbone.fc = nn.Linear(in_features, num_classes) | |
def forward(self, x): | |
return self.backbone(x) | |