import torch import torch.nn as nn from timm import create_model class EfficientNetClassifier(nn.Module): def __init__(self, model_name='efficientnet_b0', num_classes=5, pretrained=True, dropout=0.2): super().__init__() self.backbone = create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg') with torch.no_grad(): dummy = torch.randn(1, 3, 224, 224) self.feature_dim = self.backbone(dummy).shape[1] # DR分级头 self.classifier_grading = nn.Sequential( nn.Dropout(dropout), nn.Linear(self.feature_dim, 512), nn.ReLU(inplace=True), nn.Dropout(dropout * 0.5), nn.Linear(512, num_classes) ) # 二分类头 self.classifier_diabetic = nn.Sequential( nn.Dropout(dropout), nn.Linear(self.feature_dim, 128), nn.ReLU(inplace=True), nn.Dropout(dropout * 0.5), nn.Linear(128, 1) ) self._initialize_weights() def _initialize_weights(self): for m in list(self.classifier_grading) + list(self.classifier_diabetic): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0) def forward(self, x): features = self.backbone(x) out_grading = self.classifier_grading(features) out_diabetic = self.classifier_diabetic(features).squeeze(-1) # shape: (B,) # 返回dict,兼容旧用法 return {'grading': out_grading, 'diabetic': out_diabetic} def get_features(self, x): return self.backbone(x)