import torch import torch.nn as nn class FaceClassifier(nn.Module): """Face classification model with a configurable head.""" def __init__(self, base_model, num_classes, model_name, model_configs): super(FaceClassifier, self).__init__() self.base_model = base_model self.model_name = model_name # Determine the feature extraction method and output shape with torch.no_grad(): dummy_input = torch.zeros(1, 3, model_configs[model_name]['resolution'], model_configs[model_name]['resolution']) features = base_model(dummy_input) if len(features.shape) == 4: # Spatial feature map (batch, channels, height, width) in_channels = features.shape[1] self.feature_type = 'spatial' self.feature_dim = in_channels elif len(features.shape) == 2: # Flattened feature vector (batch, features) in_channels = features.shape[1] self.feature_type = 'flat' self.feature_dim = in_channels else: raise ValueError(f"Unexpected feature shape from base model {model_name}: {features.shape}") # Define the classifier head based on feature type if self.feature_type == 'flat' or 'vit' in model_name: self.conv_head = nn.Sequential( nn.Linear(self.feature_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, num_classes) ) else: self.conv_head = nn.Sequential( nn.Conv2d(self.feature_dim, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.Dropout2d(0.5), nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256, num_classes) ) def forward(self, x): features = self.base_model(x) output = self.conv_head(features) return output