import gradio as gr import torch import torch.nn as nn from torchvision import transforms import timm import numpy as np from PIL import Image, ImageDraw, ImageFont import matplotlib.pyplot as plt import cv2 from ultralytics import YOLO import warnings import os import json import pandas as pd from datetime import datetime import io import base64 warnings.filterwarnings('ignore') class GradioLettuceAnalysisPipeline: def __init__(self, detection_model_path, growth_model_path, health_classification_model_path): """ Initialize the complete lettuce analysis pipeline for Gradio interface """ self.detection_model_path = detection_model_path self.growth_model_path = growth_model_path self.health_classification_model_path = health_classification_model_path # Fixed confidence thresholds (no longer adjustable via sliders) self.detection_confidence = 0.5 self.growth_confidence = 0.25 # Load all models self.load_models() # Health classification transforms self.health_classification_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_models(self): """Load all three models""" try: # 1. Load detection model self.detection_model = YOLO(self.detection_model_path) # 2. Load growth stage model self.growth_model = YOLO(self.growth_model_path) # 3. Load health classification model self.load_health_classification_model() return "✅ All models loaded successfully!" except Exception as e: return f"❌ Error loading models: {e}" def load_health_classification_model(self): """Load the health classification model (ViT)""" checkpoint = torch.load(self.health_classification_model_path, map_location='cpu') self.health_model_name = checkpoint['model_name'] self.health_class_names = checkpoint['class_names'] # Create health classification model self.health_classification_model = timm.create_model( self.health_model_name, pretrained=False, num_classes=len(self.health_class_names) ) self.health_classification_model.load_state_dict(checkpoint['model_state_dict']) self.health_classification_model.eval() def detect_lettuce(self, image_path): """Stage 1: Detect lettuce in the image""" results = self.detection_model(image_path, conf=self.detection_confidence) detections = [] for result in results: boxes = result.boxes if boxes is not None: for box in boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() conf = box.conf[0].cpu().numpy() cls = int(box.cls[0].cpu().numpy()) detections.append({ 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'confidence': float(conf), 'class': cls, 'class_name': self.detection_model.names[cls] if hasattr(self.detection_model, 'names') else 'lettuce' }) return detections def classify_growth_stage(self, image_path, bbox): """Stage 2: Classify growth stage""" try: image = Image.open(image_path) x1, y1, x2, y2 = bbox # Add padding padding = 20 x1 = max(0, x1 - padding) y1 = max(0, y1 - padding) x2 = min(image.width, x2 + padding) y2 = min(image.height, y2 + padding) # Crop and save temporary image cropped_image = image.crop((x1, y1, x2, y2)) temp_crop_path = "temp_lettuce_crop.jpg" cropped_image.save(temp_crop_path) # Run growth stage classification results = self.growth_model.predict( source=temp_crop_path, conf=self.growth_confidence, save=False, imgsz=640, verbose=False ) growth_results = [] for result in results: boxes = result.boxes if boxes is not None: for box in boxes: cls_id = int(box.cls[0]) conf = float(box.conf[0]) growth_stage = self.growth_model.names[cls_id] growth_results.append({ 'growth_stage': growth_stage, 'confidence': conf }) # Clean up if os.path.exists(temp_crop_path): os.remove(temp_crop_path) if growth_results: best_growth = max(growth_results, key=lambda x: x['confidence']) return best_growth['growth_stage'], best_growth['confidence'] else: return "Unknown", 0.0 except Exception as e: return "Error", 0.0 def classify_health(self, image, bbox): """Stage 3: Classify health status""" try: x1, y1, x2, y2 = bbox cropped_image = image.crop((x1, y1, x2, y2)) input_tensor = self.health_classification_transform(cropped_image).unsqueeze(0) with torch.no_grad(): output = self.health_classification_model(input_tensor) probabilities = torch.softmax(output, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) predicted_class = self.health_class_names[predicted_idx.item()] confidence_score = confidence.item() return predicted_class, confidence_score except Exception as e: return "Unknown", 0.0 def process_image_gradio(self, image, show_boxes, show_labels): """ Process image for Gradio interface """ if image is None: return None, "Please upload an image first!", None, None try: # Save uploaded image temporarily temp_image_path = "temp_uploaded_image.jpg" image.save(temp_image_path) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Stage 1: Detect lettuce detections = self.detect_lettuce(temp_image_path) if not detections: # Clean up if os.path.exists(temp_image_path): os.remove(temp_image_path) return image, "No lettuce detected in the image!", None, None # Process each detection complete_results = [] annotated_image = image.copy() draw = ImageDraw.Draw(annotated_image) # Font setup try: font = ImageFont.truetype("arial.ttf", 16) small_font = ImageFont.truetype("arial.ttf", 12) except: font = ImageFont.load_default() small_font = ImageFont.load_default() colors = ['#FF0000', '#0000FF', '#00FF00', '#FFA500', '#800080', '#FFFF00', '#00FFFF', '#FF00FF'] for i, detection in enumerate(detections): bbox = detection['bbox'] det_conf = detection['confidence'] # Stage 2: Growth stage growth_stage, growth_conf = self.classify_growth_stage(temp_image_path, bbox) # Stage 3: Health status health_status, health_conf = self.classify_health(image, bbox) # Store results result = { 'lettuce_id': i + 1, 'bbox': bbox, 'detection_confidence': det_conf, 'growth_stage': growth_stage, 'growth_confidence': growth_conf, 'health_status': health_status, 'health_confidence': health_conf } complete_results.append(result) # Draw annotations if requested if show_boxes or show_labels: x1, y1, x2, y2 = bbox color = colors[i % len(colors)] if show_boxes: # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline=color, width=3) if show_labels: # Create label label_lines = [ f"Lettuce {i+1}", f"{growth_stage}", f"{health_status}", f"{health_conf:.2f}" ] # Calculate label size max_width = 0 total_height = 0 for line in label_lines: bbox_text = draw.textbbox((0, 0), line, font=small_font) line_width = bbox_text[2] - bbox_text[0] line_height = bbox_text[3] - bbox_text[1] max_width = max(max_width, line_width) total_height += line_height + 2 # Position label label_y = y1 - total_height - 8 if label_y < 0: label_y = y2 + 4 # Draw label background draw.rectangle([x1, label_y, x1 + max_width + 8, label_y + total_height + 4], fill=color, outline=None) # Draw label text current_y = label_y + 2 for line in label_lines: draw.text((x1 + 4, current_y), line, fill='white', font=small_font) bbox_text = draw.textbbox((0, 0), line, font=small_font) current_y += (bbox_text[3] - bbox_text[1]) + 2 # Clean up if os.path.exists(temp_image_path): os.remove(temp_image_path) # Create results summary summary = self.create_results_summary(complete_results) # Create detailed results table results_df = self.create_results_dataframe(complete_results) return annotated_image, summary, results_df, complete_results except Exception as e: return None, f"Error processing image: {str(e)}", None, None def create_results_summary(self, results): """Create a formatted summary of results""" if not results: return "No results to display" summary = f"**LETTUCE ANALYSIS RESULTS**\n\n" summary += f"**Summary:**\n" summary += f"- Total lettuce detected: **{len(results)}**\n" # Growth stages summary growth_stages = [r['growth_stage'] for r in results] growth_counts = {stage: growth_stages.count(stage) for stage in set(growth_stages)} summary += f"- Growth stages: {dict(growth_counts)}\n" # Health status summary health_statuses = [r['health_status'] for r in results] health_counts = {status: health_statuses.count(status) for status in set(health_statuses)} summary += f"- Health statuses: {dict(health_counts)}\n\n" # Detailed results summary += f"📋 **Detailed Results:**\n\n" for result in results: summary += f"**Lettuce {result['lettuce_id']}:**\n" summary += f"- Growth Stage: {result['growth_stage']} ({result['growth_confidence']:.3f})\n" summary += f"- Health Status: {result['health_status']} ({result['health_confidence']:.3f})\n" summary += f"- Location: {result['bbox']}\n\n" return summary def create_results_dataframe(self, results): """Create a pandas DataFrame for results table""" if not results: return pd.DataFrame() df_data = [] for result in results: df_data.append({ 'Lettuce ID': result['lettuce_id'], 'Growth Stage': result['growth_stage'], 'Growth Confidence': f"{result['growth_confidence']:.3f}", 'Health Status': result['health_status'], 'Health Confidence': f"{result['health_confidence']:.3f}", 'Detection Confidence': f"{result['detection_confidence']:.3f}", 'Bounding Box': str(result['bbox']) }) return pd.DataFrame(df_data) # Initialize the pipeline try: pipeline = GradioLettuceAnalysisPipeline( detection_model_path='detection.pt', growth_model_path='growth_detection.pt', health_classification_model_path='vit_lettuce_classifier_vit_small_patch16_224.pth' ) model_status = "All models loaded successfully!" except Exception as e: model_status = f"Error loading models: {e}" pipeline = None def process_image_wrapper(image, show_boxes, show_labels): """Wrapper function for Gradio interface""" if pipeline is None: return None, "Models not loaded properly!", None return pipeline.process_image_gradio(image, show_boxes, show_labels) def download_results(results): """Create downloadable results""" if not results: return None # Create detailed JSON report report = { 'timestamp': datetime.now().isoformat(), 'total_lettuce_detected': len(results), 'results': results, 'summary': { 'growth_stages': {}, 'health_statuses': {} } } # Add summary statistics growth_stages = [r['growth_stage'] for r in results] health_statuses = [r['health_status'] for r in results] for stage in set(growth_stages): report['summary']['growth_stages'][stage] = growth_stages.count(stage) for status in set(health_statuses): report['summary']['health_statuses'][status] = health_statuses.count(status) # Save to JSON file filename = f"lettuce_analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(filename, 'w') as f: json.dump(report, f, indent=2) return filename # Custom CSS for styling and logo custom_css = """ .logo-container { text-align: center; margin-bottom: 20px; } .logo-container img { max-height: 100px; width: auto; } .company-header { text-align: center; margin-bottom: 30px; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; } .analyze-button { background: linear-gradient(45deg, #4CAF50, #45a049) !important; color: white !important; border: none !important; padding: 15px 30px !important; font-size: 16px !important; font-weight: bold !important; border-radius: 8px !important; cursor: pointer !important; transition: all 0.3s ease !important; } .analyze-button:hover { transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } .settings-container { background: #f8f9fa; padding: 20px; border-radius: 10px; margin-bottom: 20px; } .footer-info { background: #f1f3f4; padding: 20px; border-radius: 10px; margin-top: 20px; } """ # Create Gradio interface with gr.Blocks(title="Lettuce Analysis Pipeline", theme=gr.themes.Soft(), css=custom_css) as demo: # Company Header with Logo with gr.Row(): gr.HTML("""
Garden Of Babylon

