import gradio as gr from ultralytics import YOLO import cv2 import numpy as np from PIL import Image from sklearn.cluster import DBSCAN # Load the YOLO model model = YOLO('models/rugai_m_v2.pt') def remove_overlapping_boxes(boxes, iou_threshold=0.3): """Remove overlapping boxes using IoU threshold.""" if not boxes: return [] # Convert boxes to numpy array boxes = np.array(boxes) # Calculate areas areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) # Sort by area (largest first) indices = np.argsort(areas)[::-1] keep = [] while indices.size > 0: i = indices[0] keep.append(i) # Calculate IoU with remaining boxes xx1 = np.maximum(boxes[i, 0], boxes[indices[1:], 0]) yy1 = np.maximum(boxes[i, 1], boxes[indices[1:], 1]) xx2 = np.minimum(boxes[i, 2], boxes[indices[1:], 2]) yy2 = np.minimum(boxes[i, 3], boxes[indices[1:], 3]) w = np.maximum(0, xx2 - xx1) h = np.maximum(0, yy2 - yy1) overlap = (w * h) / areas[indices[1:]] # Keep boxes with IoU less than threshold indices = indices[1:][overlap < iou_threshold] return keep def process_image(image, show_boxes=True): # Convert PIL Image to numpy array if needed if isinstance(image, Image.Image): image = np.array(image) # Run inference with specific parameters results = model.predict(image, imgsz=320, conf=0.4, iou=0.9)[0] # Lists to store center points of knots centers_x = [] centers_y = [] # Process each result and extract boxes boxes = [] # Store all boxes and their centers height, width = image.shape[:2] for box in results.boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # Calculate box center center_x = (x1 + x2) // 2 center_y = (y1 + y2) // 2 boxes.append({ 'coords': (x1, y1, x2, y2), 'center': (center_x, center_y) }) centers_x.append(center_x) centers_y.append(center_y) # Remove overlapping boxes if boxes: box_coords = [box['coords'] for box in boxes] keep_indices = remove_overlapping_boxes(box_coords, iou_threshold=0.3) boxes = [boxes[i] for i in keep_indices] centers_x = [centers_x[i] for i in keep_indices] centers_y = [centers_y[i] for i in keep_indices] # Sort centers centers_y.sort() centers_x.sort() # Set tolerances based on average knot size if len(boxes) > 0: avg_width = sum((b['coords'][2] - b['coords'][0]) for b in boxes) / len(boxes) avg_height = sum((b['coords'][3] - b['coords'][1]) for b in boxes) / len(boxes) x_tolerance = int(avg_width * 0.22) y_tolerance = int(avg_height * 0.22) else: x_tolerance = y_tolerance = 5 # Find representative points for rows and columns using DBSCAN rows = [] cols = [] # Cluster y-coordinates into rows if centers_y: y_centers = np.array(centers_y).reshape(-1, 1) y_clustering = DBSCAN(eps=y_tolerance, min_samples=2, metric='euclidean').fit(y_centers) unique_labels = np.unique(y_clustering.labels_) for label in unique_labels: if label != -1: # Skip noise points cluster_points = y_centers[y_clustering.labels_ == label] rows.append(int(np.mean(cluster_points))) # Cluster x-coordinates into columns if centers_x: x_centers = np.array(centers_x).reshape(-1, 1) x_clustering = DBSCAN(eps=x_tolerance, min_samples=2, metric='euclidean').fit(x_centers) unique_labels = np.unique(x_clustering.labels_) for label in unique_labels: if label != -1: # Skip noise points cluster_points = x_centers[x_clustering.labels_ == label] cols.append(int(np.mean(cluster_points))) # Sort rows and columns rows.sort() cols.sort() # Calculate total knots total_knots = len(rows) * len(cols) # Add padding for measurements padding = 100 padded_img = np.full((height + 2*padding, width + 2*padding, 3), 255, dtype=np.uint8) padded_img[padding:padding+height, padding:padding+width] = image # Draw boxes if requested if show_boxes: for box in boxes: x1, y1, x2, y2 = box['coords'] cv2.rectangle(padded_img, (x1 + padding, y1 + padding), (x2 + padding, y2 + padding), (0, 255, 0), 2) # Draw measurement lines and labels cv2.line(padded_img, (padding, padding//2), (width+padding, padding//2), (0, 0, 0), 2) cv2.putText(padded_img, f"{len(cols)} knots", (padding + width//2 - 100, padding//2 - 10), cv2.FONT_HERSHEY_DUPLEX, 0.7, (0, 0, 0), 2) cv2.line(padded_img, (width+padding+padding//2, padding), (width+padding+padding//2, height+padding), (0, 0, 0), 2) cv2.putText(padded_img, f"{len(rows)} knots", (width+padding+padding//2 + 10, padding + height//2), cv2.FONT_HERSHEY_DUPLEX, 0.7, (0, 0, 0), 2) # Add total knot count and density cv2.putText(padded_img, f"{int(total_knots)} Total Knots", (padding + width//2 - 100, height + padding + padding//2), cv2.FONT_HERSHEY_DUPLEX, 0.7, (0, 0, 0), 2) cv2.putText(padded_img, f"{int(total_knots)} knots/sqcm", (padding + width//2 - 100, height + padding + padding//2 + 30), cv2.FONT_HERSHEY_DUPLEX, 0.7, (0, 0, 0), 2) # Prepare detection information detection_info = f"Total Knots: {int(total_knots)}\n" detection_info += f"Rows: {len(rows)}\n" detection_info += f"Columns: {len(cols)}\n" detection_info += f"Density: {int(total_knots)} knots/sqcm" return padded_img, detection_info # Create Gradio interface with gr.Blocks(title="Rug Knot Detector") as demo: gr.Markdown("# 🧶 Rug Knot Detector") gr.Markdown("Upload an image of a rug to detect and analyze knots using our custom YOLO model.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Rug Image") show_boxes = gr.Checkbox(label="Show Detection Boxes", value=True) detect_btn = gr.Button("Detect Knots") with gr.Column(): output_image = gr.Image(label="Detection Results") output_text = gr.Textbox(label="Detection Information", lines=5) detect_btn.click( fn=process_image, inputs=[input_image, show_boxes], outputs=[output_image, output_text] ) if __name__ == "__main__": demo.launch(share=True)