#!/usr/bin/env python3 """MAE ViT-Base waste classifier for inference.""" import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import timm import os import json from huggingface_hub import hf_hub_download class MAEWasteClassifier: """Waste classifier using finetuned MAE ViT-Base model.""" def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None): self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.hf_model_id = hf_model_id # Try to load model from different sources if model_path and os.path.exists(model_path): self.model_path = model_path print(f"๐Ÿ“ Using local model: {model_path}") else: # Try to download from HF Hub try: print(f"๐ŸŒ Downloading model from HF Hub: {hf_model_id}") self.model_path = hf_hub_download( repo_id=hf_model_id, filename="best_model.pth", cache_dir="./hf_cache" ) print(f"โœ… Downloaded model to: {self.model_path}") except Exception as e: print(f"โš ๏ธ Could not download from HF Hub: {e}") # Fallback to local path self.model_path = "output_simple_mae/best_model.pth" if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub") # Class names from training self.class_names = [ 'Cardboard', 'Food Organics', 'Glass', 'Metal', 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation' ] # Load disposal instructions self.disposal_instructions = { "Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.", "Food Organics": "Compost in organic waste bin or home composter.", "Glass": "Rinse and place in glass recycling. Remove lids and caps.", "Metal": "Rinse aluminum/steel cans and place in recycling bin.", "Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.", "Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.", "Plastic": "Check recycling number. Rinse containers before recycling.", "Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.", "Vegetation": "Compost in organic waste or use for mulch in garden." } # Load model self.model = self._load_model() # Image preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print(f"โœ… MAE Waste Classifier loaded on {self.device}") print(f"๐Ÿ“Š Model: ViT-Base MAE, Classes: {len(self.class_names)}") def _load_model(self): """Load the finetuned MAE model.""" try: # Create ViT model using timm model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names)) # Load checkpoint checkpoint = torch.load(self.model_path, map_location=self.device) # Load state dict if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.to(self.device) model.eval() print(f"โœ… Loaded finetuned MAE model from {self.model_path}") return model except Exception as e: print(f"โŒ Error loading model: {e}") raise def classify_image(self, image, top_k=5): """ Classify a waste image. Args: image: PIL Image or path to image top_k: Number of top predictions to return Returns: dict: Classification results """ try: # Load and preprocess image if isinstance(image, str): image = Image.open(image).convert('RGB') elif not isinstance(image, Image.Image): raise ValueError("Image must be PIL Image or path string") # Preprocess input_tensor = self.transform(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): outputs = self.model(input_tensor) probabilities = F.softmax(outputs, dim=1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names))) top_predictions = [] for prob, idx in zip(top_probs[0], top_indices[0]): top_predictions.append({ 'class': self.class_names[idx.item()], 'confidence': prob.item() }) # Best prediction best_pred = top_predictions[0] return { 'success': True, 'predicted_class': best_pred['class'], 'confidence': best_pred['confidence'], 'top_predictions': top_predictions } except Exception as e: return { 'success': False, 'error': str(e) } def get_disposal_instructions(self, class_name): """Get disposal instructions for a waste class.""" return self.disposal_instructions.get(class_name, "No specific instructions available.") def get_model_info(self): """Get information about the loaded model.""" return { 'model_name': 'ViT-Base MAE', 'architecture': 'Vision Transformer (ViT-Base)', 'pretrained': 'MAE (Masked Autoencoder)', 'num_classes': len(self.class_names), 'device': self.device, 'model_path': self.model_path } # Test the classifier if __name__ == "__main__": print("๐Ÿงช Testing MAE Waste Classifier...") try: # Initialize classifier classifier = MAEWasteClassifier() # Test with a sample image if available test_images = [ "fail_images/image.webp", "fail_images/IMG_9501.webp" ] for img_path in test_images: if os.path.exists(img_path): print(f"\n๐Ÿ” Testing with {img_path}") result = classifier.classify_image(img_path) if result['success']: print(f"โœ… Predicted: {result['predicted_class']} ({result['confidence']:.3f})") print(f"๐Ÿ“‹ Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}") print("\n๐Ÿ“Š Top predictions:") for i, pred in enumerate(result['top_predictions'][:3], 1): print(f" {i}. {pred['class']}: {pred['confidence']:.3f}") else: print(f"โŒ Error: {result['error']}") break else: print("โ„น๏ธ No test images found, but classifier loaded successfully!") # Print model info info = classifier.get_model_info() print(f"\n๐Ÿค– Model Info:") for key, value in info.items(): print(f" {key}: {value}") print("\nSuccess!") except Exception as e: print(f"โŒ Error: {e}") import traceback traceback.print_exc()