karthikmn's picture
Update app.py
8f54527 verified
import cv2
import torch
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from collections import Counter
# Load the trained MoBiLSTM model for violence detection
violence_model = load_model('MoBiLSTM_model.h5') # Replace with the path to your model
# Load YOLOv5 model for crowd detection
yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5m') # Load the small YOLOv5 model
# Define constants for frame extraction
IMAGE_HEIGHT, IMAGE_WIDTH = 64, 64 # Adjust based on your model's input size
SEQUENCE_LENGTH = 16 # Number of frames to pass for sequence input
# Function to extract frames from the video, skipping 2 frames (process every 3rd frame)
def extract_frames(video_path):
frames_list = []
video_reader = cv2.VideoCapture(video_path)
total_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
skip_frames_window = max(int(total_frames / SEQUENCE_LENGTH), 1)
for frame_counter in range(SEQUENCE_LENGTH):
# Skip 2 frames (process every 3rd frame)
video_reader.set(cv2.CAP_PROP_POS_FRAMES, frame_counter * 3 * skip_frames_window)
ret, frame = video_reader.read()
if not ret:
break
resized_frame = cv2.resize(frame, (IMAGE_HEIGHT, IMAGE_WIDTH))
normalized_frame = resized_frame / 255.0 # Normalize to [0,1]
frames_list.append(normalized_frame)
video_reader.release()
# If fewer frames, add dummy frames (zeros)
if len(frames_list) < SEQUENCE_LENGTH:
frames_list.extend([np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3))] * (SEQUENCE_LENGTH - len(frames_list)))
return np.array(frames_list)
# Function to predict violence in the video
def predict_video_class(video_path):
frames = extract_frames(video_path)
frames = np.expand_dims(frames, axis=0) # Add batch dimension (1, SEQUENCE_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH, 3)
# Predict the class (0 = Non-Violence, 1 = Violence)
prediction = violence_model.predict(frames)
class_index = np.argmax(prediction, axis=1)[0]
return class_index
# Function to detect crowd size using YOLOv5
def detect_crowd(video_path):
total_person_count = 0
frame_counter = 0
video_reader = cv2.VideoCapture(video_path)
while True:
ret, frame = video_reader.read()
if not ret:
break
frame_counter += 1
# Skip 2 frames (process every 3rd frame)
if frame_counter % 3 == 0:
# Use YOLOv5 to detect objects (people in this case)
results = yolo_model(frame)
# Get results in pandas format for easier access
df = results.pandas().xywh[0] # Get the dataframe of detections for the first frame
# Filter out only "person" detections
person_detections = df[df['name'] == 'person']
# Count the number of people detected
person_count = person_detections.shape[0]
# Accumulate the total number of people detected
total_person_count += person_count
video_reader.release()
# Calculate the average crowd size by dividing by the number of frames processed
average_crowd_count = total_person_count * 3 / frame_counter if frame_counter > 0 else 0
# Round the crowd count to the nearest higher integer (ceiling)
rounded_crowd_count = np.ceil(average_crowd_count)
# Classify crowd size based on the average number of people detected
if rounded_crowd_count > 10:
crowd_class = "Large Crowd"
elif rounded_crowd_count > 3:
crowd_class = "Small Crowd"
else:
crowd_class = "No Crowd"
return crowd_class, rounded_crowd_count
# Main function to analyze both violence and crowd size
def analyze_video(video_path):
# Get violence prediction (0 = Non-Violence, 1 = Violence)
violence_class = predict_video_class(video_path)
if violence_class == 0:
violence_status = "Non-Violence"
else:
violence_status = "Violence"
# Get crowd detection
crowd_status, crowd_count = detect_crowd(video_path)
return violence_status, crowd_status, crowd_count
# Example usage
video_path = r'C:\Users\Asus\Downloads\Project\2\1107342075-preview.mp4' # Replace with the path to your test video
violence_status, crowd_status, crowd_count = analyze_video(video_path)
print(f"Violence Status: {violence_status}")
print(f"Crowd Status: {crowd_status}")
print(f"Crowd Count (rounded): {crowd_count}")