import torch.nn as nn import timm from torchvision import transforms import torch from PIL import Image from torch.nn.functional import softmax from transformers import PretrainedConfig, PreTrainedModel LABEL_MAP = ["blur", "smoke", "clear", "fluid", "oob"] class ClassifierConfig(PretrainedConfig): model_type = "classifier" def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), **kwargs): super().__init__(**kwargs) self.model_name = model_name self.num_classes = num_classes class ClassifierModel(nn.Module): def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), pretrained=True): super().__init__() self.base_model = timm.create_model(model_name, pretrained=pretrained) num_features = self.base_model.classifier.in_features # Use Sequential to match saved model structure self.base_model.classifier = nn.Sequential( nn.Linear(num_features, num_classes) ) if "mobilenetv2" in model_name: self.target_layer = self.base_model.conv_head else: raise NotImplementedError(f"Grad-CAM target layer not defined for model: {model_name}") def forward(self, x): return self.base_model(x) class ClassifierWrapper(PreTrainedModel): config_class = ClassifierConfig def __init__(self, config): super().__init__(config) self.model = ClassifierModel( model_name=config.model_name, num_classes=config.num_classes, pretrained=False # Weights are loaded by from_pretrained ) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Normalize(mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991]) ]) def forward(self, input): # Ensure input is a tensor if isinstance(input, Image.Image): x = transforms.ToTensor()(input).unsqueeze(0) # Convert PIL Image to tensor elif isinstance(input, torch.Tensor): if input.dim() == 3: x = input.unsqueeze(0) # Single tensor image elif input.dim() == 4: x = input # Batch else: raise ValueError("Unsupported tensor shape.") else: raise TypeError(f"Unsupported input type: {type(input)}. Expected PIL.Image or torch.Tensor.") # Apply transformations x = self.transform(x) # Forward pass through the model outputs = self.model(x) confs = softmax(outputs, dim=1) preds = torch.argmax(confs, dim=1) results = [] for i in range(len(preds)): label = LABEL_MAP[preds[i]] confidences = {} for j in range(len(LABEL_MAP)): confidences[LABEL_MAP[j]] = round(float(confs[i][j]), 3) results.append({"label": label, "confidences": confidences}) return results