Plant Disease Segmentation Model
This repository contains a UnetPlusPlus
model with a timm-efficientnet-b4
encoder, trained for segmenting diseased regions on plant leaves.
The model was trained on the PlantDisease dataset and fine-tuned as described in the training notebook.
How to Get Started
Below is a complete example of how to load and use the model for inference on a new image. You will need torch
, segmentation-models-pytorch
, albumentations
, and opencv-python
.
import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
import numpy as np
import matplotlib.pyplot as plt
# === 1. Define Helper Functions (Preprocessing and Visualization) ===
def get_preprocessing_transform():
"""Defines the minimal preprocessing for validation and testing."""
return A.Compose([
A.Resize(512, 512), # Or use CenterCrop if you prefer
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
def denormalize(t):
"""Reverses image normalization for visualization."""
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return torch.clamp(t.cpu() * std + mean, 0, 1)
def colorize_mask(mask):
"""Applies a simple colormap to a 2D mask for visualization."""
# Green for the "disease" class (class 1)
color_map = np.array([,], dtype=np.uint8)
color_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
for class_idx, color in enumerate(color_map):
color_mask[mask == class_idx] = color
return color_mask
# === 2. Load the Model from Hugging Face Hub ===
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo_id = "KM-Alee/nwrd-tukl"
model_filename = "plant-disease-segmentation-unpp-b4.pth"
# Download the packaged model file
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
# Load the package
model_package = torch.load(model_path, map_location=device)
model_config = model_package['config']
model_state_dict = model_package['state_dict']
# Re-create the model architecture
model = smp.UnetPlusPlus(
encoder_name=model_config['encoder_name'],
encoder_weights=None, # Weights are loaded below, no need to download again
in_channels=model_config['in_channels'],
classes=model_config['classes'],
)
# Load the trained weights
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()
print("Model loaded successfully and moved to device:", device)
# === 3. Run Inference on a Sample Image ===
# Load and preprocess the image
# Replace this with the path to your own image
try:
# Try to download a sample image from the repo if it exists
image_path = hf_hub_download(repo_id=repo_id, filename="sample_image.jpg")
except Exception:
# You need to provide a path to an image to test this
image_path = "path/to/your/image.jpg"
print(f"Sample image not found in repo. Please provide a path to an image at '{image_path}'")
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
preprocessing = get_preprocessing_transform()
augmented = preprocessing(image=image)
input_tensor = augmented['image'].unsqueeze(0).to(device)
# Get prediction
with torch.no_grad():
output = model(input_tensor)
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# === 4. Visualize the Results ===
color_pred_mask = colorize_mask(pred_mask)
overlayed_image = cv2.addWeighted((denormalize(input_tensor.squeeze(0)).permute(1, 2, 0).numpy() * 255).astype(np.uint8), 0.6, color_pred_mask, 0.4, 0)
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(color_pred_mask)
plt.title("Predicted Mask")
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(overlayed_image)
plt.title("Overlay")
plt.axis("off")
plt.tight_layout()
plt.show()
- Downloads last month
- 14
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support