Spaces:
Runtime error
Runtime error
from typing import Any, List | |
import numpy as np | |
import gradio as gr | |
from PIL import Image, ImageDraw | |
import yolo_detect as yod | |
import cv2 | |
import fiftyone as fo | |
from flagging import FlaggingCallback, SimpleCSVLogger | |
class NewLogger(FlaggingCallback): | |
def __init__(self): | |
self.flag_data = None | |
self.flag_option = None | |
self.flag_index = None | |
super().__init__() | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option= None, | |
flag_index = None, | |
username = None, | |
) -> int: | |
self.flag_data = flag_data | |
self.flag_option = flag_option | |
self.flag_index = flag_index | |
if flag_option == "Bad": | |
print(flag_option) | |
# log_filepath = Path(flagging_dir) / "log.csv" | |
# csv_data = [] | |
# for component, sample in zip(self.components, flag_data): | |
# save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters( | |
# component.label or "" | |
# ) | |
return 2 | |
def draw_boxes(input_img, iou_threshold, confidence_threshdol): | |
# Convert the input to a PIL Image | |
img = cv2.resize(input_img, (640, 640)) | |
img = Image.fromarray(img) | |
draw = ImageDraw.Draw(img) | |
# Example bounding boxes: (x1, y1, x2, y2) | |
boxes = yod.identifications(input_img, iou_threshold, confidence_threshdol) | |
# Draw rectangles on the image | |
for box in boxes: | |
bbox = box[:4] | |
draw.rectangle(bbox, outline="red", width=2) | |
draw.text((bbox[0], bbox[1] - 16), f"{box[5]} -> {str(round(box[4], 2))}", fill="red") | |
# Convert back to array | |
return np.array(img) | |
# Define the Gradio Interface with a title | |
title = "Vic and the boyzzzzzz" | |
# Define a slider | |
iou_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="IoU Threshold", value=0.0) | |
conf_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Threshold", value=0.0) | |
text_box = gr.Textbox(label="Write a description for bad behavior") | |
demo = gr.Interface(fn=draw_boxes, inputs=[gr.Image(), iou_slider, conf_slider], outputs=gr.Image(), title=title, flagging_options=["Good", "Bad"], allow_flagging="manual", flagging_callback=SimpleCSVLogger()) | |
# Launch the app | |
demo.launch() | |