Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoProcessor, AutoModelForImageTextToText | |
from PIL import Image | |
import torch | |
import logging | |
from typing import Union, Tuple | |
from config import Config | |
from knowledge_base import GarbageClassificationKnowledge | |
class GarbageClassifier: | |
def __init__(self, config: Config = None): | |
self.config = config or Config() | |
self.knowledge = GarbageClassificationKnowledge() | |
self.processor = None | |
self.model = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
self.logger = logging.getLogger(__name__) | |
def load_model(self): | |
"""Load the model and processor""" | |
try: | |
self.logger.info(f"Loading model: {self.config.MODEL_NAME}") | |
# Load processor | |
kwargs = {} | |
if self.config.HF_TOKEN: | |
kwargs["token"] = self.config.HF_TOKEN | |
self.processor = AutoProcessor.from_pretrained( | |
self.config.MODEL_NAME, **kwargs | |
) | |
# Load model | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
self.config.MODEL_NAME, | |
torch_dtype=self.config.TORCH_DTYPE, | |
device_map=self.config.DEVICE_MAP, | |
) | |
self.logger.info("Model loaded successfully") | |
except Exception as e: | |
self.logger.error(f"Error loading model: {str(e)}") | |
raise | |
def preprocess_image(self, image: Image.Image) -> Image.Image: | |
""" | |
Preprocess image to meet Gemma3n requirements (512x512) | |
""" | |
# Convert to RGB if necessary | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Resize to 512x512 as required by Gemma3n | |
target_size = (512, 512) | |
# Calculate aspect ratio preserving resize | |
original_width, original_height = image.size | |
aspect_ratio = original_width / original_height | |
if aspect_ratio > 1: | |
# Width is larger | |
new_width = target_size[0] | |
new_height = int(target_size[0] / aspect_ratio) | |
else: | |
# Height is larger or equal | |
new_height = target_size[1] | |
new_width = int(target_size[1] * aspect_ratio) | |
# Resize image maintaining aspect ratio | |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
# Create a new image with target size and paste the resized image | |
processed_image = Image.new( | |
"RGB", target_size, (255, 255, 255) | |
) # White background | |
# Calculate position to center the image | |
x_offset = (target_size[0] - new_width) // 2 | |
y_offset = (target_size[1] - new_height) // 2 | |
processed_image.paste(image, (x_offset, y_offset)) | |
return processed_image | |
def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str]: | |
""" | |
Classify garbage in the image | |
Args: | |
image: PIL Image or path to image file | |
Returns: | |
Tuple of (classification_result, full_response) | |
""" | |
if self.model is None or self.processor is None: | |
raise RuntimeError("Model not loaded. Call load_model() first.") | |
try: | |
# Load and process image | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif not isinstance(image, Image.Image): | |
raise ValueError("Image must be a PIL Image or file path") | |
# Preprocess image to meet Gemma3n requirements | |
processed_image = self.preprocess_image(image) | |
# Prepare messages with system prompt and user query | |
messages = [ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": self.knowledge.get_system_prompt(), | |
} | |
], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": processed_image}, | |
{ | |
"type": "text", | |
"text": "Please classify what you see in this image. If it shows garbage/waste items, classify them according to the garbage classification standards. If it shows people, living things, or other non-waste items, classify it as 'Unable to classify' and explain why it's not garbage.", | |
}, | |
], | |
}, | |
] | |
# Apply chat template and tokenize | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(self.model.device, dtype=self.model.dtype) | |
input_len = inputs["input_ids"].shape[-1] | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.config.MAX_NEW_TOKENS, | |
disable_compile=True, | |
) | |
response = self.processor.batch_decode( | |
outputs[:, input_len:], | |
skip_special_tokens=True, | |
)[0] | |
# Extract classification from response | |
classification = self._extract_classification(response) | |
# Create formatted response | |
formatted_response = self._format_response(classification, response) | |
return classification, formatted_response | |
except Exception as e: | |
self.logger.error(f"Error during classification: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return "Error", f"Classification failed: {str(e)}" | |
def _extract_classification(self, response: str) -> str: | |
"""Extract the main classification from the response""" | |
categories = self.knowledge.get_categories() | |
# Convert response to lowercase for matching | |
response_lower = response.lower() | |
# First check for "Unable to classify" indicators | |
unable_indicators = [ | |
"unable to classify", | |
"cannot classify", | |
"not garbage", | |
"not waste", | |
"person", | |
"people", | |
"human", | |
"face", | |
"living", | |
"alive", | |
"animal", | |
"functioning", | |
"in use", | |
"working", | |
"furniture", | |
"appliance", | |
"electronic device", | |
] | |
if any(indicator in response_lower for indicator in unable_indicators): | |
return "Unable to classify" | |
# Look for exact category matches (excluding "Unable to classify" since we handled it above) | |
waste_categories = [cat for cat in categories if cat != "Unable to classify"] | |
for category in waste_categories: | |
if category.lower() in response_lower: | |
return category | |
# Look for key terms if no exact match | |
category_keywords = { | |
"Recyclable Waste": [ | |
"recyclable", | |
"recycle", | |
"plastic", | |
"paper", | |
"metal", | |
"glass", | |
"bottle", | |
"can", | |
"aluminum", | |
"cardboard", | |
], | |
"Food/Kitchen Waste": [ | |
"food", | |
"kitchen", | |
"organic", | |
"fruit", | |
"vegetable", | |
"leftovers", | |
"scraps", | |
"peel", | |
"core", | |
"bone", | |
], | |
"Hazardous Waste": [ | |
"hazardous", | |
"dangerous", | |
"toxic", | |
"battery", | |
"chemical", | |
"medicine", | |
"paint", | |
"pharmaceutical", | |
], | |
"Other Waste": [ | |
"other", | |
"general", | |
"trash", | |
"garbage", | |
"waste", | |
"cigarette", | |
"ceramic", | |
"dust", | |
], | |
} | |
for category, keywords in category_keywords.items(): | |
if any(keyword in response_lower for keyword in keywords): | |
return category | |
# If no clear classification found, default to "Unable to classify" | |
return "Unable to classify" | |
def _format_response(self, classification: str, full_response: str) -> str: | |
"""Format the response with classification and reasoning""" | |
if not full_response.strip(): | |
return f"**Classification**: {classification}\n**Reasoning**: No detailed analysis available." | |
# If response already contains structured format, return as is | |
if "**Classification**" in full_response and "**Reasoning**" in full_response: | |
return full_response | |
# Otherwise, format it | |
return f"**Classification**: {classification}\n\n**Reasoning**: {full_response}" | |
def get_categories_info(self): | |
"""Get information about all categories""" | |
return self.knowledge.get_category_descriptions() | |