Spaces:
Running
Running
import torch | |
from PIL import Image | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from transformers import AutoModelForImageClassification, AutoConfig | |
import requests | |
from io import BytesIO | |
import os | |
from huggingface_hub import hf_hub_download | |
from dotenv import load_dotenv | |
load_dotenv() | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
class SkinDiseaseClassifier: | |
CLASS_NAMES = [ | |
"Acne", "Basal Cell Carcinoma", "Benign Keratosis-like Lesions", "Chickenpox", "Eczema", "Healthy Skin", | |
"Measles", "Melanocytic Nevi", "Melanoma", "Monkeypox", "Psoriasis Lichen Planus and related diseases", | |
"Seborrheic Keratoses and other Benign Tumors", "Tinea Ringworm Candidiasis and other Fungal Infections", | |
"Vitiligo", "Warts Molluscum and other Viral Infections" | |
] | |
def __init__(self, repo_id="muhammadnoman76/skin-disease-classifier"): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.repo_id = repo_id | |
self.model = self.load_trained_model() | |
self.transform = self.get_inference_transform() | |
def load_trained_model(self): | |
model_path= hf_hub_download(repo_id=self.repo_id, filename="healthy.pth", token=HUGGINGFACE_TOKEN) | |
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) | |
classifier_weight = checkpoint['model_state_dict']['classifier.3.weight'] | |
num_classes = classifier_weight.size(0) | |
config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=num_classes) | |
model = AutoModelForImageClassification.from_pretrained( | |
"google/vit-base-patch16-224-in21k", | |
config=config, | |
ignore_mismatched_sizes=True | |
) | |
in_features = model.classifier.in_features | |
model.classifier = torch.nn.Sequential( | |
torch.nn.Linear(in_features, 512), | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.3), | |
torch.nn.Linear(512, num_classes) | |
) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model = model.to(self.device) | |
if self.device.type == 'cuda': | |
model = model.half() | |
model.eval() | |
return model | |
def get_inference_transform(self): | |
return transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def load_image(self, image_input): | |
try: | |
if isinstance(image_input, Image.Image): | |
image = image_input | |
elif isinstance(image_input, str): | |
if image_input.startswith(('http://', 'https://')): | |
response = requests.get(image_input) | |
image = Image.open(BytesIO(response.content)) | |
else: | |
if not os.path.exists(image_input): | |
raise FileNotFoundError(f"Image file not found: {image_input}") | |
image = Image.open(image_input) | |
elif hasattr(image_input, 'read'): | |
image = Image.open(image_input) | |
else: | |
raise ValueError("Unsupported image input type") | |
return image.convert('RGB') | |
except Exception as e: | |
raise Exception(f"Error loading image: {str(e)}") | |
def predict(self, image_input, confidence_threshold=0.3): | |
try: | |
image = self.load_image(image_input) | |
image_tensor = self.transform(image).unsqueeze(0) | |
if self.device.type == 'cuda': | |
image_tensor = image_tensor.half() | |
image_tensor = image_tensor.to(self.device) | |
with torch.inference_mode(): | |
outputs = self.model(pixel_values=image_tensor).logits | |
probabilities = F.softmax(outputs, dim=1) | |
confidence, predicted = torch.max(probabilities, 1) | |
confidence = confidence.item() | |
predicted_class_idx = predicted.item() | |
confidence_percentage = round(confidence * 100, 2) | |
predicted_class_name = self.CLASS_NAMES[predicted_class_idx] | |
return predicted_class_name, confidence_percentage | |
except Exception as e: | |
raise Exception(f"Error during prediction: {str(e)}") |