VicAndTheBoys / yolo_detect.py
VascoDVRodrigues
cringe
2b84d47
import cv2
import torch
import numpy as np
from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression
import sys
from yolov5.models.common import DetectMultiBackend
torch.cuda.empty_cache()
weights = "./model/yolov5n6_RGB_D2304-v1_9C.pt"
model = DetectMultiBackend(weights)
model.eval()
colors = np.random.randint(0, 256, (32, 3)).tolist()
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define class labels
#class_names = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
class_names = ["ANIMAL", "BOAT", "BUOY", "FAR_AWAY_OBJECT", "FLOTSAM", "HUMAN", "LEISURE_VEHICLE", "SAILING_BOAT", "SHIP"]
# Set detection threshold
#conf_thres = 0.2
# Define the classes we want to detect
classes_of_interest = class_names
def identifications(frame, iou_threshold, conf_thres):
# Resize frame
img = cv2.resize(frame, (640, 640))
# Convert color format
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
# Convert to torch tensor
img = torch.from_numpy(img).float().to(device)
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# [1,3,502,848]
# Detect objects in the image
pred = model(img)[0]
pred = non_max_suppression(pred, conf_thres, iou_threshold)
ret_preds = []
# Draw bounding boxes around detected objects
for det in pred[0]:
x1, y1, x2, y2, conf, cls = det.cpu().numpy()
label = class_names[int(cls)]
if label not in classes_of_interest:
continue
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
#ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)]))
ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)]))
return ret_preds
# img = cv2.imread('download.jpeg')
# # perform object identification on the frame
# preds = identifications(img)
# print(preds)