import gradio as gr import torch import torch.nn as nn from torchvision import transforms import timm import numpy as np from PIL import Image, ImageDraw, ImageFont import matplotlib.pyplot as plt import cv2 from ultralytics import YOLO import warnings import os import json import pandas as pd from datetime import datetime import io import base64 warnings.filterwarnings('ignore') class GradioLettuceAnalysisPipeline: def __init__(self, detection_model_path, growth_model_path, health_classification_model_path): """ Initialize the complete lettuce analysis pipeline for Gradio interface """ self.detection_model_path = detection_model_path self.growth_model_path = growth_model_path self.health_classification_model_path = health_classification_model_path # Fixed confidence thresholds (no longer adjustable via sliders) self.detection_confidence = 0.5 self.growth_confidence = 0.25 # Load all models self.load_models() # Health classification transforms self.health_classification_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_models(self): """Load all three models""" try: # 1. Load detection model self.detection_model = YOLO(self.detection_model_path) # 2. Load growth stage model self.growth_model = YOLO(self.growth_model_path) # 3. Load health classification model self.load_health_classification_model() return "â All models loaded successfully!" except Exception as e: return f"â Error loading models: {e}" def load_health_classification_model(self): """Load the health classification model (ViT)""" checkpoint = torch.load(self.health_classification_model_path, map_location='cpu') self.health_model_name = checkpoint['model_name'] self.health_class_names = checkpoint['class_names'] # Create health classification model self.health_classification_model = timm.create_model( self.health_model_name, pretrained=False, num_classes=len(self.health_class_names) ) self.health_classification_model.load_state_dict(checkpoint['model_state_dict']) self.health_classification_model.eval() def detect_lettuce(self, image_path): """Stage 1: Detect lettuce in the image""" results = self.detection_model(image_path, conf=self.detection_confidence) detections = [] for result in results: boxes = result.boxes if boxes is not None: for box in boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() conf = box.conf[0].cpu().numpy() cls = int(box.cls[0].cpu().numpy()) detections.append({ 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'confidence': float(conf), 'class': cls, 'class_name': self.detection_model.names[cls] if hasattr(self.detection_model, 'names') else 'lettuce' }) return detections def classify_growth_stage(self, image_path, bbox): """Stage 2: Classify growth stage""" try: image = Image.open(image_path) x1, y1, x2, y2 = bbox # Add padding padding = 20 x1 = max(0, x1 - padding) y1 = max(0, y1 - padding) x2 = min(image.width, x2 + padding) y2 = min(image.height, y2 + padding) # Crop and save temporary image cropped_image = image.crop((x1, y1, x2, y2)) temp_crop_path = "temp_lettuce_crop.jpg" cropped_image.save(temp_crop_path) # Run growth stage classification results = self.growth_model.predict( source=temp_crop_path, conf=self.growth_confidence, save=False, imgsz=640, verbose=False ) growth_results = [] for result in results: boxes = result.boxes if boxes is not None: for box in boxes: cls_id = int(box.cls[0]) conf = float(box.conf[0]) growth_stage = self.growth_model.names[cls_id] growth_results.append({ 'growth_stage': growth_stage, 'confidence': conf }) # Clean up if os.path.exists(temp_crop_path): os.remove(temp_crop_path) if growth_results: best_growth = max(growth_results, key=lambda x: x['confidence']) return best_growth['growth_stage'], best_growth['confidence'] else: return "Unknown", 0.0 except Exception as e: return "Error", 0.0 def classify_health(self, image, bbox): """Stage 3: Classify health status""" try: x1, y1, x2, y2 = bbox cropped_image = image.crop((x1, y1, x2, y2)) input_tensor = self.health_classification_transform(cropped_image).unsqueeze(0) with torch.no_grad(): output = self.health_classification_model(input_tensor) probabilities = torch.softmax(output, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) predicted_class = self.health_class_names[predicted_idx.item()] confidence_score = confidence.item() return predicted_class, confidence_score except Exception as e: return "Unknown", 0.0 def process_image_gradio(self, image, show_boxes, show_labels): """ Process image for Gradio interface """ if image is None: return None, "Please upload an image first!", None, None try: # Save uploaded image temporarily temp_image_path = "temp_uploaded_image.jpg" image.save(temp_image_path) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Stage 1: Detect lettuce detections = self.detect_lettuce(temp_image_path) if not detections: # Clean up if os.path.exists(temp_image_path): os.remove(temp_image_path) return image, "No lettuce detected in the image!", None, None # Process each detection complete_results = [] annotated_image = image.copy() draw = ImageDraw.Draw(annotated_image) # Font setup try: font = ImageFont.truetype("arial.ttf", 16) small_font = ImageFont.truetype("arial.ttf", 12) except: font = ImageFont.load_default() small_font = ImageFont.load_default() colors = ['#FF0000', '#0000FF', '#00FF00', '#FFA500', '#800080', '#FFFF00', '#00FFFF', '#FF00FF'] for i, detection in enumerate(detections): bbox = detection['bbox'] det_conf = detection['confidence'] # Stage 2: Growth stage growth_stage, growth_conf = self.classify_growth_stage(temp_image_path, bbox) # Stage 3: Health status health_status, health_conf = self.classify_health(image, bbox) # Store results result = { 'lettuce_id': i + 1, 'bbox': bbox, 'detection_confidence': det_conf, 'growth_stage': growth_stage, 'growth_confidence': growth_conf, 'health_status': health_status, 'health_confidence': health_conf } complete_results.append(result) # Draw annotations if requested if show_boxes or show_labels: x1, y1, x2, y2 = bbox color = colors[i % len(colors)] if show_boxes: # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline=color, width=3) if show_labels: # Create label label_lines = [ f"Lettuce {i+1}", f"{growth_stage}", f"{health_status}", f"{health_conf:.2f}" ] # Calculate label size max_width = 0 total_height = 0 for line in label_lines: bbox_text = draw.textbbox((0, 0), line, font=small_font) line_width = bbox_text[2] - bbox_text[0] line_height = bbox_text[3] - bbox_text[1] max_width = max(max_width, line_width) total_height += line_height + 2 # Position label label_y = y1 - total_height - 8 if label_y < 0: label_y = y2 + 4 # Draw label background draw.rectangle([x1, label_y, x1 + max_width + 8, label_y + total_height + 4], fill=color, outline=None) # Draw label text current_y = label_y + 2 for line in label_lines: draw.text((x1 + 4, current_y), line, fill='white', font=small_font) bbox_text = draw.textbbox((0, 0), line, font=small_font) current_y += (bbox_text[3] - bbox_text[1]) + 2 # Clean up if os.path.exists(temp_image_path): os.remove(temp_image_path) # Create results summary summary = self.create_results_summary(complete_results) # Create detailed results table results_df = self.create_results_dataframe(complete_results) return annotated_image, summary, results_df, complete_results except Exception as e: return None, f"Error processing image: {str(e)}", None, None def create_results_summary(self, results): """Create a formatted summary of results""" if not results: return "No results to display" summary = f"**LETTUCE ANALYSIS RESULTS**\n\n" summary += f"**Summary:**\n" summary += f"- Total lettuce detected: **{len(results)}**\n" # Growth stages summary growth_stages = [r['growth_stage'] for r in results] growth_counts = {stage: growth_stages.count(stage) for stage in set(growth_stages)} summary += f"- Growth stages: {dict(growth_counts)}\n" # Health status summary health_statuses = [r['health_status'] for r in results] health_counts = {status: health_statuses.count(status) for status in set(health_statuses)} summary += f"- Health statuses: {dict(health_counts)}\n\n" # Detailed results summary += f"đ **Detailed Results:**\n\n" for result in results: summary += f"**Lettuce {result['lettuce_id']}:**\n" summary += f"- Growth Stage: {result['growth_stage']} ({result['growth_confidence']:.3f})\n" summary += f"- Health Status: {result['health_status']} ({result['health_confidence']:.3f})\n" summary += f"- Location: {result['bbox']}\n\n" return summary def create_results_dataframe(self, results): """Create a pandas DataFrame for results table""" if not results: return pd.DataFrame() df_data = [] for result in results: df_data.append({ 'Lettuce ID': result['lettuce_id'], 'Growth Stage': result['growth_stage'], 'Growth Confidence': f"{result['growth_confidence']:.3f}", 'Health Status': result['health_status'], 'Health Confidence': f"{result['health_confidence']:.3f}", 'Detection Confidence': f"{result['detection_confidence']:.3f}", 'Bounding Box': str(result['bbox']) }) return pd.DataFrame(df_data) # Initialize the pipeline try: pipeline = GradioLettuceAnalysisPipeline( detection_model_path='detection.pt', growth_model_path='growth_detection.pt', health_classification_model_path='vit_lettuce_classifier_vit_small_patch16_224.pth' ) model_status = "All models loaded successfully!" except Exception as e: model_status = f"Error loading models: {e}" pipeline = None def process_image_wrapper(image, show_boxes, show_labels): """Wrapper function for Gradio interface""" if pipeline is None: return None, "Models not loaded properly!", None return pipeline.process_image_gradio(image, show_boxes, show_labels) def download_results(results): """Create downloadable results""" if not results: return None # Create detailed JSON report report = { 'timestamp': datetime.now().isoformat(), 'total_lettuce_detected': len(results), 'results': results, 'summary': { 'growth_stages': {}, 'health_statuses': {} } } # Add summary statistics growth_stages = [r['growth_stage'] for r in results] health_statuses = [r['health_status'] for r in results] for stage in set(growth_stages): report['summary']['growth_stages'][stage] = growth_stages.count(stage) for status in set(health_statuses): report['summary']['health_statuses'][status] = health_statuses.count(status) # Save to JSON file filename = f"lettuce_analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(filename, 'w') as f: json.dump(report, f, indent=2) return filename # Custom CSS for styling and logo custom_css = """ .logo-container { text-align: center; margin-bottom: 20px; } .logo-container img { max-height: 100px; width: auto; } .company-header { text-align: center; margin-bottom: 30px; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; } .analyze-button { background: linear-gradient(45deg, #4CAF50, #45a049) !important; color: white !important; border: none !important; padding: 15px 30px !important; font-size: 16px !important; font-weight: bold !important; border-radius: 8px !important; cursor: pointer !important; transition: all 0.3s ease !important; } .analyze-button:hover { transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } .settings-container { background: #f8f9fa; padding: 20px; border-radius: 10px; margin-bottom: 20px; } .footer-info { background: #f1f3f4; padding: 20px; border-radius: 10px; margin-top: 20px; } """ # Create Gradio interface with gr.Blocks(title="Lettuce Analysis Pipeline", theme=gr.themes.Soft(), css=custom_css) as demo: # Company Header with Logo with gr.Row(): gr.HTML("""
Powered by AI âĸ Precision Agriculture Solutions