import torch from PIL import Image import torch.nn.functional as F from torchvision import transforms from transformers import AutoModelForImageClassification, AutoConfig import requests from io import BytesIO import os from huggingface_hub import hf_hub_download from dotenv import load_dotenv load_dotenv() HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") class SkinDiseaseClassifier: CLASS_NAMES = [ "Acne", "Basal Cell Carcinoma", "Benign Keratosis-like Lesions", "Chickenpox", "Eczema", "Healthy Skin", "Measles", "Melanocytic Nevi", "Melanoma", "Monkeypox", "Psoriasis Lichen Planus and related diseases", "Seborrheic Keratoses and other Benign Tumors", "Tinea Ringworm Candidiasis and other Fungal Infections", "Vitiligo", "Warts Molluscum and other Viral Infections" ] def __init__(self, repo_id="muhammadnoman76/skin-disease-classifier"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.repo_id = repo_id self.model = self.load_trained_model() self.transform = self.get_inference_transform() def load_trained_model(self): model_path= hf_hub_download(repo_id=self.repo_id, filename="healthy.pth", token=HUGGINGFACE_TOKEN) checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) classifier_weight = checkpoint['model_state_dict']['classifier.3.weight'] num_classes = classifier_weight.size(0) config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=num_classes) model = AutoModelForImageClassification.from_pretrained( "google/vit-base-patch16-224-in21k", config=config, ignore_mismatched_sizes=True ) in_features = model.classifier.in_features model.classifier = torch.nn.Sequential( torch.nn.Linear(in_features, 512), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(512, num_classes) ) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(self.device) if self.device.type == 'cuda': model = model.half() model.eval() return model def get_inference_transform(self): return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def load_image(self, image_input): try: if isinstance(image_input, Image.Image): image = image_input elif isinstance(image_input, str): if image_input.startswith(('http://', 'https://')): response = requests.get(image_input) image = Image.open(BytesIO(response.content)) else: if not os.path.exists(image_input): raise FileNotFoundError(f"Image file not found: {image_input}") image = Image.open(image_input) elif hasattr(image_input, 'read'): image = Image.open(image_input) else: raise ValueError("Unsupported image input type") return image.convert('RGB') except Exception as e: raise Exception(f"Error loading image: {str(e)}") def predict(self, image_input, confidence_threshold=0.3): try: image = self.load_image(image_input) image_tensor = self.transform(image).unsqueeze(0) if self.device.type == 'cuda': image_tensor = image_tensor.half() image_tensor = image_tensor.to(self.device) with torch.inference_mode(): outputs = self.model(pixel_values=image_tensor).logits probabilities = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, 1) confidence = confidence.item() predicted_class_idx = predicted.item() confidence_percentage = round(confidence * 100, 2) predicted_class_name = self.CLASS_NAMES[predicted_class_idx] return predicted_class_name, confidence_percentage except Exception as e: raise Exception(f"Error during prediction: {str(e)}")