Plant Disease Segmentation Model

This repository contains a UnetPlusPlus model with a timm-efficientnet-b4 encoder, trained for segmenting diseased regions on plant leaves.

The model was trained on the PlantDisease dataset and fine-tuned as described in the training notebook.

How to Get Started

Below is a complete example of how to load and use the model for inference on a new image. You will need torch, segmentation-models-pytorch, albumentations, and opencv-python.

import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download
import numpy as np
import matplotlib.pyplot as plt

# === 1. Define Helper Functions (Preprocessing and Visualization) ===

def get_preprocessing_transform():
    """Defines the minimal preprocessing for validation and testing."""
    return A.Compose([
        A.Resize(512, 512), # Or use CenterCrop if you prefer
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def denormalize(t):
    """Reverses image normalization for visualization."""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return torch.clamp(t.cpu() * std + mean, 0, 1)

def colorize_mask(mask):
    """Applies a simple colormap to a 2D mask for visualization."""
    # Green for the "disease" class (class 1)
    color_map = np.array([,], dtype=np.uint8) 
    color_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for class_idx, color in enumerate(color_map):
        color_mask[mask == class_idx] = color
    return color_mask

# === 2. Load the Model from Hugging Face Hub ===

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo_id = "KM-Alee/nwrd-tukl"
model_filename = "plant-disease-segmentation-unpp-b4.pth"

# Download the packaged model file
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)

# Load the package
model_package = torch.load(model_path, map_location=device)
model_config = model_package['config']
model_state_dict = model_package['state_dict']

# Re-create the model architecture
model = smp.UnetPlusPlus(
    encoder_name=model_config['encoder_name'],
    encoder_weights=None, # Weights are loaded below, no need to download again
    in_channels=model_config['in_channels'],
    classes=model_config['classes'],
)

# Load the trained weights
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()

print("Model loaded successfully and moved to device:", device)

# === 3. Run Inference on a Sample Image ===

# Load and preprocess the image
# Replace this with the path to your own image
try:
    # Try to download a sample image from the repo if it exists
    image_path = hf_hub_download(repo_id=repo_id, filename="sample_image.jpg")
except Exception:
    # You need to provide a path to an image to test this
    image_path = "path/to/your/image.jpg" 
    print(f"Sample image not found in repo. Please provide a path to an image at '{image_path}'")


image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
preprocessing = get_preprocessing_transform()
augmented = preprocessing(image=image)
input_tensor = augmented['image'].unsqueeze(0).to(device)

# Get prediction
with torch.no_grad():
    output = model(input_tensor)
    pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()

# === 4. Visualize the Results ===
color_pred_mask = colorize_mask(pred_mask)
overlayed_image = cv2.addWeighted((denormalize(input_tensor.squeeze(0)).permute(1, 2, 0).numpy() * 255).astype(np.uint8), 0.6, color_pred_mask, 0.4, 0)


plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(color_pred_mask)
plt.title("Predicted Mask")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(overlayed_image)
plt.title("Overlay")
plt.axis("off")

plt.tight_layout()
plt.show() 
Downloads last month
14
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support