import torch import torch.nn as nn from timm import create_model class VisionTransformerClassifier(nn.Module): def __init__(self, model_name='vit_small_patch16_224', num_classes=5, pretrained=True, dropout=0.2): super().__init__() self.backbone = create_model(model_name, pretrained=pretrained, num_classes=0) with torch.no_grad(): dummy = torch.randn(1, 3, 224, 224) self.feature_dim = self.backbone(dummy).shape[1] self.classifier = nn.Sequential( nn.LayerNorm(self.feature_dim), nn.Dropout(dropout), nn.Linear(self.feature_dim, 256), nn.GELU(), nn.Dropout(dropout * 0.5), nn.Linear(256, num_classes) ) self._initialize_weights() def _initialize_weights(self): for m in self.classifier: if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, x): features = self.backbone(x) return self.classifier(features) def get_features(self, x): return self.backbone(x)