sagar250277's picture
Fix color legend to show all 10 classes with proper dimensions and debugging info
e17c970
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os
# Check if model file exists
model_path = "export.pkl"
learn_inf = None
if os.path.exists(model_path):
try:
from fastai.vision.all import load_learner
# Load your trained model
learn_inf = load_learner(model_path)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
print("This might be due to fastai version compatibility issues.")
learn_inf = None
else:
print("Warning: export.pkl not found. Please upload your trained model file.")
# Classes
classes = [
"AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial",
"Pasture","PermanentCrop","Residential","River","SeaLake"
]
# Assign colors for visualization
class_colors = {
"AnnualCrop": (255, 255, 0),
"Forest": (34, 139, 34),
"HerbaceousVegetation": (144, 238, 144),
"Highway": (128, 128, 128),
"Industrial": (255, 165, 0),
"Pasture": (173, 255, 47),
"PermanentCrop": (0, 255, 0),
"Residential": (255, 0, 0),
"River": (0, 191, 255),
"SeaLake": (0, 0, 139)
}
# Debug: Print class information
print(f"Total classes defined: {len(classes)}")
print(f"Total colors defined: {len(class_colors)}")
print("Classes:", classes)
print("Colors:", list(class_colors.keys()))
patch_size = 128
stride = 128
def classify_image(img: Image.Image):
if learn_inf is None:
return None, "❌ Model not loaded! Please check the console for error messages. You may need to retrain your model with a compatible fastai version or use fastai<2.8.0."
try:
img_np = np.array(img.convert("RGB"))
h, w, _ = img_np.shape
mask = np.zeros((h, w, 3), dtype=np.uint8)
coverage = {cls:0 for cls in classes}
# Sliding window classification
for y in range(0, h-patch_size+1, stride):
for x in range(0, w-patch_size+1, stride):
patch = img_np[y:y+patch_size, x:x+patch_size]
patch_pil = Image.fromarray(patch).convert("RGB")
pred, _, _ = learn_inf.predict(patch_pil)
color = class_colors[pred]
mask[y:y+patch_size, x:x+patch_size] = color
coverage[pred] += patch_size*patch_size
# Blend original + mask
blended = (0.6*img_np + 0.4*mask).astype(np.uint8)
blended_img = Image.fromarray(blended)
# Compute coverage percentages
total_pixels = h*w
coverage_pct = {cls: (coverage[cls]/total_pixels)*100 for cls in classes}
# Prepare statistics text
stats_text = "\n".join([f"{cls}: {coverage_pct[cls]:.2f}%" for cls in classes])
return blended_img, stats_text
except Exception as e:
return None, f"❌ Error during classification: {str(e)}"
def create_color_legend():
"""Create a color legend image showing all land cover classes and their colors"""
# Create a white background image - increased height to fit all 10 classes
legend_width = 450
legend_height = 450 # Increased height to accommodate all classes
legend_img = Image.new('RGB', (legend_width, legend_height), 'white')
draw = ImageDraw.Draw(legend_img)
# Try to use a default font, fallback to basic if not available
try:
font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 14) # Slightly smaller font
except:
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
except:
font = ImageFont.load_default()
# Title
draw.text((10, 10), "🎨 Color Legend - Land Cover Classes", fill='black', font=font)
draw.text((10, 30), f"Total Classes: {len(class_colors)}", fill='blue', font=font)
# Draw color squares with labels - adjusted spacing
y_offset = 55
for i, (cls, color) in enumerate(class_colors.items()):
# Draw color square
y_start = y_offset + i * 35 # Increased spacing between rows
y_end = y_start + 25
# Draw the colored rectangle
draw.rectangle([20, y_start, 50, y_end], fill=color, outline='black', width=2)
# Add class name with index for debugging
class_text = f"{i+1}. {cls}"
draw.text((60, y_start + 2), class_text, fill='black', font=font)
# Add RGB values for reference
rgb_text = f"RGB{color}"
draw.text((280, y_start + 2), rgb_text, fill='gray', font=font)
y_offset += 35 # Increased spacing
# Add footer with total count
footer_y = y_offset + 10
draw.text((10, footer_y), f"Legend shows {len(class_colors)} land cover classes", fill='green', font=font)
return legend_img
# Build Gradio interface
with gr.Blocks(title="🌍 Satellite Land Cover Classifier") as iface:
gr.Markdown("# 🌍 Satellite Land Cover Classifier")
gr.Markdown("Upload a satellite image and get land-cover classification heatmap + coverage stats.")
with gr.Row():
with gr.Column(scale=2):
# Input section
gr.Markdown("## 📤 Upload Image")
input_image = gr.Image(type="pil", label="Upload Satellite Image")
classify_btn = gr.Button("🔍 Classify Image", variant="primary")
with gr.Column(scale=1):
# Color legend section
gr.Markdown("## 🎨 Color Legend")
legend_image = gr.Image(value=create_color_legend(), label="Land Cover Classes", interactive=False)
with gr.Row():
with gr.Column(scale=2):
# Output section
gr.Markdown("## 📊 Results")
output_image = gr.Image(type="pil", label="Classification Result")
output_stats = gr.Textbox(label="Coverage Statistics", lines=10)
with gr.Column(scale=1):
# Additional info
gr.Markdown("## ℹ️ How it works")
gr.Markdown("""
This classifier uses a sliding window approach:
- **Patch Size**: 128x128 pixels
- **Stride**: 128 pixels (no overlap)
- **Output**: Color-coded heatmap + coverage percentages
**Note**: You need to upload your trained model file (export.pkl) to use this classifier.
""")
# Connect the button
classify_btn.click(
fn=classify_image,
inputs=input_image,
outputs=[output_image, output_stats]
)
iface.launch()