from .efficientnet import EfficientNetClassifier from .resnet import ResNetClassifier from .vit import VisionTransformerClassifier from .mobilenetv3 import MobileNetV3SmallClassifier from .shufflenetv2 import ShuffleNetV2Classifier import torch import torch.nn as nn def create_model(config: dict) -> nn.Module: """根据配置创建模型实例。""" model_config = config['model'] model_name = model_config['name'] num_classes = model_config['num_classes'] pretrained = model_config.get('pretrained', True) dropout = model_config.get('dropout', 0.2) model_name_lower = model_name.lower() if 'efficientnet' in model_name_lower: return EfficientNetClassifier(model_name, num_classes, pretrained, dropout) if 'resnet' in model_name_lower: return ResNetClassifier(model_name, num_classes, pretrained, dropout) if 'vit' in model_name_lower: return VisionTransformerClassifier(model_name, num_classes, pretrained, dropout) if 'mobilenetv3' in model_name_lower: return MobileNetV3SmallClassifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout) if 'shufflenetv2' in model_name_lower: return ShuffleNetV2Classifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout) raise ValueError(f'不支持的模型类型: {model_name}') def count_parameters(model: nn.Module) -> int: """计算可训练参数数量。""" return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad) def model_size_mb(model: nn.Module) -> float: """估算模型参数与缓冲区占用的内存大小(MB)。""" param_size = 0 buffer_size = 0 for parameter in model.parameters(): param_size += parameter.nelement() * parameter.element_size() for buffer in model.buffers(): buffer_size += buffer.nelement() * buffer.element_size() return (param_size + buffer_size) / 1024 / 1024