#!/usr/bin/env python3 """ Use MediaPipe to detect poses in images and extract landmark coordinates. Features: 1. Run MediaPipe pose detection on images in the train folder 2. Use the nose as the head reference point (headPos) 3. Process coordinates as: pos = (pos - headPos) * 100 and round to 2 decimals 4. Save processed landmarks into JSON files named after the image files Usage: python pose_detection.py [--input INPUT_DIR] [--output OUTPUT_DIR] """ import os import json import argparse from pathlib import Path import cv2 import mediapipe as mp class PoseDetector: def __init__(self): """Initialize MediaPipe pose detector.""" self.mp_pose = mp.solutions.pose self.pose = self.mp_pose.Pose( static_image_mode=True, model_complexity=2, enable_segmentation=False, min_detection_confidence=0.5 ) # MediaPipe pose landmark name mapping self.landmark_names = [ 'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer', 'right_eye_inner', 'right_eye', 'right_eye_outer', 'left_ear', 'right_ear', 'mouth_left', 'mouth_right', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky', 'left_index', 'right_index', 'left_thumb', 'right_thumb', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_heel', 'right_heel', 'left_foot_index', 'right_foot_index' ] def get_head_position(self, landmarks): """ Compute the head reference position (use the nose landmark). Args: landmarks: MediaPipe detected landmarks Returns: tuple: (x, y, z) head coordinates """ # use nose as the head reference point nose = landmarks[0] # nose is the 0th landmark return (nose.x, nose.y, nose.z) def process_landmarks(self, landmarks, head_pos): """ Process landmarks: pos = (pos - headPos) * 100 and round to 2 decimals. Args: landmarks: MediaPipe detected landmarks head_pos: head coordinates (x, y, z) Returns: dict: processed landmarks dictionary """ processed_landmarks = {} head_pos_x = head_pos[0] head_pos_y = head_pos[1] head_pos_z = head_pos[2] for i, landmark in enumerate(landmarks): if i < len(self.landmark_names): name = self.landmark_names[i] # Calculate coordinates relative to head and multiply by 100 rel_x = round((landmark.x - head_pos_x) * 100, 2) rel_y = round((landmark.y - head_pos_y) * 100, 2) rel_z = round((landmark.z - head_pos_z) * 100, 2) processed_landmarks[name] = { 'x': rel_x, 'y': rel_y, 'z': rel_z, 'visibility': round(landmark.visibility, 3) } return processed_landmarks def detect_pose(self, image_path): """ Detect pose for a single image. Args: image_path: path to the image file Returns: dict: processed landmarks and metadata, or None on failure """ try: # Read image image = cv2.imread(str(image_path)) if image is None: print(f"Unable to read image: {image_path}") return None # Convert color space (BGR -> RGB) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Run pose detection results = self.pose.process(image_rgb) if results.pose_landmarks is None: print(f"No pose detected: {image_path}") return None # Get keypoints landmarks = results.pose_landmarks.landmark # Get head position head_pos = self.get_head_position(landmarks) # Process keypoint coordinates processed_landmarks = self.process_landmarks(landmarks, head_pos) # extract label from parent folder name label = image_path.parent.name # Add metadata result = { 'image_path': str(image_path), 'image_name': image_path.name, 'label': label, 'head_position': { 'x': round(head_pos[0], 4), 'y': round(head_pos[1], 4), 'z': round(head_pos[2], 4) }, 'landmarks': processed_landmarks, 'total_landmarks': len(processed_landmarks) } return result except Exception as e: print(f"Error processing image {image_path}: {e}") return None def close(self): """Close MediaPipe resources.""" self.pose.close() def process_all_training_data(input_dir, output_dir, batch_size=100): """ Process all images in the training dataset and write JSON files. Args: input_dir: input images directory (TrainData/train) output_dir: output JSON directory (PoseData) batch_size: progress report batch size """ input_path = Path(input_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Supported image formats image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} detector = PoseDetector() try: # statistics total_images = 0 success_count = 0 failed_count = 0 label_stats = {} print(f"Starting processing dataset: {input_path}") print(f"Output directory: {output_path}") # first count all images print("Counting images...") label_dirs = [] for item in input_path.iterdir(): if item.is_dir() and item.name.startswith('label_'): label = item.name image_files = [f for f in item.iterdir() if f.is_file() and f.suffix.lower() in image_extensions] if image_files: label_dirs.append((item, label, image_files)) total_images += len(image_files) label_stats[label] = {'total': len(image_files), 'success': 0, 'failed': 0} print(f"Found {len(label_dirs)} label directories, total {total_images} images") for label, stats in label_stats.items(): print(f" {label}: {stats['total']} images") print("\nStarting to process images...") # process each label directory for label_dir, label_name, image_files in label_dirs: print(f"\n--- Processing {label_name} ({len(image_files)} images) ---") # create output folder for this label output_label_dir = output_path / label_name output_label_dir.mkdir(parents=True, exist_ok=True) # process every image in this label for i, image_file in enumerate(image_files, 1): json_filename = image_file.stem + '.json' json_path = output_label_dir / json_filename # detect pose result = detector.detect_pose(image_file) if result is not None: # save JSON try: with open(json_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) success_count += 1 label_stats[label_name]['success'] += 1 # progress if success_count % batch_size == 0: progress = (success_count / total_images) * 100 if total_images else 0 print(f" Progress: {success_count}/{total_images} ({progress:.1f}%) - Current: {label_name} {i}/{len(image_files)}") except Exception as e: print(f" Failed to save JSON {json_path}: {e}") failed_count += 1 label_stats[label_name]['failed'] += 1 else: failed_count += 1 label_stats[label_name]['failed'] += 1 if failed_count % 10 == 0: # print every 10 failures print(f" Detection failed: {image_file.name}") # report for this label stats = label_stats[label_name] success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 print(f" {label_name} Done: Success {stats['success']}, Failed {stats['failed']}, Success rate: {success_rate:.1f}%") print("\n" + "=" * 60) print("Processing complete!") print(f"Total images: {total_images}") print(f"Successfully processed: {success_count}") print(f"Failed: {failed_count}") total_success_rate = (success_count / total_images) * 100 if total_images > 0 else 0 print(f"Overall success rate: {total_success_rate:.1f}%") print("\nPer-label statistics:") for label, stats in label_stats.items(): success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 print(f" {label}: {stats['success']}/{stats['total']} ({success_rate:.1f}%)") print(f"\nJSON files saved to: {output_path.absolute()}") print("Directory structure:") print("PoseData/") for label in sorted(label_stats.keys()): print(f"├── {label}/") print("│ └── *.json") finally: detector.close() def process_directory(input_dir, output_dir): """ Process all images in a directory tree and write JSON files. Args: input_dir: input images directory output_dir: output JSON directory """ input_path = Path(input_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Supported image formats image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} detector = PoseDetector() try: # statistics total_images = 0 success_count = 0 failed_count = 0 print(f"Starting to process directory: {input_path}") print(f"Output directory: {output_path}") # walk through the tree for root, dirs, files in os.walk(input_path): root_path = Path(root) # create corresponding output folder relative_path = root_path.relative_to(input_path) current_output_dir = output_path / relative_path current_output_dir.mkdir(parents=True, exist_ok=True) # collect image files in this folder image_files = [f for f in files if Path(f).suffix.lower() in image_extensions] if image_files: print(f"\nProcessing directory: {root_path}") print(f"Found {len(image_files)} images") for filename in image_files: total_images += 1 image_path = root_path / filename # generate JSON filename (replace extension with .json) json_filename = Path(filename).stem + '.json' json_path = current_output_dir / json_filename # detect pose result = detector.detect_pose(image_path) if result is not None: # save JSON file try: with open(json_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) success_count += 1 if success_count % 50 == 0: print(f"Successfully processed {success_count} images...") except Exception as e: print(f"Failed to save JSON {json_path}: {e}") failed_count += 1 else: failed_count += 1 print("\nProcessing complete!") print(f"Total images: {total_images}") print(f"Successfully processed: {success_count}") print(f"Failed: {failed_count}") print(f"Success rate: {success_count/total_images*100:.1f}%") finally: detector.close() def main(): parser = argparse.ArgumentParser(description="Run MediaPipe pose detection and save landmark data") parser.add_argument("--input", "-i", default="TrainData/train", help="input images directory (default: TrainData/train)") parser.add_argument("--output", "-o", default="PoseData", help="output JSON directory (default: PoseData)") parser.add_argument("--batch-size", "-b", type=int, default=100, help="batch size for progress reporting (default: 100)") args = parser.parse_args() # check input directory exists if not Path(args.input).exists(): print(f"Error: input directory does not exist: {args.input}") return print("MediaPipe pose detection tool") print("=" * 60) print(f"Input directory: {args.input}") print(f"Output directory: {args.output}") print("Processing rule: pos = (pos - headPos) * 100, round to 2 decimals") print("Head reference: nose") print(f"Batch size: show progress every {args.batch_size} images") print("=" * 60) # Start processing the entire training dataset process_all_training_data(args.input, args.output, args.batch_size) if __name__ == "__main__": main()