|
import gradio as gr |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import os |
|
|
|
|
|
model_path = "export.pkl" |
|
learn_inf = None |
|
|
|
if os.path.exists(model_path): |
|
try: |
|
from fastai.vision.all import load_learner |
|
|
|
learn_inf = load_learner(model_path) |
|
print("✅ Model loaded successfully!") |
|
except Exception as e: |
|
print(f"❌ Error loading model: {e}") |
|
print("This might be due to fastai version compatibility issues.") |
|
learn_inf = None |
|
else: |
|
print("Warning: export.pkl not found. Please upload your trained model file.") |
|
|
|
|
|
classes = [ |
|
"AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial", |
|
"Pasture","PermanentCrop","Residential","River","SeaLake" |
|
] |
|
|
|
|
|
class_colors = { |
|
"AnnualCrop": (255, 255, 0), |
|
"Forest": (34, 139, 34), |
|
"HerbaceousVegetation": (144, 238, 144), |
|
"Highway": (128, 128, 128), |
|
"Industrial": (255, 165, 0), |
|
"Pasture": (173, 255, 47), |
|
"PermanentCrop": (0, 255, 0), |
|
"Residential": (255, 0, 0), |
|
"River": (0, 191, 255), |
|
"SeaLake": (0, 0, 139) |
|
} |
|
|
|
|
|
print(f"Total classes defined: {len(classes)}") |
|
print(f"Total colors defined: {len(class_colors)}") |
|
print("Classes:", classes) |
|
print("Colors:", list(class_colors.keys())) |
|
|
|
patch_size = 128 |
|
stride = 128 |
|
|
|
def classify_image(img: Image.Image): |
|
if learn_inf is None: |
|
return None, "❌ Model not loaded! Please check the console for error messages. You may need to retrain your model with a compatible fastai version or use fastai<2.8.0." |
|
|
|
try: |
|
img_np = np.array(img.convert("RGB")) |
|
h, w, _ = img_np.shape |
|
|
|
mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
coverage = {cls:0 for cls in classes} |
|
|
|
|
|
for y in range(0, h-patch_size+1, stride): |
|
for x in range(0, w-patch_size+1, stride): |
|
patch = img_np[y:y+patch_size, x:x+patch_size] |
|
patch_pil = Image.fromarray(patch).convert("RGB") |
|
pred, _, _ = learn_inf.predict(patch_pil) |
|
color = class_colors[pred] |
|
mask[y:y+patch_size, x:x+patch_size] = color |
|
coverage[pred] += patch_size*patch_size |
|
|
|
|
|
blended = (0.6*img_np + 0.4*mask).astype(np.uint8) |
|
blended_img = Image.fromarray(blended) |
|
|
|
|
|
total_pixels = h*w |
|
coverage_pct = {cls: (coverage[cls]/total_pixels)*100 for cls in classes} |
|
|
|
|
|
stats_text = "\n".join([f"{cls}: {coverage_pct[cls]:.2f}%" for cls in classes]) |
|
|
|
return blended_img, stats_text |
|
except Exception as e: |
|
return None, f"❌ Error during classification: {str(e)}" |
|
|
|
def create_color_legend(): |
|
"""Create a color legend image showing all land cover classes and their colors""" |
|
|
|
legend_width = 450 |
|
legend_height = 450 |
|
legend_img = Image.new('RGB', (legend_width, legend_height), 'white') |
|
draw = ImageDraw.Draw(legend_img) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 14) |
|
except: |
|
try: |
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) |
|
except: |
|
font = ImageFont.load_default() |
|
|
|
|
|
draw.text((10, 10), "🎨 Color Legend - Land Cover Classes", fill='black', font=font) |
|
draw.text((10, 30), f"Total Classes: {len(class_colors)}", fill='blue', font=font) |
|
|
|
|
|
y_offset = 55 |
|
for i, (cls, color) in enumerate(class_colors.items()): |
|
|
|
y_start = y_offset + i * 35 |
|
y_end = y_start + 25 |
|
|
|
|
|
draw.rectangle([20, y_start, 50, y_end], fill=color, outline='black', width=2) |
|
|
|
|
|
class_text = f"{i+1}. {cls}" |
|
draw.text((60, y_start + 2), class_text, fill='black', font=font) |
|
|
|
|
|
rgb_text = f"RGB{color}" |
|
draw.text((280, y_start + 2), rgb_text, fill='gray', font=font) |
|
|
|
y_offset += 35 |
|
|
|
|
|
footer_y = y_offset + 10 |
|
draw.text((10, footer_y), f"Legend shows {len(class_colors)} land cover classes", fill='green', font=font) |
|
|
|
return legend_img |
|
|
|
|
|
with gr.Blocks(title="🌍 Satellite Land Cover Classifier") as iface: |
|
gr.Markdown("# 🌍 Satellite Land Cover Classifier") |
|
gr.Markdown("Upload a satellite image and get land-cover classification heatmap + coverage stats.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
gr.Markdown("## 📤 Upload Image") |
|
input_image = gr.Image(type="pil", label="Upload Satellite Image") |
|
classify_btn = gr.Button("🔍 Classify Image", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("## 🎨 Color Legend") |
|
legend_image = gr.Image(value=create_color_legend(), label="Land Cover Classes", interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
gr.Markdown("## 📊 Results") |
|
output_image = gr.Image(type="pil", label="Classification Result") |
|
output_stats = gr.Textbox(label="Coverage Statistics", lines=10) |
|
|
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("## ℹ️ How it works") |
|
gr.Markdown(""" |
|
This classifier uses a sliding window approach: |
|
- **Patch Size**: 128x128 pixels |
|
- **Stride**: 128 pixels (no overlap) |
|
- **Output**: Color-coded heatmap + coverage percentages |
|
|
|
**Note**: You need to upload your trained model file (export.pkl) to use this classifier. |
|
""") |
|
|
|
|
|
classify_btn.click( |
|
fn=classify_image, |
|
inputs=input_image, |
|
outputs=[output_image, output_stats] |
|
) |
|
|
|
iface.launch() |
|
|