File size: 3,563 Bytes
8032137
 
 
 
 
 
 
 
 
 
 
 
 
0e76afd
 
8032137
 
 
 
0e76afd
8032137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from PIL import Image
import numpy as np
import onnxruntime as ort
import os
from dotenv import load_dotenv
import ast
from openai import OpenAI

# Load environment variables
load_dotenv()

# === Load and clean class names ===
class_file_path = "class_names.txt"
with open(class_file_path, "r") as f:
    raw_line = f.read()
class_names = ast.literal_eval(raw_line.replace("Classes: ", "").strip())

# === Load ONNX model ===
model_path = "model.onnx"
learn = ort.InferenceSession(model_path)

# === OpenRouter setup ===
client = OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key=os.getenv("OPENROUTER_API_KEY"),
)

def generate_description_and_prevention(label):
    if label == "not_a_crop":
        return (
            "The uploaded image does not seem to show a valid crop or leaf.",
            "Please upload a clear image of a single crop or a leaf showing disease symptoms."
        )

    prompt = (
        f"Explain in simple words what the plant disease or condition '{label}' is, "
        f"and give 2 to 4 clear, practical prevention tips.\n"
        "Use this format:\n"
        "Description:\n"
        "Explain briefly what this disease is and how it affects the plant.\n"
        "Prevention:\n"
        "- Tip 1\n"
        "- Tip 2\n"
        "- (Optional) Tip 3\n"
        "- (Optional) Tip 4"
    )

    try:
        response = client.chat.completions.create(
            model="deepseek/deepseek-r1:free",
            messages=[
                {"role": "system", "content": "You are a knowledgeable plant pathologist."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.7,
            max_tokens=800
        )

        content = response.choices[0].message.content
        if "Description:" in content and "Prevention:" in content:
            parts = content.split("Prevention:")
            description = parts[0].replace("Description:", "").strip()
            prevention = parts[1].strip()
            return description, prevention
        else:
            return "Description not structured correctly.", "No prevention steps found."
    except Exception as e:
        print(f"[ERROR] OpenRouter API error: {e}")
        return "OpenRouter error.", "Failed to generate prevention steps."

def preprocess_image(image, size=(224, 224)):
    image = image.resize(size)
    img_array = np.array(image).astype(np.float32) / 255.0
    img_array = img_array.transpose(2, 0, 1)
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

def predict(image):
    image = image.convert("RGB")
    input_tensor = preprocess_image(image)

    input_name = learn.get_inputs()[0].name
    outputs = learn.run(None, {input_name: input_tensor})
    probs = outputs[0][0]
    pred_idx = int(np.argmax(probs))
    pred_class = class_names[pred_idx]
    confidence = float(probs[pred_idx] * 100)

    description, prevention = generate_description_and_prevention(pred_class)

    return pred_class, round(confidence, 2), description, prevention

# === Gradio Interface ===
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(label="Prediction"),
        gr.Number(label="Confidence %"),
        gr.Textbox(label="Description"),
        gr.Textbox(label="Prevention")
    ],
    title="🌱 Crop Disease Detection",
    description="Upload a crop or leaf image to detect plant diseases and get prevention tips."
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)