StoneSeller's picture
Update app.py
245988b verified
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from collections import Counter
import os
import tempfile
model = YOLO("best.pt")
print(f"Model loaded successfully. Class names: {model.names}")
class_colors = {
'freshripe': (50, 205, 50),
'freshunripe': (173, 255, 47),
'overripe': (255, 165, 0),
'ripe': (0, 128, 0),
'rotten': (128, 0, 0),
'unripe': (255, 255, 0)
}
def maintain_aspect_ratio_resize(image, target_size=640):
h, w = image.shape[:2]
aspect = w / h
if aspect > 1:
new_w = target_size
new_h = int(target_size / aspect)
else:
new_h = target_size
new_w = int(target_size * aspect)
resized = cv2.resize(image, (new_w, new_h))
square_img = np.ones((target_size, target_size, 3), dtype=np.uint8) * 255
offset_x = (target_size - new_w) // 2
offset_y = (target_size - new_h) // 2
square_img[offset_y:offset_y + new_h, offset_x:offset_x + new_w] = resized
return square_img
def predict_image(image):
if image is None:
return None, "Please upload an image."
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
processed_img = maintain_aspect_ratio_resize(img_bgr, target_size=640)
cv2.imwrite(temp_path, processed_img)
print(f"Saved preprocessed image to {temp_path}")
try:
results = model(temp_path, imgsz=640, conf=0.3)
result = results[0]
pred_img = result.plot(conf=False)
boxes = result.boxes.data.cpu().numpy()
cls_names = result.names
print(f"Number of detections: {len(boxes)}")
print(f"Class names: {cls_names}")
class_counts = Counter()
result_text = "Prediction Results:\n"
for box in boxes:
x1, y1, x2, y2, conf, cls_id = box
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
pred_class = cls_names[int(cls_id)]
class_counts[pred_class] += 1
result_text += f"- Class: {pred_class}\n"
result_text += f" Box: [{x1}, {y1}, {x2}, {y2}]\n"
result_text += "\nClass Counts:\n"
for cls_name, count in class_counts.items():
result_text += f"- {cls_name}: {count}\n"
if len(boxes) == 0:
result_text += "\nNo banana detected. Try using a clearer image or different lighting."
output_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB)
os.unlink(temp_path)
return output_img, result_text
except Exception as e:
print(f"Error during prediction: {e}")
try:
os.unlink(temp_path)
except:
pass
processed_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
return processed_rgb, f"Error during prediction: {str(e)}"
with gr.Blocks(title="Banana Ripeness Classifier") as demo:
gr.Markdown("Banana Ripeness Classifier")
gr.Markdown("Upload a banana image to analyze its ripeness. The image will be processed while preserving the banana's shape.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy")
with gr.Row():
submit_btn = gr.Button("Analyze")
clear_btn = gr.Button("Clear")
with gr.Column():
output_image = gr.Image(type="numpy")
output_text = gr.Textbox(label="Analysis Results", lines=10)
submit_btn.click(
fn=predict_image,
inputs=input_image,
outputs=[output_image, output_text]
)
clear_btn.click(
fn=lambda: (None, ""),
inputs=[],
outputs=[output_image, output_text]
)
if os.path.exists("examples"):
example_images = [f"examples/{f}" for f in os.listdir("examples") if f.endswith(('.jpg', '.jpeg', '.png'))]
if example_images:
gr.Examples(
examples=example_images,
inputs=input_image
)
demo.launch()