Spaces:
Running
Running
File size: 1,963 Bytes
04103fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from .efficientnet import EfficientNetClassifier
from .resnet import ResNetClassifier
from .vit import VisionTransformerClassifier
from .mobilenetv3 import MobileNetV3SmallClassifier
from .shufflenetv2 import ShuffleNetV2Classifier
import torch
import torch.nn as nn
def create_model(config: dict) -> nn.Module:
"""根据配置创建模型实例。"""
model_config = config['model']
model_name = model_config['name']
num_classes = model_config['num_classes']
pretrained = model_config.get('pretrained', True)
dropout = model_config.get('dropout', 0.2)
model_name_lower = model_name.lower()
if 'efficientnet' in model_name_lower:
return EfficientNetClassifier(model_name, num_classes, pretrained, dropout)
if 'resnet' in model_name_lower:
return ResNetClassifier(model_name, num_classes, pretrained, dropout)
if 'vit' in model_name_lower:
return VisionTransformerClassifier(model_name, num_classes, pretrained, dropout)
if 'mobilenetv3' in model_name_lower:
return MobileNetV3SmallClassifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout)
if 'shufflenetv2' in model_name_lower:
return ShuffleNetV2Classifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout)
raise ValueError(f'不支持的模型类型: {model_name}')
def count_parameters(model: nn.Module) -> int:
"""计算可训练参数数量。"""
return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
def model_size_mb(model: nn.Module) -> float:
"""估算模型参数与缓冲区占用的内存大小(MB)。"""
param_size = 0
buffer_size = 0
for parameter in model.parameters():
param_size += parameter.nelement() * parameter.element_size()
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
return (param_size + buffer_size) / 1024 / 1024
|