import cv2 import torch import numpy as np from yolov5.models.experimental import attempt_load from yolov5.utils.general import non_max_suppression import sys from yolov5.models.common import DetectMultiBackend torch.cuda.empty_cache() weights = "./model/yolov5n6_RGB_D2304-v1_9C.pt" model = DetectMultiBackend(weights) model.eval() colors = np.random.randint(0, 256, (32, 3)).tolist() # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define class labels #class_names = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] class_names = ["ANIMAL", "BOAT", "BUOY", "FAR_AWAY_OBJECT", "FLOTSAM", "HUMAN", "LEISURE_VEHICLE", "SAILING_BOAT", "SHIP"] # Set detection threshold #conf_thres = 0.2 # Define the classes we want to detect classes_of_interest = class_names def identifications(frame, iou_threshold, conf_thres): # Resize frame img = cv2.resize(frame, (640, 640)) # Convert color format img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img) # Convert to torch tensor img = torch.from_numpy(img).float().to(device) img /= 255.0 if img.ndimension() == 3: img = img.unsqueeze(0) # [1,3,502,848] # Detect objects in the image pred = model(img)[0] pred = non_max_suppression(pred, conf_thres, iou_threshold) ret_preds = [] # Draw bounding boxes around detected objects for det in pred[0]: x1, y1, x2, y2, conf, cls = det.cpu().numpy() label = class_names[int(cls)] if label not in classes_of_interest: continue x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) #ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)])) ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)])) return ret_preds # img = cv2.imread('download.jpeg') # # perform object identification on the frame # preds = identifications(img) # print(preds)