derm-ai / app /services /image_classification_vit.py
muhammadnoman76's picture
update
75e2b6c
import torch
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoConfig
import requests
from io import BytesIO
import os
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
class SkinDiseaseClassifier:
CLASS_NAMES = [
"Acne", "Basal Cell Carcinoma", "Benign Keratosis-like Lesions", "Chickenpox", "Eczema", "Healthy Skin",
"Measles", "Melanocytic Nevi", "Melanoma", "Monkeypox", "Psoriasis Lichen Planus and related diseases",
"Seborrheic Keratoses and other Benign Tumors", "Tinea Ringworm Candidiasis and other Fungal Infections",
"Vitiligo", "Warts Molluscum and other Viral Infections"
]
def __init__(self, repo_id="muhammadnoman76/skin-disease-classifier"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.repo_id = repo_id
self.model = self.load_trained_model()
self.transform = self.get_inference_transform()
def load_trained_model(self):
model_path= hf_hub_download(repo_id=self.repo_id, filename="healthy.pth", token=HUGGINGFACE_TOKEN)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
classifier_weight = checkpoint['model_state_dict']['classifier.3.weight']
num_classes = classifier_weight.size(0)
config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=num_classes)
model = AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
config=config,
ignore_mismatched_sizes=True
)
in_features = model.classifier.in_features
model.classifier = torch.nn.Sequential(
torch.nn.Linear(in_features, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(512, num_classes)
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(self.device)
if self.device.type == 'cuda':
model = model.half()
model.eval()
return model
def get_inference_transform(self):
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def load_image(self, image_input):
try:
if isinstance(image_input, Image.Image):
image = image_input
elif isinstance(image_input, str):
if image_input.startswith(('http://', 'https://')):
response = requests.get(image_input)
image = Image.open(BytesIO(response.content))
else:
if not os.path.exists(image_input):
raise FileNotFoundError(f"Image file not found: {image_input}")
image = Image.open(image_input)
elif hasattr(image_input, 'read'):
image = Image.open(image_input)
else:
raise ValueError("Unsupported image input type")
return image.convert('RGB')
except Exception as e:
raise Exception(f"Error loading image: {str(e)}")
def predict(self, image_input, confidence_threshold=0.3):
try:
image = self.load_image(image_input)
image_tensor = self.transform(image).unsqueeze(0)
if self.device.type == 'cuda':
image_tensor = image_tensor.half()
image_tensor = image_tensor.to(self.device)
with torch.inference_mode():
outputs = self.model(pixel_values=image_tensor).logits
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
confidence = confidence.item()
predicted_class_idx = predicted.item()
confidence_percentage = round(confidence * 100, 2)
predicted_class_name = self.CLASS_NAMES[predicted_class_idx]
return predicted_class_name, confidence_percentage
except Exception as e:
raise Exception(f"Error during prediction: {str(e)}")