File size: 4,289 Bytes
db6e6ea
 
 
 
 
 
557f511
db6e6ea
56e8c63
db6e6ea
557f511
db6e6ea
56e8c63
db6e6ea
 
 
 
 
 
 
 
 
cba72cc
 
 
 
 
56e8c63
cba72cc
 
56e8c63
cba72cc
 
 
 
 
 
 
 
 
 
 
 
 
db6e6ea
 
 
 
557f511
 
db6e6ea
557f511
db6e6ea
cba72cc
db6e6ea
cba72cc
db6e6ea
cba72cc
db6e6ea
557f511
 
 
db6e6ea
245988b
557f511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56e8c63
557f511
 
 
 
 
 
 
 
 
 
 
 
 
 
db6e6ea
557f511
 
 
 
 
 
 
 
cba72cc
 
db6e6ea
 
56e8c63
cba72cc
db6e6ea
 
 
 
557f511
 
 
db6e6ea
 
 
 
 
 
 
 
 
 
 
557f511
 
 
 
 
 
db6e6ea
 
 
 
 
 
 
 
dc009a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from collections import Counter
import os
import tempfile


model = YOLO("best.pt")
print(f"Model loaded successfully. Class names: {model.names}")


class_colors = {
    'freshripe': (50, 205, 50),
    'freshunripe': (173, 255, 47),
    'overripe': (255, 165, 0),
    'ripe': (0, 128, 0),
    'rotten': (128, 0, 0),
    'unripe': (255, 255, 0)
}

def maintain_aspect_ratio_resize(image, target_size=640):
    h, w = image.shape[:2]
    
    aspect = w / h
    
    if aspect > 1: 
        new_w = target_size
        new_h = int(target_size / aspect)
    else: 
        new_h = target_size
        new_w = int(target_size * aspect)
    
    resized = cv2.resize(image, (new_w, new_h))
    
    square_img = np.ones((target_size, target_size, 3), dtype=np.uint8) * 255
    
    offset_x = (target_size - new_w) // 2
    offset_y = (target_size - new_h) // 2
    square_img[offset_y:offset_y + new_h, offset_x:offset_x + new_w] = resized
    
    return square_img

def predict_image(image):
    if image is None:
        return None, "Please upload an image."
    
    with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
        temp_path = temp_file.name
    
    img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    processed_img = maintain_aspect_ratio_resize(img_bgr, target_size=640)
    
    cv2.imwrite(temp_path, processed_img)
    
    print(f"Saved preprocessed image to {temp_path}")
    
    try:
        results = model(temp_path, imgsz=640, conf=0.3)
        result = results[0]
        
        pred_img = result.plot(conf=False)
        
        boxes = result.boxes.data.cpu().numpy()
        cls_names = result.names
        
        print(f"Number of detections: {len(boxes)}")
        print(f"Class names: {cls_names}")
        
        class_counts = Counter()
        result_text = "Prediction Results:\n"
        
        for box in boxes:
            x1, y1, x2, y2, conf, cls_id = box
            
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            
            pred_class = cls_names[int(cls_id)]
            class_counts[pred_class] += 1
            
            result_text += f"- Class: {pred_class}\n"
            result_text += f"  Box: [{x1}, {y1}, {x2}, {y2}]\n"
        
        result_text += "\nClass Counts:\n"
        for cls_name, count in class_counts.items():
            result_text += f"- {cls_name}: {count}\n"
        
        if len(boxes) == 0:
            result_text += "\nNo banana detected. Try using a clearer image or different lighting."
        
        output_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB)
        
        os.unlink(temp_path)
        
        return output_img, result_text
    
    except Exception as e:
        print(f"Error during prediction: {e}")
        
        try:
            os.unlink(temp_path)
        except:
            pass
            
        processed_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
        return processed_rgb, f"Error during prediction: {str(e)}"

with gr.Blocks(title="Banana Ripeness Classifier") as demo:
    gr.Markdown("Banana Ripeness Classifier")
    gr.Markdown("Upload a banana image to analyze its ripeness. The image will be processed while preserving the banana's shape.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="numpy")
            with gr.Row():
                submit_btn = gr.Button("Analyze")
                clear_btn = gr.Button("Clear")
        
        with gr.Column():
            output_image = gr.Image(type="numpy")
            output_text = gr.Textbox(label="Analysis Results", lines=10)
    
    submit_btn.click(
        fn=predict_image,
        inputs=input_image,
        outputs=[output_image, output_text]
    )
    
    clear_btn.click(
        fn=lambda: (None, ""),
        inputs=[],
        outputs=[output_image, output_text]
    )
    
    if os.path.exists("examples"):
        example_images = [f"examples/{f}" for f in os.listdir("examples") if f.endswith(('.jpg', '.jpeg', '.png'))]
        if example_images:
            gr.Examples(
                examples=example_images,
                inputs=input_image
            )

demo.launch()