File size: 6,559 Bytes
6ff22d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import ultralytics
import cv2
from ultralytics import YOLO
import os
import glob
import argparse
import sys
import numpy as np
import uuid
def iou(box1, box2):
# Convert normalized coordinates to (x1, y1, x2, y2)
x1_1 = box1[0] - box1[2] / 2
y1_1 = box1[1] - box1[3] / 2
x2_1 = box1[0] + box1[2] / 2
y2_1 = box1[1] + box1[3] / 2
x1_2 = box2[0] - box2[2] / 2
y1_2 = box2[1] - box2[3] / 2
x2_2 = box2[0] + box2[2] / 2
y2_2 = box2[1] + box2[3] / 2
xi1 = max(x1_1, x1_2)
yi1 = max(y1_1, y1_2)
xi2 = min(x2_1, x2_2)
yi2 = min(y2_1, y2_2)
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0
def load_image(input_image, base_name: str = None):
if isinstance(input_image, str):
img = cv2.imread(input_image)
if img is None:
raise ValueError(f"Unable to load image from path: {input_image}")
if base_name is None:
base_name = os.path.splitext(os.path.basename(input_image))[0]
return img, base_name
else:
# Assume input_image is raw bytes or a file-like object
if isinstance(input_image, bytes):
image_bytes = input_image
else:
image_bytes = input_image.read()
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Unable to decode image from input bytes or file-like object.")
if base_name is None:
base_name = str(uuid.uuid4())
return img, base_name
def process_yolo(input_image, weights_file: str, output_dir: str = './yolo_run', model_obj: YOLO = None, base_name: str = None) -> str:
orig_image, inferred_base_name = load_image(input_image, base_name)
base_name = inferred_base_name
os.makedirs(output_dir, exist_ok=True)
# Determine file extension: if input_image is a file path, use its extension. Otherwise, default to .png
if isinstance(input_image, str):
ext = os.path.splitext(input_image)[1]
else:
ext = ".png"
output_image_name = f"{base_name}_yolo{ext}"
updated_output_image_name = f"{base_name}_yolo_updated{ext}"
# If input_image is a file path, call YOLO with the path to preserve filename-based labeling.
# Otherwise, if processing in-memory, YOLO might default to a generic name.
if isinstance(input_image, str):
source_input = input_image
else:
source_input = orig_image
# Use provided model or load from weights_file.
if model_obj is None:
model = YOLO(weights_file)
else:
model = model_obj
results = model(
source=source_input,
save_txt=True,
project=output_dir,
name='.',
exist_ok=True,
)
# Save the initial inference image.
img_with_boxes = results[0].plot(font_size=2, line_width=1)
output_image_path = os.path.join(output_dir, output_image_name)
cv2.imwrite(output_image_path, img_with_boxes)
print(f"Image saved as '{output_image_path}'")
labels_dir = os.path.join(output_dir, 'labels')
label_file = os.path.join(labels_dir, f"{base_name}.txt")
if not os.path.isfile(label_file):
raise FileNotFoundError(f"No label files found for the image '{base_name}' at path '{label_file}'.")
with open(label_file, 'r') as f:
lines = f.readlines()
boxes = []
for idx, line in enumerate(lines):
tokens = line.strip().split()
class_id = int(tokens[0])
x_center = float(tokens[1])
y_center = float(tokens[2])
width = float(tokens[3])
height = float(tokens[4])
boxes.append({
'class_id': class_id,
'bbox': [x_center, y_center, width, height],
'line': line,
'index': idx
})
boxes.sort(key=lambda b: b['bbox'][1] - (b['bbox'][3] / 2))
# Perform NMS.
keep_indices = []
suppressed = [False] * len(boxes)
num_removed = 0
for i in range(len(boxes)):
if suppressed[i]:
continue
keep_indices.append(i)
for j in range(i + 1, len(boxes)):
if suppressed[j]:
continue
if boxes[i]['class_id'] == boxes[j]['class_id']:
iou_value = iou(boxes[i]['bbox'], boxes[j]['bbox'])
if iou_value > 0.7:
suppressed[j] = True
num_removed += 1
with open(label_file, 'w') as f:
for idx in keep_indices:
f.write(boxes[idx]['line'])
print(f"Number of bounding boxes removed: {num_removed}")
# Draw updated bounding boxes on the original image (loaded in memory).
drawn_image = orig_image.copy()
h_img, w_img, _ = drawn_image.shape
for i, idx in enumerate(keep_indices):
box = boxes[idx]
x_center, y_center, w_norm, h_norm = box['bbox']
x_center *= w_img
y_center *= h_img
w_box = w_norm * w_img
h_box = h_norm * h_img
x1 = int(x_center - w_box / 2)
y1 = int(y_center - h_box / 2)
x2 = int(x_center + w_box / 2)
y2 = int(y_center + h_box / 2)
cv2.rectangle(drawn_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(drawn_image, str(i + 1), (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (143, 10, 18), 1)
updated_output_image_path = os.path.join(output_dir, updated_output_image_name)
cv2.imwrite(updated_output_image_path, drawn_image)
print(f"Updated image saved as '{updated_output_image_path}'")
return updated_output_image_path
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process YOLO inference and NMS on an image.')
parser.add_argument('input_image', help='Path to the input image or pass raw bytes via a file-like object.')
parser.add_argument('weights_file', help='Path to the YOLO weights file.')
parser.add_argument('output_dir', nargs='?', default='./yolo_run', help='Output directory (optional).')
parser.add_argument('--base_name', help='Optional base name for output files (without extension).')
args = parser.parse_args()
try:
process_yolo(args.input_image, args.weights_file, args.output_dir, base_name=args.base_name)
except Exception as e:
print(e)
sys.exit(1)
|