import ultralytics import cv2 from ultralytics import YOLO import os import glob import argparse import sys import numpy as np import uuid def iou(box1, box2): # Convert normalized coordinates to (x1, y1, x2, y2) x1_1 = box1[0] - box1[2] / 2 y1_1 = box1[1] - box1[3] / 2 x2_1 = box1[0] + box1[2] / 2 y2_1 = box1[1] + box1[3] / 2 x1_2 = box2[0] - box2[2] / 2 y1_2 = box2[1] - box2[3] / 2 x2_2 = box2[0] + box2[2] / 2 y2_2 = box2[1] + box2[3] / 2 xi1 = max(x1_1, x1_2) yi1 = max(y1_1, y1_2) xi2 = min(x2_1, x2_2) yi2 = min(y2_1, y2_2) inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) union_area = box1_area + box2_area - inter_area return inter_area / union_area if union_area > 0 else 0 def load_image(input_image, base_name: str = None): if isinstance(input_image, str): img = cv2.imread(input_image) if img is None: raise ValueError(f"Unable to load image from path: {input_image}") if base_name is None: base_name = os.path.splitext(os.path.basename(input_image))[0] return img, base_name else: # Assume input_image is raw bytes or a file-like object if isinstance(input_image, bytes): image_bytes = input_image else: image_bytes = input_image.read() nparr = np.frombuffer(image_bytes, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise ValueError("Unable to decode image from input bytes or file-like object.") if base_name is None: base_name = str(uuid.uuid4()) return img, base_name def process_yolo(input_image, weights_file: str, output_dir: str = './yolo_run', model_obj: YOLO = None, base_name: str = None) -> str: orig_image, inferred_base_name = load_image(input_image, base_name) base_name = inferred_base_name os.makedirs(output_dir, exist_ok=True) # Determine file extension: if input_image is a file path, use its extension. Otherwise, default to .png if isinstance(input_image, str): ext = os.path.splitext(input_image)[1] else: ext = ".png" output_image_name = f"{base_name}_yolo{ext}" updated_output_image_name = f"{base_name}_yolo_updated{ext}" # If input_image is a file path, call YOLO with the path to preserve filename-based labeling. # Otherwise, if processing in-memory, YOLO might default to a generic name. if isinstance(input_image, str): source_input = input_image else: source_input = orig_image # Use provided model or load from weights_file. if model_obj is None: model = YOLO(weights_file) else: model = model_obj results = model( source=source_input, save_txt=True, project=output_dir, name='.', exist_ok=True, ) # Save the initial inference image. img_with_boxes = results[0].plot(font_size=2, line_width=1) output_image_path = os.path.join(output_dir, output_image_name) cv2.imwrite(output_image_path, img_with_boxes) print(f"Image saved as '{output_image_path}'") labels_dir = os.path.join(output_dir, 'labels') label_file = os.path.join(labels_dir, f"{base_name}.txt") if not os.path.isfile(label_file): raise FileNotFoundError(f"No label files found for the image '{base_name}' at path '{label_file}'.") with open(label_file, 'r') as f: lines = f.readlines() boxes = [] for idx, line in enumerate(lines): tokens = line.strip().split() class_id = int(tokens[0]) x_center = float(tokens[1]) y_center = float(tokens[2]) width = float(tokens[3]) height = float(tokens[4]) boxes.append({ 'class_id': class_id, 'bbox': [x_center, y_center, width, height], 'line': line, 'index': idx }) boxes.sort(key=lambda b: b['bbox'][1] - (b['bbox'][3] / 2)) # Perform NMS. keep_indices = [] suppressed = [False] * len(boxes) num_removed = 0 for i in range(len(boxes)): if suppressed[i]: continue keep_indices.append(i) for j in range(i + 1, len(boxes)): if suppressed[j]: continue if boxes[i]['class_id'] == boxes[j]['class_id']: iou_value = iou(boxes[i]['bbox'], boxes[j]['bbox']) if iou_value > 0.7: suppressed[j] = True num_removed += 1 with open(label_file, 'w') as f: for idx in keep_indices: f.write(boxes[idx]['line']) print(f"Number of bounding boxes removed: {num_removed}") # Draw updated bounding boxes on the original image (loaded in memory). drawn_image = orig_image.copy() h_img, w_img, _ = drawn_image.shape for i, idx in enumerate(keep_indices): box = boxes[idx] x_center, y_center, w_norm, h_norm = box['bbox'] x_center *= w_img y_center *= h_img w_box = w_norm * w_img h_box = h_norm * h_img x1 = int(x_center - w_box / 2) y1 = int(y_center - h_box / 2) x2 = int(x_center + w_box / 2) y2 = int(y_center + h_box / 2) cv2.rectangle(drawn_image, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(drawn_image, str(i + 1), (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (143, 10, 18), 1) updated_output_image_path = os.path.join(output_dir, updated_output_image_name) cv2.imwrite(updated_output_image_path, drawn_image) print(f"Updated image saved as '{updated_output_image_path}'") return updated_output_image_path if __name__ == '__main__': parser = argparse.ArgumentParser(description='Process YOLO inference and NMS on an image.') parser.add_argument('input_image', help='Path to the input image or pass raw bytes via a file-like object.') parser.add_argument('weights_file', help='Path to the YOLO weights file.') parser.add_argument('output_dir', nargs='?', default='./yolo_run', help='Output directory (optional).') parser.add_argument('--base_name', help='Optional base name for output files (without extension).') args = parser.parse_args() try: process_yolo(args.input_image, args.weights_file, args.output_dir, base_name=args.base_name) except Exception as e: print(e) sys.exit(1)