File size: 3,470 Bytes
28c65ef
5148c74
 
28c65ef
5148c74
28c65ef
5148c74
 
28c65ef
5148c74
0ede587
 
 
 
 
 
 
9390992
 
 
bd6d077
9390992
 
 
 
 
 
28c65ef
9390992
28c65ef
9390992
 
 
 
 
 
 
 
 
 
 
28c65ef
 
 
 
 
 
 
 
 
9390992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28c65ef
 
 
 
 
 
 
 
9390992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ede587
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Check if running in Hugging Face Spaces environment
try:
    import spaces

    HF_SPACES = True
    print("Running in Hugging Face Spaces environment")
except ImportError:
    HF_SPACES = False
    print("Running in local environment")

import gradio as gr
from PIL import Image
import os
from classifier import GarbageClassifier
from config import Config


# Initialize classifier
config = Config()
classifier = GarbageClassifier(config)

# Load model at startup
print("Loading model...")
classifier.load_model()
print("Model loaded successfully!")


def classify_garbage_impl(image):
    """
    Actual classification implementation
    """
    if image is None:
        return "Please upload an image", "No image provided"

    try:
        classification, full_response = classifier.classify_image(image)
        return classification, full_response
    except Exception as e:
        return "Error", f"Classification failed: {str(e)}"


# Apply GPU decorator based on environment
if HF_SPACES:
    classify_garbage = spaces.GPU(classify_garbage_impl)
    print("GPU decorator applied for Hugging Face Spaces")
else:
    classify_garbage = classify_garbage_impl
    print("Running without GPU decorator")


def get_example_images():
    """Get example images if they exist"""
    example_dir = "test_images"
    examples = []
    if os.path.exists(example_dir):
        for file in os.listdir(example_dir):
            if file.lower().endswith((".png", ".jpg", ".jpeg")):
                examples.append(os.path.join(example_dir, file))
    return examples[:3]  # Limit to 3 examples


# Create Gradio interface
with gr.Blocks(title="Garbage Classification System") as demo:
    gr.Markdown("# πŸ—‚οΈ Garbage Classification System")
    gr.Markdown(
        "Upload an image to classify garbage into: Recyclable Waste, Food/Kitchen Waste, Hazardous Waste, or Other Waste"
    )

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Garbage Image")

            classify_btn = gr.Button("Classify Garbage", variant="primary", size="lg")

        with gr.Column():
            classification_output = gr.Textbox(
                label="Classification Result",
                placeholder="Upload an image and click classify",
            )

            full_response_output = gr.Textbox(
                label="Detailed Analysis",
                placeholder="Detailed reasoning will appear here",
                lines=10,
            )

    # Category information
    with gr.Accordion("πŸ“‹ Garbage Categories Information", open=False):
        try:
            category_info = classifier.get_categories_info()
            for category, description in category_info.items():
                gr.Markdown(f"**{category}**: {description}")
        except Exception as e:
            gr.Markdown(f"Categories information not available: {str(e)}")

    # Examples section
    examples = get_example_images()
    if examples:
        gr.Examples(examples=examples, inputs=image_input, label="Example Images")

    # Event handlers
    classify_btn.click(
        fn=classify_garbage,
        inputs=image_input,
        outputs=[classification_output, full_response_output],
    )

    # Auto-classify on image upload
    image_input.change(
        fn=classify_garbage,
        inputs=image_input,
        outputs=[classification_output, full_response_output],
    )

if __name__ == "__main__":
    demo.launch()