Advanced Lettuce Analysis Platform

Powered by AI â€ĸ Precision Agriculture Solutions

""") # Main content # gr.Markdown("""## Professional Lettuce Analysis Pipeline Our advanced AI system performs comprehensive lettuce analysis in three automated stages: - **Detection**: Automatically locates lettuce in your images - **Growth Stage Classification**: Determines the growth stage of each lettuce plant- **Health Assessment**: Evaluates the health condition of each plant Simply upload an image and let our AI do the rest!""") # Model status gr.Markdown(f"**System Status:** {model_status}") with gr.Row(): # Left column - Input with gr.Column(scale=1): gr.Markdown("## Upload Image") # Image input input_image = gr.Image( type="pil", label="Upload Lettuce Image", sources=["upload"], interactive=True, height=300 ) # Simplified settings with gr.Group(): gr.Markdown("### Display Options") with gr.Row(): show_boxes = gr.Checkbox( label="Show Bounding Boxes", value=True ) show_labels = gr.Checkbox( label="Show Labels", value=True ) # Process button process_btn = gr.Button( "🚀 Analyze Lettuce", variant="primary", size="lg", elem_classes="analyze-button" ) # Info box #gr.Markdown("""

â„šī¸ Analysis Settings

""") # Right column - Output with gr.Column(scale=2): gr.Markdown("## Analysis Results") # Output image output_image = gr.Image( label="Analysis Results", type="pil", interactive=False, height=400 ) # Results summary results_summary = gr.Markdown( label="Analysis Summary", value="Upload an image and click 'Analyze Lettuce' to see results here." ) # Results table gr.Markdown("##Detailed Results") results_table = gr.Dataframe( label="Comprehensive Analysis Data", interactive=False, wrap=True ) # Download section with gr.Row(): with gr.Column(scale=1): download_btn = gr.Button("Download Results (JSON)", variant="secondary") with gr.Column(scale=2): download_file = gr.File(label="Download Analysis Report", visible=False) # Hidden state to store results results_state = gr.State() # Event handlers process_btn.click( fn=process_image_wrapper, inputs=[input_image, show_boxes, show_labels], outputs=[output_image, results_summary, results_table, results_state] ) download_btn.click( fn=download_results, inputs=[results_state], outputs=[download_file] ).then( lambda: gr.update(visible=True), outputs=[download_file] ) # Footer gr.HTML(""" """) # Launch the app if __name__ == "__main__": demo.launch()