Spaces:
Runtime error
Runtime error
import cv2 | |
import torch | |
import numpy as np | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.image import img_to_array | |
from collections import Counter | |
# Load the trained MoBiLSTM model for violence detection | |
violence_model = load_model('MoBiLSTM_model.h5') # Replace with the path to your model | |
# Load YOLOv5 model for crowd detection | |
yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5m') # Load the small YOLOv5 model | |
# Define constants for frame extraction | |
IMAGE_HEIGHT, IMAGE_WIDTH = 64, 64 # Adjust based on your model's input size | |
SEQUENCE_LENGTH = 16 # Number of frames to pass for sequence input | |
# Function to extract frames from the video, skipping 2 frames (process every 3rd frame) | |
def extract_frames(video_path): | |
frames_list = [] | |
video_reader = cv2.VideoCapture(video_path) | |
total_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT)) | |
skip_frames_window = max(int(total_frames / SEQUENCE_LENGTH), 1) | |
for frame_counter in range(SEQUENCE_LENGTH): | |
# Skip 2 frames (process every 3rd frame) | |
video_reader.set(cv2.CAP_PROP_POS_FRAMES, frame_counter * 3 * skip_frames_window) | |
ret, frame = video_reader.read() | |
if not ret: | |
break | |
resized_frame = cv2.resize(frame, (IMAGE_HEIGHT, IMAGE_WIDTH)) | |
normalized_frame = resized_frame / 255.0 # Normalize to [0,1] | |
frames_list.append(normalized_frame) | |
video_reader.release() | |
# If fewer frames, add dummy frames (zeros) | |
if len(frames_list) < SEQUENCE_LENGTH: | |
frames_list.extend([np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3))] * (SEQUENCE_LENGTH - len(frames_list))) | |
return np.array(frames_list) | |
# Function to predict violence in the video | |
def predict_video_class(video_path): | |
frames = extract_frames(video_path) | |
frames = np.expand_dims(frames, axis=0) # Add batch dimension (1, SEQUENCE_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH, 3) | |
# Predict the class (0 = Non-Violence, 1 = Violence) | |
prediction = violence_model.predict(frames) | |
class_index = np.argmax(prediction, axis=1)[0] | |
return class_index | |
# Function to detect crowd size using YOLOv5 | |
def detect_crowd(video_path): | |
total_person_count = 0 | |
frame_counter = 0 | |
video_reader = cv2.VideoCapture(video_path) | |
while True: | |
ret, frame = video_reader.read() | |
if not ret: | |
break | |
frame_counter += 1 | |
# Skip 2 frames (process every 3rd frame) | |
if frame_counter % 3 == 0: | |
# Use YOLOv5 to detect objects (people in this case) | |
results = yolo_model(frame) | |
# Get results in pandas format for easier access | |
df = results.pandas().xywh[0] # Get the dataframe of detections for the first frame | |
# Filter out only "person" detections | |
person_detections = df[df['name'] == 'person'] | |
# Count the number of people detected | |
person_count = person_detections.shape[0] | |
# Accumulate the total number of people detected | |
total_person_count += person_count | |
video_reader.release() | |
# Calculate the average crowd size by dividing by the number of frames processed | |
average_crowd_count = total_person_count * 3 / frame_counter if frame_counter > 0 else 0 | |
# Round the crowd count to the nearest higher integer (ceiling) | |
rounded_crowd_count = np.ceil(average_crowd_count) | |
# Classify crowd size based on the average number of people detected | |
if rounded_crowd_count > 10: | |
crowd_class = "Large Crowd" | |
elif rounded_crowd_count > 3: | |
crowd_class = "Small Crowd" | |
else: | |
crowd_class = "No Crowd" | |
return crowd_class, rounded_crowd_count | |
# Main function to analyze both violence and crowd size | |
def analyze_video(video_path): | |
# Get violence prediction (0 = Non-Violence, 1 = Violence) | |
violence_class = predict_video_class(video_path) | |
if violence_class == 0: | |
violence_status = "Non-Violence" | |
else: | |
violence_status = "Violence" | |
# Get crowd detection | |
crowd_status, crowd_count = detect_crowd(video_path) | |
return violence_status, crowd_status, crowd_count | |
# Example usage | |
video_path = r'C:\Users\Asus\Downloads\Project\2\1107342075-preview.mp4' # Replace with the path to your test video | |
violence_status, crowd_status, crowd_count = analyze_video(video_path) | |
print(f"Violence Status: {violence_status}") | |
print(f"Crowd Status: {crowd_status}") | |
print(f"Crowd Count (rounded): {crowd_count}") | |