File size: 6,580 Bytes
9fc6414 ec08466 9fc6414 5df6d0e 9fc6414 5df6d0e e8d546b 5df6d0e e8d546b 9fc6414 e17c970 9fc6414 5df6d0e e8d546b 5df6d0e e8d546b 9fc6414 e8d546b 9fc6414 e8d546b 9fc6414 e8d546b 9fc6414 e8d546b 9fc6414 e8d546b 9fc6414 e8d546b 9fc6414 a7c4bde e17c970 ec08466 e17c970 ec08466 e17c970 ec08466 e17c970 a7c4bde e17c970 a7c4bde e17c970 ec08466 e17c970 a7c4bde ec08466 e17c970 a7c4bde e17c970 a7c4bde ec08466 a7c4bde 9fc6414 a7c4bde ec08466 a7c4bde 9fc6414 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os
# Check if model file exists
model_path = "export.pkl"
learn_inf = None
if os.path.exists(model_path):
try:
from fastai.vision.all import load_learner
# Load your trained model
learn_inf = load_learner(model_path)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
print("This might be due to fastai version compatibility issues.")
learn_inf = None
else:
print("Warning: export.pkl not found. Please upload your trained model file.")
# Classes
classes = [
"AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial",
"Pasture","PermanentCrop","Residential","River","SeaLake"
]
# Assign colors for visualization
class_colors = {
"AnnualCrop": (255, 255, 0),
"Forest": (34, 139, 34),
"HerbaceousVegetation": (144, 238, 144),
"Highway": (128, 128, 128),
"Industrial": (255, 165, 0),
"Pasture": (173, 255, 47),
"PermanentCrop": (0, 255, 0),
"Residential": (255, 0, 0),
"River": (0, 191, 255),
"SeaLake": (0, 0, 139)
}
# Debug: Print class information
print(f"Total classes defined: {len(classes)}")
print(f"Total colors defined: {len(class_colors)}")
print("Classes:", classes)
print("Colors:", list(class_colors.keys()))
patch_size = 128
stride = 128
def classify_image(img: Image.Image):
if learn_inf is None:
return None, "❌ Model not loaded! Please check the console for error messages. You may need to retrain your model with a compatible fastai version or use fastai<2.8.0."
try:
img_np = np.array(img.convert("RGB"))
h, w, _ = img_np.shape
mask = np.zeros((h, w, 3), dtype=np.uint8)
coverage = {cls:0 for cls in classes}
# Sliding window classification
for y in range(0, h-patch_size+1, stride):
for x in range(0, w-patch_size+1, stride):
patch = img_np[y:y+patch_size, x:x+patch_size]
patch_pil = Image.fromarray(patch).convert("RGB")
pred, _, _ = learn_inf.predict(patch_pil)
color = class_colors[pred]
mask[y:y+patch_size, x:x+patch_size] = color
coverage[pred] += patch_size*patch_size
# Blend original + mask
blended = (0.6*img_np + 0.4*mask).astype(np.uint8)
blended_img = Image.fromarray(blended)
# Compute coverage percentages
total_pixels = h*w
coverage_pct = {cls: (coverage[cls]/total_pixels)*100 for cls in classes}
# Prepare statistics text
stats_text = "\n".join([f"{cls}: {coverage_pct[cls]:.2f}%" for cls in classes])
return blended_img, stats_text
except Exception as e:
return None, f"❌ Error during classification: {str(e)}"
def create_color_legend():
"""Create a color legend image showing all land cover classes and their colors"""
# Create a white background image - increased height to fit all 10 classes
legend_width = 450
legend_height = 450 # Increased height to accommodate all classes
legend_img = Image.new('RGB', (legend_width, legend_height), 'white')
draw = ImageDraw.Draw(legend_img)
# Try to use a default font, fallback to basic if not available
try:
font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 14) # Slightly smaller font
except:
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
except:
font = ImageFont.load_default()
# Title
draw.text((10, 10), "🎨 Color Legend - Land Cover Classes", fill='black', font=font)
draw.text((10, 30), f"Total Classes: {len(class_colors)}", fill='blue', font=font)
# Draw color squares with labels - adjusted spacing
y_offset = 55
for i, (cls, color) in enumerate(class_colors.items()):
# Draw color square
y_start = y_offset + i * 35 # Increased spacing between rows
y_end = y_start + 25
# Draw the colored rectangle
draw.rectangle([20, y_start, 50, y_end], fill=color, outline='black', width=2)
# Add class name with index for debugging
class_text = f"{i+1}. {cls}"
draw.text((60, y_start + 2), class_text, fill='black', font=font)
# Add RGB values for reference
rgb_text = f"RGB{color}"
draw.text((280, y_start + 2), rgb_text, fill='gray', font=font)
y_offset += 35 # Increased spacing
# Add footer with total count
footer_y = y_offset + 10
draw.text((10, footer_y), f"Legend shows {len(class_colors)} land cover classes", fill='green', font=font)
return legend_img
# Build Gradio interface
with gr.Blocks(title="🌍 Satellite Land Cover Classifier") as iface:
gr.Markdown("# 🌍 Satellite Land Cover Classifier")
gr.Markdown("Upload a satellite image and get land-cover classification heatmap + coverage stats.")
with gr.Row():
with gr.Column(scale=2):
# Input section
gr.Markdown("## 📤 Upload Image")
input_image = gr.Image(type="pil", label="Upload Satellite Image")
classify_btn = gr.Button("🔍 Classify Image", variant="primary")
with gr.Column(scale=1):
# Color legend section
gr.Markdown("## 🎨 Color Legend")
legend_image = gr.Image(value=create_color_legend(), label="Land Cover Classes", interactive=False)
with gr.Row():
with gr.Column(scale=2):
# Output section
gr.Markdown("## 📊 Results")
output_image = gr.Image(type="pil", label="Classification Result")
output_stats = gr.Textbox(label="Coverage Statistics", lines=10)
with gr.Column(scale=1):
# Additional info
gr.Markdown("## ℹ️ How it works")
gr.Markdown("""
This classifier uses a sliding window approach:
- **Patch Size**: 128x128 pixels
- **Stride**: 128 pixels (no overlap)
- **Output**: Color-coded heatmap + coverage percentages
**Note**: You need to upload your trained model file (export.pkl) to use this classifier.
""")
# Connect the button
classify_btn.click(
fn=classify_image,
inputs=input_image,
outputs=[output_image, output_stats]
)
iface.launch()
|