#!/usr/bin/env python3 """ Real-time pose classifier Uses MediaPipe to capture camera input, perform pose recognition and classification, and display results on screen Features: 1. Use MediaPipe to obtain real-time pose data from camera 2. Extract joint coordinates and preprocess them 3. Use trained machine learning models for pose classification 4. Display classification results and keypoints in real-time on video screen Dependencies: pip install opencv-python mediapipe numpy scikit-learn Usage: python realtime_pose_classifier.py [--model MODEL_PATH] [--camera CAMERA_ID] """ import cv2 import mediapipe as mp import numpy as np import json import joblib import argparse import time from pathlib import Path import traceback class RealtimePoseClassifier: def __init__(self, model_path=None, camera_id=0): """ Initialize real-time pose classifier Args: model_path (str): Model file path, auto-detect if None camera_id (int): Camera ID, default 0 """ self.camera_id = camera_id # Initialize MediaPipe self.mp_pose = mp.solutions.pose self.mp_drawing = mp.solutions.drawing_utils self.mp_drawing_styles = mp.solutions.drawing_styles # Configure pose detector self.pose = self.mp_pose.Pose( static_image_mode=False, model_complexity=1, # Use lower complexity for real-time applications enable_segmentation=False, min_detection_confidence=0.7, min_tracking_confidence=0.5 ) # MediaPipe landmark name mapping self.landmark_names = [ 'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer', 'right_eye_inner', 'right_eye', 'right_eye_outer', 'left_ear', 'right_ear', 'mouth_left', 'mouth_right', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky', 'left_index', 'right_index', 'left_thumb', 'right_thumb', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_heel', 'right_heel', 'left_foot_index', 'right_foot_index' ] # Load model self.model = None self.scaler = None self.label_encoder = None self.target_joints = None self.model_info = None self.load_model(model_path) # Prediction result cache self.prediction_history = [] self.history_size = 5 # Keep recent 5 predictions for smoothing # Performance statistics self.fps_counter = 0 self.fps_start_time = time.time() self.current_fps = 0 # Added: Time statistics self.mediapipe_time_total = 0.0 self.mediapipe_time_count = 0 self.feature_pred_time_total = 0.0 self.feature_pred_time_count = 0 # Display settings self.show_landmarks = True self.show_connections = True def load_model(self, model_path=None): """Load trained model""" if model_path is None: # Auto-detect available model files possible_models = [ 'pose_classifier_random_forest.pkl', 'pose_classifier_logistic.pkl', 'pose_classifier_distilled_rf.pkl' ] for model_file in possible_models: if Path(model_file).exists(): model_path = model_file break if model_path is None: raise FileNotFoundError("No available model file found, please specify model path") try: print(f"Loading model: {model_path}") model_data = joblib.load(model_path) self.model = model_data['model'] self.scaler = model_data['scaler'] self.label_encoder = model_data['label_encoder'] self.target_joints = model_data['target_joints'] # Try to load corresponding labels file labels_path = model_path.replace('.pkl', '_labels.json') if Path(labels_path).exists(): with open(labels_path, 'r') as f: self.model_info = json.load(f) print(f"Loaded label information: {labels_path}") print("Model loaded successfully!") print(f"Target joints: {self.target_joints}") print(f"Classification classes: {self.label_encoder.classes_}") except Exception as e: raise RuntimeError(f"Model loading failed: {e}") def extract_pose_features(self, landmarks): """ Extract pose features from MediaPipe landmarks (vectorized optimized version) """ if landmarks is None: return None # Get all joint coordinates as NumPy array coords = np.array([[lm.x, lm.y, lm.z] for lm in landmarks.landmark], dtype=np.float32) # Get head position (nose as reference point) try: head_idx = self.landmark_names.index('nose') head_pos = coords[head_idx] except ValueError: return None # Build target joint indices list joint_indices = [self.landmark_names.index(j) if j in self.landmark_names else -1 for j in self.target_joints] # Extract target joint coordinates (fill with 0 if not exist) joint_coords = np.array([ coords[idx] if idx >= 0 else np.zeros(3, dtype=np.float32) for idx in joint_indices ], dtype=np.float32) # Calculate relative position to head and scale relative_coords = (joint_coords - head_pos) * 100 # Keep consistent with training processing # Keep two decimal places features = np.round(relative_coords, 2).flatten() return features def predict_pose(self, features): """ Use machine learning model to predict pose Args: features: Feature vector Returns: dict: Prediction result containing label, confidence, etc. """ if features is None or self.model is None: return None try: # Standardize features features_scaled = self.scaler.transform(features.reshape(1, -1)) # Predict prediction = self.model.predict(features_scaled)[0] predicted_label = self.label_encoder.inverse_transform([prediction])[0] # Get confidence (if model supports probability prediction) confidence = 0.0 probabilities = None if hasattr(self.model, 'predict_proba'): probs = self.model.predict_proba(features_scaled)[0] confidence = float(np.max(probs)) probabilities = dict(zip(self.label_encoder.classes_, probs)) return { 'predicted_label': predicted_label, 'confidence': confidence, 'probabilities': probabilities } except Exception as e: print(f"Prediction error: {e}") return None def smooth_predictions(self, current_prediction): """ Smooth prediction results Args: current_prediction: Current prediction result Returns: dict: Smoothed prediction result """ if current_prediction is None: return None # Add to history self.prediction_history.append(current_prediction) if len(self.prediction_history) > self.history_size: self.prediction_history.pop(0) # If history is insufficient, return current prediction directly if len(self.prediction_history) < 3: return current_prediction # Count recent prediction labels recent_labels = [pred['predicted_label'] for pred in self.prediction_history] # Use mode as final prediction from collections import Counter label_counts = Counter(recent_labels) most_common_label = label_counts.most_common(1)[0][0] # Calculate average confidence for this label avg_confidence = np.mean([ pred['confidence'] for pred in self.prediction_history if pred['predicted_label'] == most_common_label ]) return { 'predicted_label': most_common_label, 'confidence': avg_confidence, 'stability': label_counts[most_common_label] / len(recent_labels) } def draw_pose_info(self, image, landmarks, prediction_result): """ Draw pose information on image Args: image: OpenCV image landmarks: MediaPipe landmarks prediction_result: Prediction result """ height, width = image.shape[:2] # Draw pose skeleton if landmarks and self.show_connections: self.mp_drawing.draw_landmarks( image, landmarks, self.mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=self.mp_drawing_styles.get_default_pose_landmarks_style() ) # Draw keypoints if landmarks and self.show_landmarks: for i, landmark in enumerate(landmarks.landmark): if self.landmark_names[i] in self.target_joints: x = int(landmark.x * width) y = int(landmark.y * height) cv2.circle(image, (x, y), 8, (0, 255, 0), -1) cv2.putText(image, self.landmark_names[i], (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) # Display prediction results if prediction_result: label = prediction_result['predicted_label'] confidence = prediction_result.get('confidence', 0.0) stability = prediction_result.get('stability', 1.0) # Set color based on confidence if confidence > 0.8: color = (0, 255, 0) # Green - high confidence elif confidence > 0.6: color = (0, 255, 255) # Yellow - medium confidence else: color = (0, 0, 255) # Red - low confidence # Draw prediction result background box cv2.rectangle(image, (10, 10), (400, 120), (0, 0, 0), -1) cv2.rectangle(image, (10, 10), (400, 120), color, 2) # Display prediction label cv2.putText(image, f"Pose: {label}", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) # Display confidence cv2.putText(image, f"Confidence: {confidence:.2f}", (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) # Display stability cv2.putText(image, f"Stability: {stability:.2f}", (20, 95), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) # Display FPS cv2.putText(image, f"FPS: {self.current_fps:.1f}", (width - 150, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) # Display control instructions instructions = [ "Controls:", "Q - Quit", "L - Toggle Landmarks", "C - Toggle Connections", "R - Reset History" ] for i, instruction in enumerate(instructions): cv2.putText(image, instruction, (width - 200, height - 120 + i * 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1) # Added: Display timing statistics mp_avg = self.mediapipe_time_total / self.mediapipe_time_count if self.mediapipe_time_count else 0.0 fp_avg = self.feature_pred_time_total / self.feature_pred_time_count if self.feature_pred_time_count else 0.0 cv2.putText(image, f"MP avg: {mp_avg*1000:.1f}ms", (width - 150, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) cv2.putText(image, f"FP avg: {fp_avg*1000:.1f}ms", (width - 150, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) # Display average frame rate total_frames = max(self.mediapipe_time_count, 1) avg_fps = total_frames / max(self.mediapipe_time_total + self.feature_pred_time_total, 1e-6) cv2.putText(image, f"Avg FPS: {avg_fps:.1f}", (width - 150, 95), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) def update_fps(self): """Update FPS calculation""" self.fps_counter += 1 if self.fps_counter >= 30: # Update FPS every 30 frames current_time = time.time() self.current_fps = 30 / (current_time - self.fps_start_time) self.fps_start_time = current_time self.fps_counter = 0 def run(self): """Run real-time pose classification""" print("Starting real-time pose classifier...") print("Press 'Q' to quit, 'L' to toggle landmark display, 'C' to toggle skeleton connections, 'R' to reset history") # Initialize camera cap = cv2.VideoCapture(self.camera_id) if not cap.isOpened(): raise RuntimeError(f"Cannot open camera {self.camera_id}") # Set camera parameters cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) cap.set(cv2.CAP_PROP_FPS, 30) try: while True: success, frame = cap.read() if not success: print("Cannot read camera frame") break # Flip image horizontally (mirror effect) frame = cv2.flip(frame, 1) # Convert color space rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Time MediaPipe pose detection mp_start = time.time() results = self.pose.process(rgb_frame) mp_end = time.time() self.mediapipe_time_total += (mp_end - mp_start) self.mediapipe_time_count += 1 # Extract features and predict fp_start = time.time() prediction_result = None if results.pose_landmarks: features = self.extract_pose_features(results.pose_landmarks) if features is not None: raw_prediction = self.predict_pose(features) prediction_result = self.smooth_predictions(raw_prediction) fp_end = time.time() self.feature_pred_time_total += (fp_end - fp_start) self.feature_pred_time_count += 1 # Draw results self.draw_pose_info(frame, results.pose_landmarks, prediction_result) # Update FPS self.update_fps() # Display image cv2.imshow('Real-time Pose Classification', frame) # Handle key presses key = cv2.waitKey(1) & 0xFF if key == ord('q') or key == ord('Q'): break elif key == ord('l') or key == ord('L'): self.show_landmarks = not self.show_landmarks print(f"Landmark display: {'On' if self.show_landmarks else 'Off'}") elif key == ord('c') or key == ord('C'): self.show_connections = not self.show_connections print(f"Skeleton connection display: {'On' if self.show_connections else 'Off'}") elif key == ord('r') or key == ord('R'): self.prediction_history.clear() print("Prediction history reset") except KeyboardInterrupt: print("\nUser interrupted program") except Exception as e: print(f"Runtime error: {e}") traceback.print_exc() finally: cap.release() cv2.destroyAllWindows() print("Program exited") def main(): """Main function""" parser = argparse.ArgumentParser(description='Real-time pose classifier') parser.add_argument('--model', '-m', type=str, default=None, help='Model file path (auto-detect by default)') parser.add_argument('--camera', '-c', type=int, default=0, help='Camera ID (default 0)') args = parser.parse_args() try: classifier = RealtimePoseClassifier( model_path=args.model, camera_id=args.camera ) classifier.run() except Exception as e: print(f"Program startup failed: {e}") return 1 return 0 if __name__ == "__main__": exit(main())