Gemma3n-challenge-demo / classifier.py
HMWCS's picture
fix: fix failed uploads
0be64b3 verified
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torch
import logging
from typing import Union, Tuple
from config import Config
from knowledge_base import GarbageClassificationKnowledge
import re
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess image to meet Gemma3n requirements (512x512)
"""
# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 512x512 as required by Gemma3n
target_size = (512, 512)
# Calculate aspect ratio preserving resize
original_width, original_height = image.size
aspect_ratio = original_width / original_height
if aspect_ratio > 1:
# Width is larger
new_width = target_size[0]
new_height = int(target_size[0] / aspect_ratio)
else:
# Height is larger or equal
new_height = target_size[1]
new_width = int(target_size[1] * aspect_ratio)
# Resize image maintaining aspect ratio
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create a new image with target size and paste the resized image
processed_image = Image.new(
"RGB", target_size, (255, 255, 255)
) # White background
# Calculate position to center the image
x_offset = (target_size[0] - new_width) // 2
y_offset = (target_size[1] - new_height) // 2
processed_image.paste(image, (x_offset, y_offset))
return processed_image
class GarbageClassifier:
def __init__(self, config: Config = None):
self.config = config or Config()
self.knowledge = GarbageClassificationKnowledge()
self.processor = None
self.model = None
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup logging
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def load_model(self):
"""Load the model and processor"""
try:
self.logger.info(f"Loading model: {self.config.MODEL_NAME}")
# Load processor
kwargs = {}
if self.config.HF_TOKEN:
kwargs["token"] = self.config.HF_TOKEN
self.processor = AutoProcessor.from_pretrained(
self.config.MODEL_NAME, **kwargs
)
# Load model
self.model = AutoModelForImageTextToText.from_pretrained(
self.config.MODEL_NAME,
torch_dtype=self.config.TORCH_DTYPE,
device_map=self.config.DEVICE_MAP,
)
self.logger.info("Model loaded successfully")
except Exception as e:
self.logger.error(f"Error loading model: {str(e)}")
raise
def classify_image(self, image: Union[str, Image.Image]) -> Tuple[str, str, int]:
"""
Classify garbage in the image
Args:
image: PIL Image or path to image file
Returns:
Tuple of (classification_result, detailed_analysis, confidence_score)
"""
if self.model is None or self.processor is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
try:
# Load and process image
if isinstance(image, str):
image = Image.open(image)
elif not isinstance(image, Image.Image):
raise ValueError("Image must be a PIL Image or file path")
# Preprocess image to meet Gemma3n requirements
processed_image = preprocess_image(image)
# Prepare messages with system prompt and user query
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": self.knowledge.get_system_prompt(),
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": processed_image},
{
"type": "text",
"text": "Please classify what you see in this image. If it shows garbage/waste items, classify them according to the garbage classification standards. If it shows people, living things, or other non-waste items, classify it as 'Unable to classify' and explain why it's not garbage. Also provide a confidence score from 1-10 indicating how certain you are about your classification.",
},
],
},
]
# Apply chat template and tokenize
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device, dtype=self.model.dtype)
input_len = inputs["input_ids"].shape[-1]
outputs = self.model.generate(
**inputs,
max_new_tokens=self.config.MAX_NEW_TOKENS,
disable_compile=True,
)
response = self.processor.batch_decode(
outputs[:, input_len:],
skip_special_tokens=True,
)[0]
# Extract classification from response
classification = self._extract_classification(response)
# Extract reasoning from response
reasoning = self._extract_reasoning(response)
# Extract confidence score from response
confidence_score = self._extract_confidence_score(response, classification)
return classification, reasoning, confidence_score
except Exception as e:
self.logger.error(f"Error during classification: {str(e)}")
import traceback
traceback.print_exc()
return "Error", f"Classification failed: {str(e)}", 0
def _calculate_confidence_heuristic(self, response_lower: str, classification: str) -> int:
"""Calculate confidence based on response content and classification type"""
base_confidence = 5
# Confidence indicators (increase confidence)
high_confidence_words = ["clearly", "obviously", "definitely", "certainly", "exactly"]
medium_confidence_words = ["appears", "seems", "likely", "probably"]
# Uncertainty indicators (decrease confidence)
uncertainty_words = ["might", "could", "possibly", "maybe", "unclear", "difficult"]
# Adjust based on confidence words
for word in high_confidence_words:
if word in response_lower:
base_confidence += 2
break
for word in medium_confidence_words:
if word in response_lower:
base_confidence += 1
break
for word in uncertainty_words:
if word in response_lower:
base_confidence -= 2
break
# Classification-specific adjustments
if classification == "Unable to classify":
if any(indicator in response_lower for indicator in ["person", "people", "human", "living"]):
base_confidence += 1 # High confidence when clearly not waste
else:
base_confidence -= 1 # Lower confidence for unclear items
elif classification == "Error":
base_confidence = 1
else:
# Check for specific material mentions (increases confidence)
specific_materials = ["aluminum", "plastic", "glass", "metal", "cardboard", "paper"]
if any(material in response_lower for material in specific_materials):
base_confidence += 1
return min(max(base_confidence, 1), 10)
def _extract_confidence_score(self, response: str, classification: str) -> int:
"""Extract confidence score from response or calculate based on classification"""
response_lower = response.lower()
# Look for explicit confidence scores in the response
confidence_patterns = [
r'\*\*confidence score\*\*[:\s]*(\d+)', # For **Confidence Score**: format
r'confidence[:\s]*(\d+)',
r'confident[:\s]*(\d+)',
r'certainty[:\s]*(\d+)',
r'score[:\s]*(\d+)',
r'(\d+)/10',
r'(\d+)\s*out\s*of\s*10'
]
for pattern in confidence_patterns:
match = re.search(pattern, response_lower)
if match:
score = int(match.group(1))
return min(max(score, 1), 10) # Clamp between 1-10
# If no explicit score found, calculate based on classification indicators
return self._calculate_confidence_heuristic(response_lower, classification)
def _extract_classification(self, response: str) -> str:
"""Extract the main classification from the response - trust Gemma 3n intelligence more"""
response_lower = response.lower()
# Primary: Trust explicit category mentions from Gemma 3n
categories = self.knowledge.get_categories()
for category in categories:
if category.lower() in response_lower:
# Simple negation check
category_index = response_lower.find(category.lower())
context_before = response_lower[max(0, category_index - 20):category_index]
if not any(neg in context_before[-10:] for neg in ["not", "cannot", "isn't"]):
return category
# Secondary: Look for explicit mixed garbage warnings from model
mixed_warnings = [
"multiple garbage types detected",
"separate items",
"different garbage types",
"mixed together"
]
if any(warning in response_lower for warning in mixed_warnings):
return "Unable to classify"
# Tertiary: Basic material detection (simplified)
if any(material in response_lower for material in
["recyclable", "aluminum", "plastic", "glass", "metal", "cardboard"]):
# Check for contamination
if any(cont in response_lower for cont in ["obvious food", "substantial residue", "chunks", "liquids"]):
return "Food/Kitchen Waste"
return "Recyclable Waste"
if any(food in response_lower for food in ["food", "organic", "kitchen", "fruit", "vegetable"]):
return "Food/Kitchen Waste"
if any(hazard in response_lower for hazard in ["battery", "hazardous", "chemical", "toxic"]):
return "Hazardous Waste"
if any(other in response_lower for other in ["cigarette", "ceramic", "styrofoam"]):
return "Other Waste"
# Non-garbage detection
if any(non_garbage in response_lower for non_garbage in ["person", "people", "human", "living", "animal"]):
return "Unable to classify"
# Final fallback - let Gemma 3n's reasoning guide us
if any(unable in response_lower for unable in ["unable to classify", "cannot classify", "not garbage"]):
return "Unable to classify"
# Default to Unable to classify if unclear
return "Unable to classify"
def _extract_reasoning(self, response: str) -> str:
"""Extract only the reasoning content, removing all formatting markers and classification info"""
import re
# Remove all formatting markers
cleaned_response = response.replace("**Classification**:", "")
cleaned_response = cleaned_response.replace("**Reasoning**:", "")
cleaned_response = re.sub(r'\*\*.*?\*\*:', '', cleaned_response) # Remove any **text**: patterns
cleaned_response = cleaned_response.replace("**", "") # Remove remaining ** markers
# Remove category names that might appear at the beginning
categories = self.knowledge.get_categories()
for category in categories:
if cleaned_response.strip().startswith(category):
cleaned_response = cleaned_response.replace(category, "", 1)
break
# Remove common material names that might appear at the beginning
material_names = [
"Glass", "Plastic", "Metal", "Paper", "Cardboard", "Aluminum",
"Steel", "Iron", "Tin", "Foil", "Wood", "Ceramic", "Fabric",
"Recyclable Waste", "Food/Kitchen Waste", "Hazardous Waste", "Other Waste"
]
# Clean the response
cleaned_response = cleaned_response.strip()
# Remove material names at the beginning
for material in material_names:
if cleaned_response.startswith(material):
# Remove the material name and any following punctuation/whitespace
cleaned_response = cleaned_response[len(material):].lstrip(" .,;:")
break
# Split into sentences and clean up
sentences = []
# Split by common sentence endings, but keep the endings
parts = re.split(r'([.!?])\s+', cleaned_response)
# Rejoin parts to maintain sentence structure
reconstructed_parts = []
for i in range(0, len(parts), 2):
if i < len(parts):
sentence = parts[i]
if i + 1 < len(parts):
sentence += parts[i + 1] # Add the punctuation back
reconstructed_parts.append(sentence)
for part in reconstructed_parts:
part = part.strip()
if not part:
continue
# Skip parts that are just category names or material names
if part in categories or part.rstrip(".,;:") in material_names:
continue
# Skip parts that start with category names or material names
is_category_line = False
for item in categories + material_names:
if part.startswith(item):
is_category_line = True
break
if is_category_line:
continue
# Clean up the sentence
part = re.sub(r'^[A-Za-z\s]+:', '', part).strip() # Remove "Category:" type prefixes
if part and len(part) > 3: # Only keep meaningful content
sentences.append(part)
# Join sentences
reasoning = ' '.join(sentences)
# Final cleanup - remove any remaining standalone material words at the beginning
reasoning_words = reasoning.split()
if reasoning_words and reasoning_words[0] in [m.lower() for m in material_names]:
reasoning_words = reasoning_words[1:]
reasoning = ' '.join(reasoning_words)
# Ensure proper capitalization
if reasoning:
reasoning = reasoning[0].upper() + reasoning[1:] if len(reasoning) > 1 else reasoning.upper()
# Ensure proper punctuation
if not reasoning.endswith(('.', '!', '?')):
reasoning += '.'
return reasoning if reasoning else "Analysis not available"
def get_categories_info(self):
"""Get information about all categories"""
return self.knowledge.get_category_descriptions()