File size: 3,194 Bytes
39aad74
 
 
 
 
78efd15
 
6aee1db
78efd15
39aad74
 
 
 
 
 
 
 
6aee1db
78efd15
 
6aee1db
78efd15
 
 
 
 
 
6aee1db
78efd15
6aee1db
 
78efd15
6aee1db
 
78efd15
 
 
 
6aee1db
78efd15
 
6aee1db
78efd15
 
 
 
6aee1db
 
 
 
 
 
78efd15
 
 
 
 
 
 
 
 
 
 
6aee1db
78efd15
 
 
6aee1db
 
 
 
78efd15
39aad74
78efd15
39aad74
5a6ed3e
 
 
39aad74
 
 
 
 
 
 
78efd15
39aad74
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO

# Load YOLO model
model = YOLO('yolo11s-earth.pt')  # Load your model

# Default classes
default_classes = [
    'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
    'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
    'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship',
    'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle',
    'windmill'
]

def process_frame(frame, classes_input):
    # Process user input classes
    if classes_input and classes_input.strip():
        classes_list = [cls.strip() for cls in classes_input.split(',')]
        # Validate classes_list
        for cls in classes_list:
            if not isinstance(cls, str):
                print("Invalid class name:", cls)
                continue
        model.set_classes(classes_list)  # Set model classes
    else:
        # Use default classes if no input or input is empty
        model.set_classes(default_classes)
    
    # Copy frame to a writable array
    frame = frame.copy()
    
    # Resize image to speed up processing (optional)
    h, w = frame.shape[:2]
    new_size = (640, int(h * (640 / w))) if w > h else (int(w * (640 / h)), 640)
    resized_frame = cv2.resize(frame, new_size)
    
    # Convert image format
    rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
    
    # Use model for detection
    results = model.predict(rgb_frame)
    
    # Draw detection results
    for result in results:
        boxes = result.boxes
        for box in boxes:
            x1, y1, x2, y2 = box.xyxy[0]
            conf = box.conf[0]
            cls = box.cls[0]
            try:
                class_name = model.names[int(cls)]
            except (IndexError, TypeError) as e:
                print(f"Error accessing model.names: {e}")
                class_name = "Unknown"  # Provide a default value
            
            # Adjust coordinates to original image size
            x1 = int(x1 * w / new_size[0])
            y1 = int(y1 * h / new_size[1])
            x2 = int(x2 * w / new_size[0])
            y2 = int(y2 * h / new_size[1])
            
            # Draw bounding box and label
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
    
    return frame

def main():
    # Create Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown("# YOLO11s-Earth open vocabulary detection (DIOR finetuning)")
        with gr.Row():
            cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
            classes_input = gr.Textbox(label="New classes (comma-separated)", placeholder="e.g.: airplane, airport, tennis court")
        output = gr.Image(label="Results", type="numpy", height=480)  # Set height to 480
        
        cam_input.stream(
            process_frame,
            inputs=[cam_input, classes_input],
            outputs=output
        )
    
    # Launch Gradio app
    demo.launch()

if __name__ == "__main__":
    main()