|
import torch.nn as nn |
|
import timm |
|
from torchvision import transforms |
|
import torch |
|
from PIL import Image |
|
from torch.nn.functional import softmax |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
LABEL_MAP = ["blur", "smoke", "clear", "fluid", "oob"] |
|
|
|
|
|
class ClassifierConfig(PretrainedConfig): |
|
model_type = "classifier" |
|
|
|
def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), **kwargs): |
|
super().__init__(**kwargs) |
|
self.model_name = model_name |
|
self.num_classes = num_classes |
|
|
|
class ClassifierModel(nn.Module): |
|
def __init__(self, model_name="mobilenetv2_100", num_classes=len(LABEL_MAP), pretrained=True): |
|
super().__init__() |
|
self.base_model = timm.create_model(model_name, pretrained=pretrained) |
|
num_features = self.base_model.classifier.in_features |
|
|
|
self.base_model.classifier = nn.Sequential( |
|
nn.Linear(num_features, num_classes) |
|
) |
|
if "mobilenetv2" in model_name: |
|
self.target_layer = self.base_model.conv_head |
|
else: |
|
raise NotImplementedError(f"Grad-CAM target layer not defined for model: {model_name}") |
|
|
|
def forward(self, x): |
|
return self.base_model(x) |
|
|
|
|
|
class ClassifierWrapper(PreTrainedModel): |
|
config_class = ClassifierConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = ClassifierModel( |
|
model_name=config.model_name, |
|
num_classes=config.num_classes, |
|
pretrained=False |
|
) |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.Normalize(mean=[0.6075, 0.4093, 0.3609], std=[0.2066, 0.2036, 0.1991]) |
|
]) |
|
|
|
|
|
def forward(self, input): |
|
|
|
if isinstance(input, Image.Image): |
|
x = transforms.ToTensor()(input).unsqueeze(0) |
|
elif isinstance(input, torch.Tensor): |
|
if input.dim() == 3: |
|
x = input.unsqueeze(0) |
|
elif input.dim() == 4: |
|
x = input |
|
else: |
|
raise ValueError("Unsupported tensor shape.") |
|
else: |
|
raise TypeError(f"Unsupported input type: {type(input)}. Expected PIL.Image or torch.Tensor.") |
|
|
|
|
|
x = self.transform(x) |
|
|
|
|
|
outputs = self.model(x) |
|
|
|
confs = softmax(outputs, dim=1) |
|
preds = torch.argmax(confs, dim=1) |
|
|
|
results = [] |
|
for i in range(len(preds)): |
|
label = LABEL_MAP[preds[i]] |
|
confidences = {} |
|
for j in range(len(LABEL_MAP)): |
|
confidences[LABEL_MAP[j]] = round(float(confs[i][j]), 3) |
|
|
|
results.append({"label": label, "confidences": confidences}) |
|
return results |