import gradio as gr from PIL import Image, ImageDraw, ImageFont import numpy as np import os # Check if model file exists model_path = "export.pkl" learn_inf = None if os.path.exists(model_path): try: from fastai.vision.all import load_learner # Load your trained model learn_inf = load_learner(model_path) print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}") print("This might be due to fastai version compatibility issues.") learn_inf = None else: print("Warning: export.pkl not found. Please upload your trained model file.") # Classes classes = [ "AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial", "Pasture","PermanentCrop","Residential","River","SeaLake" ] # Assign colors for visualization class_colors = { "AnnualCrop": (255, 255, 0), "Forest": (34, 139, 34), "HerbaceousVegetation": (144, 238, 144), "Highway": (128, 128, 128), "Industrial": (255, 165, 0), "Pasture": (173, 255, 47), "PermanentCrop": (0, 255, 0), "Residential": (255, 0, 0), "River": (0, 191, 255), "SeaLake": (0, 0, 139) } # Debug: Print class information print(f"Total classes defined: {len(classes)}") print(f"Total colors defined: {len(class_colors)}") print("Classes:", classes) print("Colors:", list(class_colors.keys())) patch_size = 128 stride = 128 def classify_image(img: Image.Image): if learn_inf is None: return None, "❌ Model not loaded! Please check the console for error messages. You may need to retrain your model with a compatible fastai version or use fastai<2.8.0." try: img_np = np.array(img.convert("RGB")) h, w, _ = img_np.shape mask = np.zeros((h, w, 3), dtype=np.uint8) coverage = {cls:0 for cls in classes} # Sliding window classification for y in range(0, h-patch_size+1, stride): for x in range(0, w-patch_size+1, stride): patch = img_np[y:y+patch_size, x:x+patch_size] patch_pil = Image.fromarray(patch).convert("RGB") pred, _, _ = learn_inf.predict(patch_pil) color = class_colors[pred] mask[y:y+patch_size, x:x+patch_size] = color coverage[pred] += patch_size*patch_size # Blend original + mask blended = (0.6*img_np + 0.4*mask).astype(np.uint8) blended_img = Image.fromarray(blended) # Compute coverage percentages total_pixels = h*w coverage_pct = {cls: (coverage[cls]/total_pixels)*100 for cls in classes} # Prepare statistics text stats_text = "\n".join([f"{cls}: {coverage_pct[cls]:.2f}%" for cls in classes]) return blended_img, stats_text except Exception as e: return None, f"❌ Error during classification: {str(e)}" def create_color_legend(): """Create a color legend image showing all land cover classes and their colors""" # Create a white background image - increased height to fit all 10 classes legend_width = 450 legend_height = 450 # Increased height to accommodate all classes legend_img = Image.new('RGB', (legend_width, legend_height), 'white') draw = ImageDraw.Draw(legend_img) # Try to use a default font, fallback to basic if not available try: font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 14) # Slightly smaller font except: try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) except: font = ImageFont.load_default() # Title draw.text((10, 10), "🎨 Color Legend - Land Cover Classes", fill='black', font=font) draw.text((10, 30), f"Total Classes: {len(class_colors)}", fill='blue', font=font) # Draw color squares with labels - adjusted spacing y_offset = 55 for i, (cls, color) in enumerate(class_colors.items()): # Draw color square y_start = y_offset + i * 35 # Increased spacing between rows y_end = y_start + 25 # Draw the colored rectangle draw.rectangle([20, y_start, 50, y_end], fill=color, outline='black', width=2) # Add class name with index for debugging class_text = f"{i+1}. {cls}" draw.text((60, y_start + 2), class_text, fill='black', font=font) # Add RGB values for reference rgb_text = f"RGB{color}" draw.text((280, y_start + 2), rgb_text, fill='gray', font=font) y_offset += 35 # Increased spacing # Add footer with total count footer_y = y_offset + 10 draw.text((10, footer_y), f"Legend shows {len(class_colors)} land cover classes", fill='green', font=font) return legend_img # Build Gradio interface with gr.Blocks(title="🌍 Satellite Land Cover Classifier") as iface: gr.Markdown("# 🌍 Satellite Land Cover Classifier") gr.Markdown("Upload a satellite image and get land-cover classification heatmap + coverage stats.") with gr.Row(): with gr.Column(scale=2): # Input section gr.Markdown("## 📤 Upload Image") input_image = gr.Image(type="pil", label="Upload Satellite Image") classify_btn = gr.Button("🔍 Classify Image", variant="primary") with gr.Column(scale=1): # Color legend section gr.Markdown("## 🎨 Color Legend") legend_image = gr.Image(value=create_color_legend(), label="Land Cover Classes", interactive=False) with gr.Row(): with gr.Column(scale=2): # Output section gr.Markdown("## 📊 Results") output_image = gr.Image(type="pil", label="Classification Result") output_stats = gr.Textbox(label="Coverage Statistics", lines=10) with gr.Column(scale=1): # Additional info gr.Markdown("## â„šī¸ How it works") gr.Markdown(""" This classifier uses a sliding window approach: - **Patch Size**: 128x128 pixels - **Stride**: 128 pixels (no overlap) - **Output**: Color-coded heatmap + coverage percentages **Note**: You need to upload your trained model file (export.pkl) to use this classifier. """) # Connect the button classify_btn.click( fn=classify_image, inputs=input_image, outputs=[output_image, output_stats] ) iface.launch()