|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
import timm |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
from ultralytics import YOLO |
|
import warnings |
|
import os |
|
import json |
|
import pandas as pd |
|
from datetime import datetime |
|
import io |
|
import base64 |
|
warnings.filterwarnings('ignore') |
|
|
|
class GradioLettuceAnalysisPipeline: |
|
def __init__(self, detection_model_path, growth_model_path, health_classification_model_path): |
|
""" |
|
Initialize the complete lettuce analysis pipeline for Gradio interface |
|
""" |
|
self.detection_model_path = detection_model_path |
|
self.growth_model_path = growth_model_path |
|
self.health_classification_model_path = health_classification_model_path |
|
|
|
|
|
self.detection_confidence = 0.5 |
|
self.growth_confidence = 0.25 |
|
|
|
|
|
self.load_models() |
|
|
|
|
|
self.health_classification_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def load_models(self): |
|
"""Load all three models""" |
|
try: |
|
|
|
self.detection_model = YOLO(self.detection_model_path) |
|
|
|
|
|
self.growth_model = YOLO(self.growth_model_path) |
|
|
|
|
|
self.load_health_classification_model() |
|
|
|
return "β
All models loaded successfully!" |
|
|
|
except Exception as e: |
|
return f"β Error loading models: {e}" |
|
|
|
def load_health_classification_model(self): |
|
"""Load the health classification model (ViT)""" |
|
checkpoint = torch.load(self.health_classification_model_path, map_location='cpu') |
|
self.health_model_name = checkpoint['model_name'] |
|
self.health_class_names = checkpoint['class_names'] |
|
|
|
|
|
self.health_classification_model = timm.create_model( |
|
self.health_model_name, |
|
pretrained=False, |
|
num_classes=len(self.health_class_names) |
|
) |
|
self.health_classification_model.load_state_dict(checkpoint['model_state_dict']) |
|
self.health_classification_model.eval() |
|
|
|
def detect_lettuce(self, image_path): |
|
"""Stage 1: Detect lettuce in the image""" |
|
results = self.detection_model(image_path, conf=self.detection_confidence) |
|
detections = [] |
|
|
|
for result in results: |
|
boxes = result.boxes |
|
if boxes is not None: |
|
for box in boxes: |
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() |
|
conf = box.conf[0].cpu().numpy() |
|
cls = int(box.cls[0].cpu().numpy()) |
|
|
|
detections.append({ |
|
'bbox': [int(x1), int(y1), int(x2), int(y2)], |
|
'confidence': float(conf), |
|
'class': cls, |
|
'class_name': self.detection_model.names[cls] if hasattr(self.detection_model, 'names') else 'lettuce' |
|
}) |
|
|
|
return detections |
|
|
|
def classify_growth_stage(self, image_path, bbox): |
|
"""Stage 2: Classify growth stage""" |
|
try: |
|
image = Image.open(image_path) |
|
x1, y1, x2, y2 = bbox |
|
|
|
|
|
padding = 20 |
|
x1 = max(0, x1 - padding) |
|
y1 = max(0, y1 - padding) |
|
x2 = min(image.width, x2 + padding) |
|
y2 = min(image.height, y2 + padding) |
|
|
|
|
|
cropped_image = image.crop((x1, y1, x2, y2)) |
|
temp_crop_path = "temp_lettuce_crop.jpg" |
|
cropped_image.save(temp_crop_path) |
|
|
|
|
|
results = self.growth_model.predict( |
|
source=temp_crop_path, |
|
conf=self.growth_confidence, |
|
save=False, |
|
imgsz=640, |
|
verbose=False |
|
) |
|
|
|
growth_results = [] |
|
for result in results: |
|
boxes = result.boxes |
|
if boxes is not None: |
|
for box in boxes: |
|
cls_id = int(box.cls[0]) |
|
conf = float(box.conf[0]) |
|
growth_stage = self.growth_model.names[cls_id] |
|
|
|
growth_results.append({ |
|
'growth_stage': growth_stage, |
|
'confidence': conf |
|
}) |
|
|
|
|
|
if os.path.exists(temp_crop_path): |
|
os.remove(temp_crop_path) |
|
|
|
if growth_results: |
|
best_growth = max(growth_results, key=lambda x: x['confidence']) |
|
return best_growth['growth_stage'], best_growth['confidence'] |
|
else: |
|
return "Unknown", 0.0 |
|
|
|
except Exception as e: |
|
return "Error", 0.0 |
|
|
|
def classify_health(self, image, bbox): |
|
"""Stage 3: Classify health status""" |
|
try: |
|
x1, y1, x2, y2 = bbox |
|
cropped_image = image.crop((x1, y1, x2, y2)) |
|
|
|
input_tensor = self.health_classification_transform(cropped_image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = self.health_classification_model(input_tensor) |
|
probabilities = torch.softmax(output, dim=1) |
|
confidence, predicted_idx = torch.max(probabilities, 1) |
|
|
|
predicted_class = self.health_class_names[predicted_idx.item()] |
|
confidence_score = confidence.item() |
|
|
|
return predicted_class, confidence_score |
|
|
|
except Exception as e: |
|
return "Unknown", 0.0 |
|
|
|
def process_image_gradio(self, image, show_boxes, show_labels): |
|
""" |
|
Process image for Gradio interface |
|
""" |
|
if image is None: |
|
return None, "Please upload an image first!", None, None |
|
|
|
try: |
|
|
|
temp_image_path = "temp_uploaded_image.jpg" |
|
image.save(temp_image_path) |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
detections = self.detect_lettuce(temp_image_path) |
|
|
|
if not detections: |
|
|
|
if os.path.exists(temp_image_path): |
|
os.remove(temp_image_path) |
|
return image, "No lettuce detected in the image!", None, None |
|
|
|
|
|
complete_results = [] |
|
annotated_image = image.copy() |
|
draw = ImageDraw.Draw(annotated_image) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 16) |
|
small_font = ImageFont.truetype("arial.ttf", 12) |
|
except: |
|
font = ImageFont.load_default() |
|
small_font = ImageFont.load_default() |
|
|
|
colors = ['#FF0000', '#0000FF', '#00FF00', '#FFA500', '#800080', '#FFFF00', '#00FFFF', '#FF00FF'] |
|
|
|
for i, detection in enumerate(detections): |
|
bbox = detection['bbox'] |
|
det_conf = detection['confidence'] |
|
|
|
|
|
growth_stage, growth_conf = self.classify_growth_stage(temp_image_path, bbox) |
|
|
|
|
|
health_status, health_conf = self.classify_health(image, bbox) |
|
|
|
|
|
result = { |
|
'lettuce_id': i + 1, |
|
'bbox': bbox, |
|
'detection_confidence': det_conf, |
|
'growth_stage': growth_stage, |
|
'growth_confidence': growth_conf, |
|
'health_status': health_status, |
|
'health_confidence': health_conf |
|
} |
|
complete_results.append(result) |
|
|
|
|
|
if show_boxes or show_labels: |
|
x1, y1, x2, y2 = bbox |
|
color = colors[i % len(colors)] |
|
|
|
if show_boxes: |
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
|
|
|
if show_labels: |
|
|
|
label_lines = [ |
|
f"Lettuce {i+1}", |
|
f"{growth_stage}", |
|
f"{health_status}", |
|
f"{health_conf:.2f}" |
|
] |
|
|
|
|
|
max_width = 0 |
|
total_height = 0 |
|
|
|
for line in label_lines: |
|
bbox_text = draw.textbbox((0, 0), line, font=small_font) |
|
line_width = bbox_text[2] - bbox_text[0] |
|
line_height = bbox_text[3] - bbox_text[1] |
|
max_width = max(max_width, line_width) |
|
total_height += line_height + 2 |
|
|
|
|
|
label_y = y1 - total_height - 8 |
|
if label_y < 0: |
|
label_y = y2 + 4 |
|
|
|
|
|
draw.rectangle([x1, label_y, x1 + max_width + 8, label_y + total_height + 4], |
|
fill=color, outline=None) |
|
|
|
|
|
current_y = label_y + 2 |
|
for line in label_lines: |
|
draw.text((x1 + 4, current_y), line, fill='white', font=small_font) |
|
bbox_text = draw.textbbox((0, 0), line, font=small_font) |
|
current_y += (bbox_text[3] - bbox_text[1]) + 2 |
|
|
|
|
|
if os.path.exists(temp_image_path): |
|
os.remove(temp_image_path) |
|
|
|
|
|
summary = self.create_results_summary(complete_results) |
|
|
|
|
|
results_df = self.create_results_dataframe(complete_results) |
|
|
|
return annotated_image, summary, results_df, complete_results |
|
|
|
except Exception as e: |
|
return None, f"Error processing image: {str(e)}", None, None |
|
|
|
def create_results_summary(self, results): |
|
"""Create a formatted summary of results""" |
|
if not results: |
|
return "No results to display" |
|
|
|
summary = f"**LETTUCE ANALYSIS RESULTS**\n\n" |
|
summary += f"**Summary:**\n" |
|
summary += f"- Total lettuce detected: **{len(results)}**\n" |
|
|
|
|
|
growth_stages = [r['growth_stage'] for r in results] |
|
growth_counts = {stage: growth_stages.count(stage) for stage in set(growth_stages)} |
|
summary += f"- Growth stages: {dict(growth_counts)}\n" |
|
|
|
|
|
health_statuses = [r['health_status'] for r in results] |
|
health_counts = {status: health_statuses.count(status) for status in set(health_statuses)} |
|
summary += f"- Health statuses: {dict(health_counts)}\n\n" |
|
|
|
|
|
summary += f"π **Detailed Results:**\n\n" |
|
|
|
for result in results: |
|
summary += f"**Lettuce {result['lettuce_id']}:**\n" |
|
summary += f"- Growth Stage: {result['growth_stage']} ({result['growth_confidence']:.3f})\n" |
|
summary += f"- Health Status: {result['health_status']} ({result['health_confidence']:.3f})\n" |
|
summary += f"- Location: {result['bbox']}\n\n" |
|
|
|
return summary |
|
|
|
def create_results_dataframe(self, results): |
|
"""Create a pandas DataFrame for results table""" |
|
if not results: |
|
return pd.DataFrame() |
|
|
|
df_data = [] |
|
for result in results: |
|
df_data.append({ |
|
'Lettuce ID': result['lettuce_id'], |
|
'Growth Stage': result['growth_stage'], |
|
'Growth Confidence': f"{result['growth_confidence']:.3f}", |
|
'Health Status': result['health_status'], |
|
'Health Confidence': f"{result['health_confidence']:.3f}", |
|
'Detection Confidence': f"{result['detection_confidence']:.3f}", |
|
'Bounding Box': str(result['bbox']) |
|
}) |
|
|
|
return pd.DataFrame(df_data) |
|
|
|
|
|
try: |
|
pipeline = GradioLettuceAnalysisPipeline( |
|
detection_model_path='detection.pt', |
|
growth_model_path='growth_detection.pt', |
|
health_classification_model_path='vit_lettuce_classifier_vit_small_patch16_224.pth' |
|
) |
|
model_status = "All models loaded successfully!" |
|
except Exception as e: |
|
model_status = f"Error loading models: {e}" |
|
pipeline = None |
|
|
|
def process_image_wrapper(image, show_boxes, show_labels): |
|
"""Wrapper function for Gradio interface""" |
|
if pipeline is None: |
|
return None, "Models not loaded properly!", None |
|
|
|
return pipeline.process_image_gradio(image, show_boxes, show_labels) |
|
|
|
def download_results(results): |
|
"""Create downloadable results""" |
|
if not results: |
|
return None |
|
|
|
|
|
report = { |
|
'timestamp': datetime.now().isoformat(), |
|
'total_lettuce_detected': len(results), |
|
'results': results, |
|
'summary': { |
|
'growth_stages': {}, |
|
'health_statuses': {} |
|
} |
|
} |
|
|
|
|
|
growth_stages = [r['growth_stage'] for r in results] |
|
health_statuses = [r['health_status'] for r in results] |
|
|
|
for stage in set(growth_stages): |
|
report['summary']['growth_stages'][stage] = growth_stages.count(stage) |
|
|
|
for status in set(health_statuses): |
|
report['summary']['health_statuses'][status] = health_statuses.count(status) |
|
|
|
|
|
filename = f"lettuce_analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
|
with open(filename, 'w') as f: |
|
json.dump(report, f, indent=2) |
|
|
|
return filename |
|
|
|
|
|
custom_css = """ |
|
.logo-container { |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.logo-container img { |
|
max-height: 100px; |
|
width: auto; |
|
} |
|
|
|
.company-header { |
|
text-align: center; |
|
margin-bottom: 30px; |
|
padding: 20px; |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
border-radius: 10px; |
|
color: white; |
|
} |
|
|
|
.analyze-button { |
|
background: linear-gradient(45deg, #4CAF50, #45a049) !important; |
|
color: white !important; |
|
border: none !important; |
|
padding: 15px 30px !important; |
|
font-size: 16px !important; |
|
font-weight: bold !important; |
|
border-radius: 8px !important; |
|
cursor: pointer !important; |
|
transition: all 0.3s ease !important; |
|
} |
|
|
|
.analyze-button:hover { |
|
transform: translateY(-2px) !important; |
|
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; |
|
} |
|
|
|
.settings-container { |
|
background: #f8f9fa; |
|
padding: 20px; |
|
border-radius: 10px; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.footer-info { |
|
background: #f1f3f4; |
|
padding: 20px; |
|
border-radius: 10px; |
|
margin-top: 20px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(title="Lettuce Analysis Pipeline", theme=gr.themes.Soft(), css=custom_css) as demo: |
|
|
|
|
|
with gr.Row(): |
|
gr.HTML(""" |
|
<div class="company-header"><div class="logo-container"> |
|
<img src="./GB_logo.jpg" alt="Garden Of Babylon" /> |
|
</div> |
|
<h1>Advanced Lettuce Analysis Platform</h1> |
|
<p>Powered by AI β’ Precision Agriculture Solutions</p> |
|
</div> |
|
""") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown(f"**System Status:** {model_status}") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## Upload Image") |
|
|
|
|
|
input_image = gr.Image( |
|
type="pil", |
|
label="Upload Lettuce Image", |
|
sources=["upload"], |
|
interactive=True, |
|
height=300 |
|
) |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("### Display Options") |
|
with gr.Row(): |
|
show_boxes = gr.Checkbox( |
|
label="Show Bounding Boxes", |
|
value=True |
|
) |
|
show_labels = gr.Checkbox( |
|
label="Show Labels", |
|
value=True |
|
) |
|
|
|
|
|
process_btn = gr.Button( |
|
"π Analyze Lettuce", |
|
variant="primary", |
|
size="lg", |
|
elem_classes="analyze-button" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("## Analysis Results") |
|
|
|
|
|
output_image = gr.Image( |
|
label="Analysis Results", |
|
type="pil", |
|
interactive=False, |
|
height=400 |
|
) |
|
|
|
|
|
results_summary = gr.Markdown( |
|
label="Analysis Summary", |
|
value="Upload an image and click 'Analyze Lettuce' to see results here." |
|
) |
|
|
|
|
|
gr.Markdown("##Detailed Results") |
|
results_table = gr.Dataframe( |
|
label="Comprehensive Analysis Data", |
|
interactive=False, |
|
wrap=True |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
download_btn = gr.Button("Download Results (JSON)", variant="secondary") |
|
with gr.Column(scale=2): |
|
download_file = gr.File(label="Download Analysis Report", visible=False) |
|
|
|
|
|
results_state = gr.State() |
|
|
|
|
|
process_btn.click( |
|
fn=process_image_wrapper, |
|
inputs=[input_image, show_boxes, show_labels], |
|
outputs=[output_image, results_summary, results_table, results_state] |
|
) |
|
|
|
download_btn.click( |
|
fn=download_results, |
|
inputs=[results_state], |
|
outputs=[download_file] |
|
).then( |
|
lambda: gr.update(visible=True), |
|
outputs=[download_file] |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div class="footer-info"> |
|
<h3>π§ System Features</h3> |
|
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 20px;"> |
|
<div> |
|
<h4>Precision Detection</h4> |
|
<p>Advanced YOLO-based detection with optimized confidence thresholds</p> |
|
</div> |
|
<div> |
|
<h4> Growth Analysis</h4> |
|
<p>Multi-stage classification for accurate growth assessment</p> |
|
</div> |
|
<div> |
|
<h4>Health Monitoring</h4> |
|
<p>Vision Transformer (ViT) powered health status evaluation</p> |
|
</div> |
|
<div> |
|
<h4>Comprehensive Reports</h4> |
|
<p>Detailed analysis with downloadable JSON reports</p> |
|
</div> |
|
</div> |
|
<hr style="margin: 20px 0;"> |
|
<p style="text-align: center; color: #666;"> |
|
<strong>Developed for Precision Agriculture</strong> | |
|
Optimized confidence thresholds for maximum accuracy | |
|
Support for multiple lettuce detection |
|
</p> |
|
</div> |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |