from transformers import AutoProcessor, AutoModelForImageTextToText from PIL import Image import torch import logging from typing import Union, Tuple from config import Config from knowledge_base import GarbageClassificationKnowledge class GarbageClassifier: def __init__(self, config: Config = None): self.config = config or Config() self.knowledge = GarbageClassificationKnowledge() self.processor = None self.model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup logging logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) def load_model(self): """Load the model and processor""" try: self.logger.info(f"Loading model: {self.config.MODEL_NAME}") # Load processor kwargs = {} if self.config.HF_TOKEN: kwargs["token"] = self.config.HF_TOKEN self.processor = AutoProcessor.from_pretrained( self.config.MODEL_NAME, **kwargs ) # Load model self.model = AutoModelForImageTextToText.from_pretrained( self.config.MODEL_NAME, torch_dtype=self.config.TORCH_DTYPE, device_map=self.config.DEVICE_MAP, ) self.logger.info("Model loaded successfully") except Exception as e: self.logger.error(f"Error loading model: {str(e)}") raise def preprocess_image(self, image: Image.Image) -> Image.Image: """ Preprocess image to meet Gemma3n requirements (512x512) """ # Convert to RGB if necessary if image.mode != "RGB": image = image.convert("RGB") # Resize to 512x512 as required by Gemma3n target_size = (512, 512) # Calculate aspect ratio preserving resize original_width, original_height = image.size aspect_ratio = original_width / original_height if aspect_ratio > 1: # Width is larger new_width = target_size[0] new_height = int(target_size[0] / aspect_ratio) else: # Height is larger or equal new_height = target_size[1] new_width = int(target_size[1] * aspect_ratio) # Resize image maintaining aspect ratio image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create a new image with target size and paste the resized image processed_image = Image.new( "RGB", target_size, (255, 255, 255) ) # White background # Calculate position to center the image x_offset = (target_size[0] - new_width) // 2 y_offset = (target_size[1] - new_height) // 2 processed_image.paste(image, (x_offset, y_offset)) return processed_image def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str]: """ Classify garbage in the image Args: image: PIL Image or path to image file Returns: Tuple of (classification_result, full_response) """ if self.model is None or self.processor is None: raise RuntimeError("Model not loaded. Call load_model() first.") try: # Load and process image if isinstance(image, str): image = Image.open(image) elif not isinstance(image, Image.Image): raise ValueError("Image must be a PIL Image or file path") # Preprocess image to meet Gemma3n requirements processed_image = self.preprocess_image(image) # Prepare messages with system prompt and user query messages = [ { "role": "system", "content": [ { "type": "text", "text": self.knowledge.get_system_prompt(), } ], }, { "role": "user", "content": [ {"type": "image", "image": processed_image}, { "type": "text", "text": "Please classify what you see in this image. If it shows garbage/waste items, classify them according to the garbage classification standards. If it shows people, living things, or other non-waste items, classify it as 'Unable to classify' and explain why it's not garbage.", }, ], }, ] # Apply chat template and tokenize inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device, dtype=self.model.dtype) input_len = inputs["input_ids"].shape[-1] outputs = self.model.generate( **inputs, max_new_tokens=self.config.MAX_NEW_TOKENS, disable_compile=True, ) response = self.processor.batch_decode( outputs[:, input_len:], skip_special_tokens=True, )[0] # Extract classification from response classification = self._extract_classification(response) # Create formatted response formatted_response = self._format_response(classification, response) return classification, formatted_response except Exception as e: self.logger.error(f"Error during classification: {str(e)}") import traceback traceback.print_exc() return "Error", f"Classification failed: {str(e)}" def _extract_classification(self, response: str) -> str: """Extract the main classification from the response""" categories = self.knowledge.get_categories() # Convert response to lowercase for matching response_lower = response.lower() # First check for "Unable to classify" indicators unable_indicators = [ "unable to classify", "cannot classify", "not garbage", "not waste", "person", "people", "human", "face", "living", "alive", "animal", "functioning", "in use", "working", "furniture", "appliance", "electronic device", ] if any(indicator in response_lower for indicator in unable_indicators): return "Unable to classify" # Look for exact category matches (excluding "Unable to classify" since we handled it above) waste_categories = [cat for cat in categories if cat != "Unable to classify"] for category in waste_categories: if category.lower() in response_lower: return category # Look for key terms if no exact match category_keywords = { "Recyclable Waste": [ "recyclable", "recycle", "plastic", "paper", "metal", "glass", "bottle", "can", "aluminum", "cardboard", ], "Food/Kitchen Waste": [ "food", "kitchen", "organic", "fruit", "vegetable", "leftovers", "scraps", "peel", "core", "bone", ], "Hazardous Waste": [ "hazardous", "dangerous", "toxic", "battery", "chemical", "medicine", "paint", "pharmaceutical", ], "Other Waste": [ "other", "general", "trash", "garbage", "waste", "cigarette", "ceramic", "dust", ], } for category, keywords in category_keywords.items(): if any(keyword in response_lower for keyword in keywords): return category # If no clear classification found, default to "Unable to classify" return "Unable to classify" def _format_response(self, classification: str, full_response: str) -> str: """Format the response with classification and reasoning""" if not full_response.strip(): return f"**Classification**: {classification}\n**Reasoning**: No detailed analysis available." # If response already contains structured format, return as is if "**Classification**" in full_response and "**Reasoning**" in full_response: return full_response # Otherwise, format it return f"**Classification**: {classification}\n\n**Reasoning**: {full_response}" def get_categories_info(self): """Get information about all categories""" return self.knowledge.get_category_descriptions()