File size: 6,559 Bytes
6ff22d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import ultralytics
import cv2
from ultralytics import YOLO
import os
import glob
import argparse
import sys
import numpy as np
import uuid

def iou(box1, box2):
    # Convert normalized coordinates to (x1, y1, x2, y2)
    x1_1 = box1[0] - box1[2] / 2
    y1_1 = box1[1] - box1[3] / 2
    x2_1 = box1[0] + box1[2] / 2
    y2_1 = box1[1] + box1[3] / 2

    x1_2 = box2[0] - box2[2] / 2
    y1_2 = box2[1] - box2[3] / 2
    x2_2 = box2[0] + box2[2] / 2
    y2_2 = box2[1] + box2[3] / 2

    xi1 = max(x1_1, x1_2)
    yi1 = max(y1_1, y1_2)
    xi2 = min(x2_1, x2_2)
    yi2 = min(y2_1, y2_2)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)

    box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
    box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
    union_area = box1_area + box2_area - inter_area

    return inter_area / union_area if union_area > 0 else 0

def load_image(input_image, base_name: str = None):

    if isinstance(input_image, str):
        img = cv2.imread(input_image)
        if img is None:
            raise ValueError(f"Unable to load image from path: {input_image}")
        if base_name is None:
            base_name = os.path.splitext(os.path.basename(input_image))[0]
        return img, base_name
    else:
        # Assume input_image is raw bytes or a file-like object
        if isinstance(input_image, bytes):
            image_bytes = input_image
        else:
            image_bytes = input_image.read()
        nparr = np.frombuffer(image_bytes, np.uint8)
        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        if img is None:
            raise ValueError("Unable to decode image from input bytes or file-like object.")
        if base_name is None:
            base_name = str(uuid.uuid4())
        return img, base_name

def process_yolo(input_image, weights_file: str, output_dir: str = './yolo_run', model_obj: YOLO = None, base_name: str = None) -> str:

    orig_image, inferred_base_name = load_image(input_image, base_name)
    base_name = inferred_base_name 
    
    os.makedirs(output_dir, exist_ok=True)
    # Determine file extension: if input_image is a file path, use its extension. Otherwise, default to .png
    if isinstance(input_image, str):
        ext = os.path.splitext(input_image)[1]
    else:
        ext = ".png"
    output_image_name = f"{base_name}_yolo{ext}"
    updated_output_image_name = f"{base_name}_yolo_updated{ext}"

    # If input_image is a file path, call YOLO with the path to preserve filename-based labeling.
    # Otherwise, if processing in-memory, YOLO might default to a generic name.
    if isinstance(input_image, str):
        source_input = input_image
    else:
        source_input = orig_image

    # Use provided model or load from weights_file.
    if model_obj is None:
        model = YOLO(weights_file)
    else:
        model = model_obj

    results = model(
        source=source_input,
        save_txt=True,
        project=output_dir,
        name='.',
        exist_ok=True,
    )

    # Save the initial inference image.
    img_with_boxes = results[0].plot(font_size=2, line_width=1)
    output_image_path = os.path.join(output_dir, output_image_name)
    cv2.imwrite(output_image_path, img_with_boxes)
    print(f"Image saved as '{output_image_path}'")

    labels_dir = os.path.join(output_dir, 'labels')
    label_file = os.path.join(labels_dir, f"{base_name}.txt")
    
    if not os.path.isfile(label_file):
        raise FileNotFoundError(f"No label files found for the image '{base_name}' at path '{label_file}'.")

    with open(label_file, 'r') as f:
        lines = f.readlines()

    boxes = []
    for idx, line in enumerate(lines):
        tokens = line.strip().split()
        class_id = int(tokens[0])
        x_center = float(tokens[1])
        y_center = float(tokens[2])
        width = float(tokens[3])
        height = float(tokens[4])
        boxes.append({
            'class_id': class_id,
            'bbox': [x_center, y_center, width, height],
            'line': line,
            'index': idx
        })

    boxes.sort(key=lambda b: b['bbox'][1] - (b['bbox'][3] / 2))

    # Perform NMS.
    keep_indices = []
    suppressed = [False] * len(boxes)
    num_removed = 0
    for i in range(len(boxes)):
        if suppressed[i]:
            continue
        keep_indices.append(i)
        for j in range(i + 1, len(boxes)):
            if suppressed[j]:
                continue
            if boxes[i]['class_id'] == boxes[j]['class_id']:
                iou_value = iou(boxes[i]['bbox'], boxes[j]['bbox'])
                if iou_value > 0.7:
                    suppressed[j] = True
                    num_removed += 1

    with open(label_file, 'w') as f:
        for idx in keep_indices:
            f.write(boxes[idx]['line'])

    print(f"Number of bounding boxes removed: {num_removed}")

    # Draw updated bounding boxes on the original image (loaded in memory).
    drawn_image = orig_image.copy()
    h_img, w_img, _ = drawn_image.shape

    for i, idx in enumerate(keep_indices):
        box = boxes[idx]
        x_center, y_center, w_norm, h_norm = box['bbox']
        x_center *= w_img
        y_center *= h_img
        w_box = w_norm * w_img
        h_box = h_norm * h_img
        x1 = int(x_center - w_box / 2)
        y1 = int(y_center - h_box / 2)
        x2 = int(x_center + w_box / 2)
        y2 = int(y_center + h_box / 2)
        cv2.rectangle(drawn_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(drawn_image, str(i + 1), (x1, y1 - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (143, 10, 18), 1)

    updated_output_image_path = os.path.join(output_dir, updated_output_image_name)
    cv2.imwrite(updated_output_image_path, drawn_image)
    print(f"Updated image saved as '{updated_output_image_path}'")
    
    return updated_output_image_path

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process YOLO inference and NMS on an image.')
    parser.add_argument('input_image', help='Path to the input image or pass raw bytes via a file-like object.')
    parser.add_argument('weights_file', help='Path to the YOLO weights file.')
    parser.add_argument('output_dir', nargs='?', default='./yolo_run', help='Output directory (optional).')
    parser.add_argument('--base_name', help='Optional base name for output files (without extension).')
    args = parser.parse_args()
    
    try:
        process_yolo(args.input_image, args.weights_file, args.output_dir, base_name=args.base_name)
    except Exception as e:
        print(e)
        sys.exit(1)