classifier-mix / classifier.py
noeedc
Fix : Confidence bug
841bfaa
import torch.nn as nn
import timm
from torchvision import transforms
import torch
from PIL import Image
from torch.nn.functional import softmax
from transformers import PretrainedConfig, PreTrainedModel
LABEL_MAP = ["blur", "smoke", "clear", "fluid", "oob"]
class ClassifierConfig(PretrainedConfig):
model_type = "classifier"
def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), **kwargs):
super().__init__(**kwargs)
self.model_name = model_name
self.num_classes = num_classes
class ClassifierModel(nn.Module):
def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), pretrained=True):
super().__init__()
self.base_model = timm.create_model(model_name, pretrained=pretrained)
num_features = self.base_model.classifier.in_features
# Use Sequential to match saved model structure
self.base_model.classifier = nn.Sequential(
nn.Linear(num_features, num_classes)
)
if "mobilenetv2" in model_name:
self.target_layer = self.base_model.conv_head
else:
raise NotImplementedError(f"Grad-CAM target layer not defined for model: {model_name}")
def forward(self, x):
return self.base_model(x)
class ClassifierWrapper(PreTrainedModel):
config_class = ClassifierConfig
def __init__(self, config):
super().__init__(config)
self.model = ClassifierModel(
model_name=config.model_name,
num_classes=config.num_classes,
pretrained=False # Weights are loaded by from_pretrained
)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Normalize(mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991])
])
def forward(self, input):
# Ensure input is a tensor
if isinstance(input, Image.Image):
x = transforms.ToTensor()(input).unsqueeze(0) # Convert PIL Image to tensor
elif isinstance(input, torch.Tensor):
if input.dim() == 3:
x = input.unsqueeze(0) # Single tensor image
elif input.dim() == 4:
x = input # Batch
else:
raise ValueError("Unsupported tensor shape.")
else:
raise TypeError(f"Unsupported input type: {type(input)}. Expected PIL.Image or torch.Tensor.")
# Apply transformations
x = self.transform(x)
# Forward pass through the model
outputs = self.model(x)
confs = softmax(outputs, dim=1)
preds = torch.argmax(confs, dim=1)
results = []
for i in range(len(preds)):
label = LABEL_MAP[preds[i]]
confidences = {}
for j in range(len(LABEL_MAP)):
confidences[LABEL_MAP[j]] = round(float(confs[i][j]), 3)
results.append({"label": label, "confidences": confidences})
return results