Spaces:
Running
Running
File size: 3,194 Bytes
39aad74 78efd15 6aee1db 78efd15 39aad74 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 6aee1db 78efd15 39aad74 78efd15 39aad74 5a6ed3e 39aad74 78efd15 39aad74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
# Load YOLO model
model = YOLO('yolo11s-earth.pt') # Load your model
# Default classes
default_classes = [
'airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge',
'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship',
'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle',
'windmill'
]
def process_frame(frame, classes_input):
# Process user input classes
if classes_input and classes_input.strip():
classes_list = [cls.strip() for cls in classes_input.split(',')]
# Validate classes_list
for cls in classes_list:
if not isinstance(cls, str):
print("Invalid class name:", cls)
continue
model.set_classes(classes_list) # Set model classes
else:
# Use default classes if no input or input is empty
model.set_classes(default_classes)
# Copy frame to a writable array
frame = frame.copy()
# Resize image to speed up processing (optional)
h, w = frame.shape[:2]
new_size = (640, int(h * (640 / w))) if w > h else (int(w * (640 / h)), 640)
resized_frame = cv2.resize(frame, new_size)
# Convert image format
rgb_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
# Use model for detection
results = model.predict(rgb_frame)
# Draw detection results
for result in results:
boxes = result.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0]
conf = box.conf[0]
cls = box.cls[0]
try:
class_name = model.names[int(cls)]
except (IndexError, TypeError) as e:
print(f"Error accessing model.names: {e}")
class_name = "Unknown" # Provide a default value
# Adjust coordinates to original image size
x1 = int(x1 * w / new_size[0])
y1 = int(y1 * h / new_size[1])
x2 = int(x2 * w / new_size[0])
y2 = int(y2 * h / new_size[1])
# Draw bounding box and label
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f'{class_name}:{conf:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
return frame
def main():
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# YOLO11s-Earth open vocabulary detection (DIOR finetuning)")
with gr.Row():
cam_input = gr.Image(type="numpy", sources=["webcam"], streaming=True, label="Webcam")
classes_input = gr.Textbox(label="New classes (comma-separated)", placeholder="e.g.: airplane, airport, tennis court")
output = gr.Image(label="Results", type="numpy", height=480) # Set height to 480
cam_input.stream(
process_frame,
inputs=[cam_input, classes_input],
outputs=output
)
# Launch Gradio app
demo.launch()
if __name__ == "__main__":
main() |