aicomp_demo / src /models /__init__.py
ceasonen
我的视网膜检测网站
04103fb
from .efficientnet import EfficientNetClassifier
from .resnet import ResNetClassifier
from .vit import VisionTransformerClassifier
from .mobilenetv3 import MobileNetV3SmallClassifier
from .shufflenetv2 import ShuffleNetV2Classifier
import torch
import torch.nn as nn
def create_model(config: dict) -> nn.Module:
"""根据配置创建模型实例。"""
model_config = config['model']
model_name = model_config['name']
num_classes = model_config['num_classes']
pretrained = model_config.get('pretrained', True)
dropout = model_config.get('dropout', 0.2)
model_name_lower = model_name.lower()
if 'efficientnet' in model_name_lower:
return EfficientNetClassifier(model_name, num_classes, pretrained, dropout)
if 'resnet' in model_name_lower:
return ResNetClassifier(model_name, num_classes, pretrained, dropout)
if 'vit' in model_name_lower:
return VisionTransformerClassifier(model_name, num_classes, pretrained, dropout)
if 'mobilenetv3' in model_name_lower:
return MobileNetV3SmallClassifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout)
if 'shufflenetv2' in model_name_lower:
return ShuffleNetV2Classifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout)
raise ValueError(f'不支持的模型类型: {model_name}')
def count_parameters(model: nn.Module) -> int:
"""计算可训练参数数量。"""
return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
def model_size_mb(model: nn.Module) -> float:
"""估算模型参数与缓冲区占用的内存大小(MB)。"""
param_size = 0
buffer_size = 0
for parameter in model.parameters():
param_size += parameter.nelement() * parameter.element_size()
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
return (param_size + buffer_size) / 1024 / 1024