File size: 4,451 Bytes
75e2b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoConfig
import requests
from io import BytesIO
import os
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv


load_dotenv()

HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")


class SkinDiseaseClassifier:
    CLASS_NAMES = [
        "Acne", "Basal Cell Carcinoma", "Benign Keratosis-like Lesions", "Chickenpox", "Eczema", "Healthy Skin",
        "Measles", "Melanocytic Nevi", "Melanoma", "Monkeypox", "Psoriasis Lichen Planus and related diseases",
        "Seborrheic Keratoses and other Benign Tumors", "Tinea Ringworm Candidiasis and other Fungal Infections",
        "Vitiligo", "Warts Molluscum and other Viral Infections"
    ]

    def __init__(self, repo_id="muhammadnoman76/skin-disease-classifier"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.repo_id = repo_id
        self.model = self.load_trained_model()
        self.transform = self.get_inference_transform()

    def load_trained_model(self):
        model_path= hf_hub_download(repo_id=self.repo_id, filename="healthy.pth", token=HUGGINGFACE_TOKEN)
        
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
        classifier_weight = checkpoint['model_state_dict']['classifier.3.weight']
        num_classes = classifier_weight.size(0)

        config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=num_classes)
        model = AutoModelForImageClassification.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            config=config,
            ignore_mismatched_sizes=True
        )

        in_features = model.classifier.in_features
        model.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(512, num_classes)
        )

        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(self.device)
        if self.device.type == 'cuda':
            model = model.half()

        model.eval()
        return model

    def get_inference_transform(self):
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def load_image(self, image_input):
        try:
            if isinstance(image_input, Image.Image):
                image = image_input
            elif isinstance(image_input, str):
                if image_input.startswith(('http://', 'https://')):
                    response = requests.get(image_input)
                    image = Image.open(BytesIO(response.content))
                else:
                    if not os.path.exists(image_input):
                        raise FileNotFoundError(f"Image file not found: {image_input}")
                    image = Image.open(image_input)
            elif hasattr(image_input, 'read'):
                image = Image.open(image_input)
            else:
                raise ValueError("Unsupported image input type")
            return image.convert('RGB')
        except Exception as e:
            raise Exception(f"Error loading image: {str(e)}")

    def predict(self, image_input, confidence_threshold=0.3):
        try:
            image = self.load_image(image_input)
            image_tensor = self.transform(image).unsqueeze(0)
            if self.device.type == 'cuda':
                image_tensor = image_tensor.half()
            image_tensor = image_tensor.to(self.device)
            with torch.inference_mode():
                outputs = self.model(pixel_values=image_tensor).logits
                probabilities = F.softmax(outputs, dim=1)
                confidence, predicted = torch.max(probabilities, 1)

                confidence = confidence.item()
                predicted_class_idx = predicted.item()
                confidence_percentage = round(confidence * 100, 2)
                predicted_class_name = self.CLASS_NAMES[predicted_class_idx]

                return predicted_class_name, confidence_percentage

        except Exception as e:
            raise Exception(f"Error during prediction: {str(e)}")