|
import ultralytics |
|
import cv2 |
|
from ultralytics import YOLO |
|
import os |
|
import glob |
|
import argparse |
|
import sys |
|
import numpy as np |
|
import uuid |
|
|
|
def iou(box1, box2): |
|
|
|
x1_1 = box1[0] - box1[2] / 2 |
|
y1_1 = box1[1] - box1[3] / 2 |
|
x2_1 = box1[0] + box1[2] / 2 |
|
y2_1 = box1[1] + box1[3] / 2 |
|
|
|
x1_2 = box2[0] - box2[2] / 2 |
|
y1_2 = box2[1] - box2[3] / 2 |
|
x2_2 = box2[0] + box2[2] / 2 |
|
y2_2 = box2[1] + box2[3] / 2 |
|
|
|
xi1 = max(x1_1, x1_2) |
|
yi1 = max(y1_1, y1_2) |
|
xi2 = min(x2_1, x2_2) |
|
yi2 = min(y2_1, y2_2) |
|
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) |
|
|
|
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) |
|
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) |
|
union_area = box1_area + box2_area - inter_area |
|
|
|
return inter_area / union_area if union_area > 0 else 0 |
|
|
|
def load_image(input_image, base_name: str = None): |
|
|
|
if isinstance(input_image, str): |
|
img = cv2.imread(input_image) |
|
if img is None: |
|
raise ValueError(f"Unable to load image from path: {input_image}") |
|
if base_name is None: |
|
base_name = os.path.splitext(os.path.basename(input_image))[0] |
|
return img, base_name |
|
else: |
|
|
|
if isinstance(input_image, bytes): |
|
image_bytes = input_image |
|
else: |
|
image_bytes = input_image.read() |
|
nparr = np.frombuffer(image_bytes, np.uint8) |
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
if img is None: |
|
raise ValueError("Unable to decode image from input bytes or file-like object.") |
|
if base_name is None: |
|
base_name = str(uuid.uuid4()) |
|
return img, base_name |
|
|
|
def process_yolo(input_image, weights_file: str, output_dir: str = './yolo_run', model_obj: YOLO = None, base_name: str = None) -> str: |
|
|
|
orig_image, inferred_base_name = load_image(input_image, base_name) |
|
base_name = inferred_base_name |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
if isinstance(input_image, str): |
|
ext = os.path.splitext(input_image)[1] |
|
else: |
|
ext = ".png" |
|
output_image_name = f"{base_name}_yolo{ext}" |
|
updated_output_image_name = f"{base_name}_yolo_updated{ext}" |
|
|
|
|
|
|
|
if isinstance(input_image, str): |
|
source_input = input_image |
|
else: |
|
source_input = orig_image |
|
|
|
|
|
if model_obj is None: |
|
model = YOLO(weights_file) |
|
else: |
|
model = model_obj |
|
|
|
results = model( |
|
source=source_input, |
|
save_txt=True, |
|
project=output_dir, |
|
name='.', |
|
exist_ok=True, |
|
) |
|
|
|
|
|
img_with_boxes = results[0].plot(font_size=2, line_width=1) |
|
output_image_path = os.path.join(output_dir, output_image_name) |
|
cv2.imwrite(output_image_path, img_with_boxes) |
|
print(f"Image saved as '{output_image_path}'") |
|
|
|
labels_dir = os.path.join(output_dir, 'labels') |
|
label_file = os.path.join(labels_dir, f"{base_name}.txt") |
|
|
|
if not os.path.isfile(label_file): |
|
raise FileNotFoundError(f"No label files found for the image '{base_name}' at path '{label_file}'.") |
|
|
|
with open(label_file, 'r') as f: |
|
lines = f.readlines() |
|
|
|
boxes = [] |
|
for idx, line in enumerate(lines): |
|
tokens = line.strip().split() |
|
class_id = int(tokens[0]) |
|
x_center = float(tokens[1]) |
|
y_center = float(tokens[2]) |
|
width = float(tokens[3]) |
|
height = float(tokens[4]) |
|
boxes.append({ |
|
'class_id': class_id, |
|
'bbox': [x_center, y_center, width, height], |
|
'line': line, |
|
'index': idx |
|
}) |
|
|
|
boxes.sort(key=lambda b: b['bbox'][1] - (b['bbox'][3] / 2)) |
|
|
|
|
|
keep_indices = [] |
|
suppressed = [False] * len(boxes) |
|
num_removed = 0 |
|
for i in range(len(boxes)): |
|
if suppressed[i]: |
|
continue |
|
keep_indices.append(i) |
|
for j in range(i + 1, len(boxes)): |
|
if suppressed[j]: |
|
continue |
|
if boxes[i]['class_id'] == boxes[j]['class_id']: |
|
iou_value = iou(boxes[i]['bbox'], boxes[j]['bbox']) |
|
if iou_value > 0.7: |
|
suppressed[j] = True |
|
num_removed += 1 |
|
|
|
with open(label_file, 'w') as f: |
|
for idx in keep_indices: |
|
f.write(boxes[idx]['line']) |
|
|
|
print(f"Number of bounding boxes removed: {num_removed}") |
|
|
|
|
|
drawn_image = orig_image.copy() |
|
h_img, w_img, _ = drawn_image.shape |
|
|
|
for i, idx in enumerate(keep_indices): |
|
box = boxes[idx] |
|
x_center, y_center, w_norm, h_norm = box['bbox'] |
|
x_center *= w_img |
|
y_center *= h_img |
|
w_box = w_norm * w_img |
|
h_box = h_norm * h_img |
|
x1 = int(x_center - w_box / 2) |
|
y1 = int(y_center - h_box / 2) |
|
x2 = int(x_center + w_box / 2) |
|
y2 = int(y_center + h_box / 2) |
|
cv2.rectangle(drawn_image, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
cv2.putText(drawn_image, str(i + 1), (x1, y1 - 5), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (143, 10, 18), 1) |
|
|
|
updated_output_image_path = os.path.join(output_dir, updated_output_image_name) |
|
cv2.imwrite(updated_output_image_path, drawn_image) |
|
print(f"Updated image saved as '{updated_output_image_path}'") |
|
|
|
return updated_output_image_path |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Process YOLO inference and NMS on an image.') |
|
parser.add_argument('input_image', help='Path to the input image or pass raw bytes via a file-like object.') |
|
parser.add_argument('weights_file', help='Path to the YOLO weights file.') |
|
parser.add_argument('output_dir', nargs='?', default='./yolo_run', help='Output directory (optional).') |
|
parser.add_argument('--base_name', help='Optional base name for output files (without extension).') |
|
args = parser.parse_args() |
|
|
|
try: |
|
process_yolo(args.input_image, args.weights_file, args.output_dir, base_name=args.base_name) |
|
except Exception as e: |
|
print(e) |
|
sys.exit(1) |
|
|