File size: 2,758 Bytes
2b84d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import cv2
import torch
import numpy as np
from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression
import sys

from yolov5.models.common import DetectMultiBackend

torch.cuda.empty_cache()
 
weights = "./model/yolov5n6_RGB_D2304-v1_9C.pt"
model = DetectMultiBackend(weights)
 
model.eval()

colors = np.random.randint(0, 256, (32, 3)).tolist()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define class labels
#class_names = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
class_names = ["ANIMAL", "BOAT", "BUOY", "FAR_AWAY_OBJECT", "FLOTSAM", "HUMAN", "LEISURE_VEHICLE", "SAILING_BOAT", "SHIP"]
# Set detection threshold
#conf_thres = 0.2

# Define the classes we want to detect
classes_of_interest = class_names

def identifications(frame, iou_threshold, conf_thres):
    # Resize frame
    img = cv2.resize(frame, (640, 640))
    # Convert color format
    img = img[:, :, ::-1].transpose(2, 0, 1)
    img = np.ascontiguousarray(img)

    # Convert to torch tensor
    img = torch.from_numpy(img).float().to(device)
    img /= 255.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    # [1,3,502,848]
    # Detect objects in the image
    pred = model(img)[0]
    pred = non_max_suppression(pred, conf_thres, iou_threshold)
    ret_preds = []
    # Draw bounding boxes around detected objects
    for det in pred[0]:
        x1, y1, x2, y2, conf, cls = det.cpu().numpy()
        label = class_names[int(cls)]
        if label not in classes_of_interest:
            continue
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        #ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)]))
        ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)]))

    return ret_preds


# img = cv2.imread('download.jpeg')

# # perform object identification on the frame
# preds = identifications(img)

# print(preds)