Spaces:
Runtime error
Runtime error
import cv2 | |
import torch | |
import numpy as np | |
from yolov5.models.experimental import attempt_load | |
from yolov5.utils.general import non_max_suppression | |
import sys | |
from yolov5.models.common import DetectMultiBackend | |
torch.cuda.empty_cache() | |
weights = "./model/yolov5n6_RGB_D2304-v1_9C.pt" | |
model = DetectMultiBackend(weights) | |
model.eval() | |
colors = np.random.randint(0, 256, (32, 3)).tolist() | |
# Set device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Define class labels | |
#class_names = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] | |
class_names = ["ANIMAL", "BOAT", "BUOY", "FAR_AWAY_OBJECT", "FLOTSAM", "HUMAN", "LEISURE_VEHICLE", "SAILING_BOAT", "SHIP"] | |
# Set detection threshold | |
#conf_thres = 0.2 | |
# Define the classes we want to detect | |
classes_of_interest = class_names | |
def identifications(frame, iou_threshold, conf_thres): | |
# Resize frame | |
img = cv2.resize(frame, (640, 640)) | |
# Convert color format | |
img = img[:, :, ::-1].transpose(2, 0, 1) | |
img = np.ascontiguousarray(img) | |
# Convert to torch tensor | |
img = torch.from_numpy(img).float().to(device) | |
img /= 255.0 | |
if img.ndimension() == 3: | |
img = img.unsqueeze(0) | |
# [1,3,502,848] | |
# Detect objects in the image | |
pred = model(img)[0] | |
pred = non_max_suppression(pred, conf_thres, iou_threshold) | |
ret_preds = [] | |
# Draw bounding boxes around detected objects | |
for det in pred[0]: | |
x1, y1, x2, y2, conf, cls = det.cpu().numpy() | |
label = class_names[int(cls)] | |
if label not in classes_of_interest: | |
continue | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
#ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)])) | |
ret_preds.append((x1, y1, x2, y2, conf, class_names[int(cls)])) | |
return ret_preds | |
# img = cv2.imread('download.jpeg') | |
# # perform object identification on the frame | |
# preds = identifications(img) | |
# print(preds) |