train and test python script
Browse files- extract_images.py +166 -0
- ml_pose_classifier.py +1121 -0
- pose_detection.py +382 -0
- realtime_pose_classifier.py +456 -0
extract_images.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Extract images and labels from Parquet files and save them into
|
| 3 |
+
subfolders by label.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python extract_images.py [--train] [--test] [--output OUTPUT_DIR]
|
| 7 |
+
|
| 8 |
+
Defaults:
|
| 9 |
+
train: process training data (train-00000-of-00001.parquet)
|
| 10 |
+
test: process test data (test-00000-of-00001.parquet)
|
| 11 |
+
output: TrainData (relative to script location)
|
| 12 |
+
"""
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import argparse
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import pyarrow.parquet as pq
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def extract_images_from_parquet(parquet_path, output_dir, split_name):
|
| 21 |
+
"""Extract images from a Parquet file and save them into label folders."""
|
| 22 |
+
|
| 23 |
+
print(f"Processing {parquet_path}...")
|
| 24 |
+
|
| 25 |
+
# read parquet file
|
| 26 |
+
try:
|
| 27 |
+
table = pq.read_table(parquet_path)
|
| 28 |
+
df = table.to_pandas()
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Failed to read parquet file: {e}")
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
print(f"Found {len(df)} images")
|
| 34 |
+
|
| 35 |
+
# get unique labels
|
| 36 |
+
unique_labels = sorted(df['label'].unique())
|
| 37 |
+
print(f"Label classes: {unique_labels}")
|
| 38 |
+
|
| 39 |
+
# create folder for each label
|
| 40 |
+
for label in unique_labels:
|
| 41 |
+
label_dir = output_dir / split_name / f"label_{label}"
|
| 42 |
+
label_dir.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
print(f"Created folder: {label_dir}")
|
| 44 |
+
|
| 45 |
+
# extract and save images
|
| 46 |
+
success_count = 0
|
| 47 |
+
error_count = 0
|
| 48 |
+
|
| 49 |
+
for idx, row in df.iterrows():
|
| 50 |
+
try:
|
| 51 |
+
# get image data
|
| 52 |
+
image_struct = row['image']
|
| 53 |
+
image_bytes = image_struct['bytes']
|
| 54 |
+
original_path = image_struct['path']
|
| 55 |
+
label = row['label']
|
| 56 |
+
|
| 57 |
+
# get file extension
|
| 58 |
+
_, ext = os.path.splitext(original_path)
|
| 59 |
+
if not ext:
|
| 60 |
+
ext = '.jpg' # default extension
|
| 61 |
+
|
| 62 |
+
# build a new filename (preserve original base name, avoid collisions)
|
| 63 |
+
base_name = os.path.splitext(os.path.basename(original_path))[0]
|
| 64 |
+
filename = f"{base_name}{ext}"
|
| 65 |
+
|
| 66 |
+
# ensure filename is unique
|
| 67 |
+
label_dir = output_dir / split_name / f"label_{label}"
|
| 68 |
+
output_path = label_dir / filename
|
| 69 |
+
counter = 1
|
| 70 |
+
while output_path.exists():
|
| 71 |
+
filename = f"{base_name}_{counter}{ext}"
|
| 72 |
+
output_path = label_dir / filename
|
| 73 |
+
counter += 1
|
| 74 |
+
|
| 75 |
+
# save image
|
| 76 |
+
with open(output_path, 'wb') as f:
|
| 77 |
+
f.write(image_bytes)
|
| 78 |
+
|
| 79 |
+
success_count += 1
|
| 80 |
+
if success_count % 100 == 0:
|
| 81 |
+
print(f"Processed {success_count} images...")
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error processing image {idx}: {e}")
|
| 85 |
+
error_count += 1
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
print(f"Done! Success: {success_count}, Failed: {error_count}")
|
| 89 |
+
|
| 90 |
+
# report counts per label
|
| 91 |
+
print("\nImage count per label:")
|
| 92 |
+
for label in unique_labels:
|
| 93 |
+
label_dir = output_dir / split_name / f"label_{label}"
|
| 94 |
+
count = len(list(label_dir.glob("*")))
|
| 95 |
+
print(f" label {label}: {count} images")
|
| 96 |
+
|
| 97 |
+
return success_count > 0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def main():
|
| 101 |
+
parser = argparse.ArgumentParser(description="Extract images from Parquet files and organize by label")
|
| 102 |
+
parser.add_argument("--train", action="store_true", help="process training data")
|
| 103 |
+
parser.add_argument("--test", action="store_true", help="process test data")
|
| 104 |
+
parser.add_argument("--output", "-o", default="TrainData", help="output directory")
|
| 105 |
+
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
# if neither train nor test specified, do both by default
|
| 109 |
+
if not args.train and not args.test:
|
| 110 |
+
args.train = True
|
| 111 |
+
args.test = True
|
| 112 |
+
|
| 113 |
+
# set paths
|
| 114 |
+
script_dir = Path(__file__).parent
|
| 115 |
+
yoga_data_dir = script_dir / "YogaDataSet" / "data"
|
| 116 |
+
output_dir = Path(args.output)
|
| 117 |
+
|
| 118 |
+
# ensure output directory exists
|
| 119 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
print(f"Output directory: {output_dir.absolute()}")
|
| 122 |
+
|
| 123 |
+
success = True
|
| 124 |
+
|
| 125 |
+
# process training data
|
| 126 |
+
if args.train:
|
| 127 |
+
train_parquet = yoga_data_dir / "train-00000-of-00001.parquet"
|
| 128 |
+
if train_parquet.exists():
|
| 129 |
+
if not extract_images_from_parquet(train_parquet, output_dir, "train"):
|
| 130 |
+
success = False
|
| 131 |
+
else:
|
| 132 |
+
print(f"Training parquet file not found: {train_parquet}")
|
| 133 |
+
success = False
|
| 134 |
+
|
| 135 |
+
# process test data
|
| 136 |
+
if args.test:
|
| 137 |
+
test_parquet = yoga_data_dir / "test-00000-of-00001.parquet"
|
| 138 |
+
if test_parquet.exists():
|
| 139 |
+
if not extract_images_from_parquet(test_parquet, output_dir, "test"):
|
| 140 |
+
success = False
|
| 141 |
+
else:
|
| 142 |
+
print(f"Test parquet file not found: {test_parquet}")
|
| 143 |
+
success = False
|
| 144 |
+
|
| 145 |
+
if success:
|
| 146 |
+
print("\n✅ All images extracted!")
|
| 147 |
+
print(f"Images saved to: {output_dir.absolute()}")
|
| 148 |
+
print("Directory structure:")
|
| 149 |
+
print("TrainData/")
|
| 150 |
+
if args.train:
|
| 151 |
+
print("├── train/")
|
| 152 |
+
print("│ ├── label_0/")
|
| 153 |
+
print("│ ├── label_1/")
|
| 154 |
+
print("│ └── ...")
|
| 155 |
+
if args.test:
|
| 156 |
+
print("└── test/")
|
| 157 |
+
print(" ├── label_0/")
|
| 158 |
+
print(" ├── label_1/")
|
| 159 |
+
print(" └── ...")
|
| 160 |
+
else:
|
| 161 |
+
print("\n❌ Errors occurred during extraction")
|
| 162 |
+
sys.exit(1)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
main()
|
ml_pose_classifier.py
ADDED
|
@@ -0,0 +1,1121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Machine learning pose classification script.
|
| 4 |
+
|
| 5 |
+
Features:
|
| 6 |
+
1. Train classifiers on pose landmark inputs
|
| 7 |
+
2. Use selected landmark coordinates as features
|
| 8 |
+
3. Use folder names as class labels
|
| 9 |
+
4. Train and evaluate models
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python ml_pose_classifier.py [--data DATA_DIR] [--model MODEL_TYPE] [--test-size RATIO]
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import argparse
|
| 17 |
+
import numpy as np
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
| 21 |
+
from sklearn.svm import SVC
|
| 22 |
+
from sklearn.linear_model import LogisticRegression
|
| 23 |
+
from sklearn.model_selection import train_test_split, cross_val_score
|
| 24 |
+
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
|
| 25 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
| 26 |
+
# from sklearn.pipeline import Pipeline # not used
|
| 27 |
+
from sklearn.neural_network import MLPRegressor
|
| 28 |
+
import joblib
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
# seaborn is optional; used only for confusion matrix plotting
|
| 31 |
+
try:
|
| 32 |
+
import seaborn as sns
|
| 33 |
+
SEABORN_AVAILABLE = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
SEABORN_AVAILABLE = False
|
| 36 |
+
|
| 37 |
+
# ONNX related imports
|
| 38 |
+
try:
|
| 39 |
+
from skl2onnx import convert_sklearn
|
| 40 |
+
from skl2onnx.common.data_types import FloatTensorType
|
| 41 |
+
# onnx is not required here; we import it lazily where needed
|
| 42 |
+
ONNX_AVAILABLE = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
ONNX_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
# ONNX Runtime import
|
| 47 |
+
try:
|
| 48 |
+
# onnxruntime is optional and not required unless ONNX runtime testing is implemented
|
| 49 |
+
ONNX_RUNTIME_AVAILABLE = False
|
| 50 |
+
except ImportError:
|
| 51 |
+
ONNX_RUNTIME_AVAILABLE = False
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class PoseClassifier:
|
| 55 |
+
def __init__(self, model_type='random_forest'):
|
| 56 |
+
"""
|
| 57 |
+
Initialize the pose classifier.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model_type: model type ('random_forest', 'svm', 'gradient_boost', 'logistic', 'distilled_rf')
|
| 61 |
+
"""
|
| 62 |
+
self.model_type = model_type
|
| 63 |
+
self.model = None
|
| 64 |
+
self.student_model = None # If distillation is used, save student (MLP) model
|
| 65 |
+
self.scaler = StandardScaler()
|
| 66 |
+
self.label_encoder = LabelEncoder()
|
| 67 |
+
|
| 68 |
+
# Define joints we want to use (based on MediaPipe keypoint indices)
|
| 69 |
+
self.target_joints = [
|
| 70 |
+
'nose', # Head (nose as reference, but will actually be 0,0,0)
|
| 71 |
+
'left_shoulder', # Left shoulder
|
| 72 |
+
'right_shoulder', # Right shoulder
|
| 73 |
+
'left_elbow', # Left elbow
|
| 74 |
+
'right_elbow', # Right elbow
|
| 75 |
+
'left_wrist', # Left wrist
|
| 76 |
+
'right_wrist', # Right wrist
|
| 77 |
+
'left_hip', # Left hip
|
| 78 |
+
'right_hip', # Right hip
|
| 79 |
+
'left_knee', # Left knee
|
| 80 |
+
'right_knee', # Right knee
|
| 81 |
+
'left_ankle', # Left ankle
|
| 82 |
+
'right_ankle' # Right ankle
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
self.feature_columns = []
|
| 86 |
+
for joint in self.target_joints:
|
| 87 |
+
self.feature_columns.extend([f'{joint}_x', f'{joint}_y', f'{joint}_z'])
|
| 88 |
+
|
| 89 |
+
print(f"Target joints: {len(self.target_joints)}")
|
| 90 |
+
print(f"Feature dimension: {len(self.feature_columns)}")
|
| 91 |
+
print("Joint list:", ', '.join(self.target_joints))
|
| 92 |
+
|
| 93 |
+
def _get_model(self):
|
| 94 |
+
"""Create a classifier based on the selected model type."""
|
| 95 |
+
if self.model_type == 'random_forest':
|
| 96 |
+
return RandomForestClassifier(
|
| 97 |
+
n_estimators=100,
|
| 98 |
+
max_depth=15,
|
| 99 |
+
min_samples_split=5,
|
| 100 |
+
min_samples_leaf=2,
|
| 101 |
+
random_state=42,
|
| 102 |
+
n_jobs=-1
|
| 103 |
+
)
|
| 104 |
+
elif self.model_type == 'svm':
|
| 105 |
+
return SVC(
|
| 106 |
+
C=1.0,
|
| 107 |
+
kernel='rbf',
|
| 108 |
+
gamma='scale',
|
| 109 |
+
random_state=42
|
| 110 |
+
)
|
| 111 |
+
elif self.model_type == 'gradient_boost':
|
| 112 |
+
return GradientBoostingClassifier(
|
| 113 |
+
n_estimators=100,
|
| 114 |
+
learning_rate=0.1,
|
| 115 |
+
max_depth=6,
|
| 116 |
+
random_state=42
|
| 117 |
+
)
|
| 118 |
+
elif self.model_type == 'logistic':
|
| 119 |
+
return LogisticRegression(
|
| 120 |
+
C=10.0, # Increase regularization parameter to improve model complexity
|
| 121 |
+
max_iter=2000, # Increase maximum iterations
|
| 122 |
+
solver='lbfgs', # Use L-BFGS solver, suitable for small datasets
|
| 123 |
+
multi_class='multinomial', # Multi-class strategy
|
| 124 |
+
random_state=42,
|
| 125 |
+
n_jobs=-1
|
| 126 |
+
)
|
| 127 |
+
elif self.model_type == 'distilled_rf':
|
| 128 |
+
# Teacher uses random forest (returns an RF for training process)
|
| 129 |
+
return RandomForestClassifier(
|
| 130 |
+
n_estimators=100,
|
| 131 |
+
max_depth=15,
|
| 132 |
+
min_samples_split=5,
|
| 133 |
+
min_samples_leaf=2,
|
| 134 |
+
random_state=42,
|
| 135 |
+
n_jobs=-1
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
| 139 |
+
|
| 140 |
+
def load_data(self, data_dir):
|
| 141 |
+
"""
|
| 142 |
+
Load pose data from JSON files
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data_dir: Data directory containing label folders
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
tuple: (feature data, labels)
|
| 149 |
+
"""
|
| 150 |
+
data_path = Path(data_dir)
|
| 151 |
+
all_features = []
|
| 152 |
+
all_labels = []
|
| 153 |
+
|
| 154 |
+
print(f"Loading data from: {data_path}")
|
| 155 |
+
|
| 156 |
+
# Iterate over each label directory
|
| 157 |
+
for label_dir in data_path.iterdir():
|
| 158 |
+
if not label_dir.is_dir() or not label_dir.name.startswith('label_'):
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
label = label_dir.name
|
| 162 |
+
json_files = list(label_dir.glob('*.json'))
|
| 163 |
+
|
| 164 |
+
print(f"Processing {label}: {len(json_files)} files")
|
| 165 |
+
|
| 166 |
+
for json_file in json_files:
|
| 167 |
+
try:
|
| 168 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 169 |
+
data = json.load(f)
|
| 170 |
+
|
| 171 |
+
landmarks = data.get('landmarks', {})
|
| 172 |
+
|
| 173 |
+
# Extract coordinates of target joints
|
| 174 |
+
features = []
|
| 175 |
+
missing_joints = []
|
| 176 |
+
|
| 177 |
+
for joint in self.target_joints:
|
| 178 |
+
if joint in landmarks:
|
| 179 |
+
joint_data = landmarks[joint]
|
| 180 |
+
features.extend([
|
| 181 |
+
joint_data.get('x', 0.0),
|
| 182 |
+
joint_data.get('y', 0.0),
|
| 183 |
+
joint_data.get('z', 0.0)
|
| 184 |
+
])
|
| 185 |
+
else:
|
| 186 |
+
# If a joint is missing, fill with zeros
|
| 187 |
+
features.extend([0.0, 0.0, 0.0])
|
| 188 |
+
missing_joints.append(joint)
|
| 189 |
+
|
| 190 |
+
if len(features) == len(self.feature_columns):
|
| 191 |
+
all_features.append(features)
|
| 192 |
+
all_labels.append(label)
|
| 193 |
+
else:
|
| 194 |
+
print(f"Skipping file {json_file}: feature dimension mismatch")
|
| 195 |
+
|
| 196 |
+
if missing_joints:
|
| 197 |
+
print(f"File {json_file.name} missing joints: {missing_joints}")
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Error reading file {json_file}: {e}")
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
print(f"Loaded {len(all_features)} samples")
|
| 204 |
+
|
| 205 |
+
# count samples per label
|
| 206 |
+
label_counts = {}
|
| 207 |
+
for label in all_labels:
|
| 208 |
+
label_counts[label] = label_counts.get(label, 0) + 1
|
| 209 |
+
|
| 210 |
+
print("Label distribution:")
|
| 211 |
+
for label, count in sorted(label_counts.items()):
|
| 212 |
+
print(f" {label}: {count} samples")
|
| 213 |
+
|
| 214 |
+
return np.array(all_features), np.array(all_labels)
|
| 215 |
+
|
| 216 |
+
def train(self, X, y, test_size=0.2):
|
| 217 |
+
"""
|
| 218 |
+
Train the classifier.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
X: feature data
|
| 222 |
+
y: labels
|
| 223 |
+
test_size: ratio for test split
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
dict: a dictionary containing training results
|
| 227 |
+
"""
|
| 228 |
+
print(f"\nStarting training for model: {self.model_type}...")
|
| 229 |
+
print(f"Data shape: {X.shape}")
|
| 230 |
+
print(f"Number of labels: {len(np.unique(y))}")
|
| 231 |
+
|
| 232 |
+
# Encode labels
|
| 233 |
+
y_encoded = self.label_encoder.fit_transform(y)
|
| 234 |
+
|
| 235 |
+
# Split data
|
| 236 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 237 |
+
X, y_encoded, test_size=test_size, random_state=42, stratify=y_encoded
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
print(f"Train set size: {X_train.shape[0]}")
|
| 241 |
+
print(f"Test set size: {X_test.shape[0]}")
|
| 242 |
+
|
| 243 |
+
# standardize features
|
| 244 |
+
X_train_scaled = self.scaler.fit_transform(X_train)
|
| 245 |
+
X_test_scaled = self.scaler.transform(X_test)
|
| 246 |
+
|
| 247 |
+
# If using distillation process: train RF teacher first, then train MLPRegressor student to fit teacher's predict_proba
|
| 248 |
+
if self.model_type == 'distilled_rf':
|
| 249 |
+
print("Using distillation: train RandomForest teacher, then fit an MLPRegressor student to teacher soft labels")
|
| 250 |
+
# Train teacher
|
| 251 |
+
teacher = self._get_model()
|
| 252 |
+
teacher.fit(X_train_scaled, y_train)
|
| 253 |
+
|
| 254 |
+
# Get teacher's probability distribution as soft labels
|
| 255 |
+
y_train_proba = teacher.predict_proba(X_train_scaled)
|
| 256 |
+
|
| 257 |
+
# Create and train student (MLPRegressor) to fit probability vectors
|
| 258 |
+
student = MLPRegressor(hidden_layer_sizes=(128, 64, 32),
|
| 259 |
+
activation='relu',
|
| 260 |
+
solver='adam',
|
| 261 |
+
max_iter=1000,
|
| 262 |
+
learning_rate_init=0.001,
|
| 263 |
+
random_state=42,
|
| 264 |
+
early_stopping=True,
|
| 265 |
+
validation_fraction=0.1)
|
| 266 |
+
|
| 267 |
+
print("Training student model to fit teacher probability outputs...")
|
| 268 |
+
print(f"Teacher probability output shape: {y_train_proba.shape}")
|
| 269 |
+
|
| 270 |
+
# Multi-output regression, target is probability vector
|
| 271 |
+
student.fit(X_train_scaled, y_train_proba)
|
| 272 |
+
|
| 273 |
+
# Save models
|
| 274 |
+
self.model = teacher
|
| 275 |
+
self.student_model = student
|
| 276 |
+
|
| 277 |
+
# Use student to predict on train/test sets
|
| 278 |
+
y_train_pred_proba = student.predict(X_train_scaled)
|
| 279 |
+
y_test_pred_proba = student.predict(X_test_scaled)
|
| 280 |
+
|
| 281 |
+
# Apply softmax to ensure probabilities sum to 1
|
| 282 |
+
def softmax(x):
|
| 283 |
+
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
|
| 284 |
+
return exp_x / np.sum(exp_x, axis=1, keepdims=True)
|
| 285 |
+
|
| 286 |
+
y_train_pred_proba = softmax(y_train_pred_proba)
|
| 287 |
+
y_test_pred_proba = softmax(y_test_pred_proba)
|
| 288 |
+
|
| 289 |
+
y_train_pred = np.argmax(y_train_pred_proba, axis=1)
|
| 290 |
+
y_test_pred = np.argmax(y_test_pred_proba, axis=1)
|
| 291 |
+
|
| 292 |
+
print(f"Student predicted probability shape: {y_test_pred_proba.shape}")
|
| 293 |
+
print(f"Student training accuracy: {accuracy_score(y_train, y_train_pred):.4f}")
|
| 294 |
+
|
| 295 |
+
else:
|
| 296 |
+
# Standard flow: train a single model
|
| 297 |
+
self.model = self._get_model()
|
| 298 |
+
self.model.fit(X_train_scaled, y_train)
|
| 299 |
+
|
| 300 |
+
y_train_pred = self.model.predict(X_train_scaled)
|
| 301 |
+
y_test_pred = self.model.predict(X_test_scaled)
|
| 302 |
+
|
| 303 |
+
# compute accuracies
|
| 304 |
+
train_accuracy = accuracy_score(y_train, y_train_pred)
|
| 305 |
+
test_accuracy = accuracy_score(y_test, y_test_pred)
|
| 306 |
+
|
| 307 |
+
# cross validation on the model used for training
|
| 308 |
+
# if student_model exists, still use teacher for cross-val
|
| 309 |
+
cv_model = self.model if self.model is not None else None
|
| 310 |
+
if cv_model is not None:
|
| 311 |
+
cv_scores = cross_val_score(cv_model, X_train_scaled, y_train, cv=5)
|
| 312 |
+
else:
|
| 313 |
+
cv_scores = np.array([])
|
| 314 |
+
|
| 315 |
+
print("\nTraining results:")
|
| 316 |
+
print(f"Train accuracy: {train_accuracy:.4f}")
|
| 317 |
+
print(f"Test accuracy: {test_accuracy:.4f}")
|
| 318 |
+
print(f"5-fold CV accuracy: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")
|
| 319 |
+
|
| 320 |
+
# classification report
|
| 321 |
+
print("\nTest set classification report:")
|
| 322 |
+
target_names = self.label_encoder.classes_
|
| 323 |
+
print(classification_report(y_test, y_test_pred, target_names=target_names))
|
| 324 |
+
|
| 325 |
+
# confusion matrix
|
| 326 |
+
cm = confusion_matrix(y_test, y_test_pred)
|
| 327 |
+
|
| 328 |
+
return {
|
| 329 |
+
'train_accuracy': train_accuracy,
|
| 330 |
+
'test_accuracy': test_accuracy,
|
| 331 |
+
'cv_scores': cv_scores,
|
| 332 |
+
'confusion_matrix': cm,
|
| 333 |
+
'target_names': target_names,
|
| 334 |
+
'X_test': X_test_scaled,
|
| 335 |
+
'y_test': y_test,
|
| 336 |
+
'y_test_pred': y_test_pred
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def save_model(self, filepath):
|
| 340 |
+
"""Save trained model to disk."""
|
| 341 |
+
model_data = {
|
| 342 |
+
'model': self.model,
|
| 343 |
+
'scaler': self.scaler,
|
| 344 |
+
'label_encoder': self.label_encoder,
|
| 345 |
+
'model_type': self.model_type,
|
| 346 |
+
'target_joints': self.target_joints,
|
| 347 |
+
'feature_columns': self.feature_columns
|
| 348 |
+
}
|
| 349 |
+
joblib.dump(model_data, filepath)
|
| 350 |
+
print(f"Model saved to: {filepath}")
|
| 351 |
+
|
| 352 |
+
def load_model(self, filepath):
|
| 353 |
+
"""Load trained model from disk."""
|
| 354 |
+
model_data = joblib.load(filepath)
|
| 355 |
+
self.model = model_data['model']
|
| 356 |
+
self.scaler = model_data['scaler']
|
| 357 |
+
self.label_encoder = model_data['label_encoder']
|
| 358 |
+
self.model_type = model_data['model_type']
|
| 359 |
+
self.target_joints = model_data['target_joints']
|
| 360 |
+
self.feature_columns = model_data['feature_columns']
|
| 361 |
+
print(f"Model loaded from: {filepath}")
|
| 362 |
+
|
| 363 |
+
def predict(self, X):
|
| 364 |
+
"""Run prediction on input features."""
|
| 365 |
+
if self.model is None and self.student_model is None:
|
| 366 |
+
raise ValueError("Model not trained or loaded")
|
| 367 |
+
|
| 368 |
+
X_scaled = self.scaler.transform(X)
|
| 369 |
+
|
| 370 |
+
# Prefer to use student_model (if exists) to generate probability output
|
| 371 |
+
if self.student_model is not None:
|
| 372 |
+
proba = self.student_model.predict(X_scaled) # Returns probability vector
|
| 373 |
+
preds = np.argmax(proba, axis=1)
|
| 374 |
+
labels = self.label_encoder.inverse_transform(preds)
|
| 375 |
+
return labels, proba
|
| 376 |
+
|
| 377 |
+
# Otherwise fall back to original model
|
| 378 |
+
predictions = self.model.predict(X_scaled)
|
| 379 |
+
probabilities = None
|
| 380 |
+
if hasattr(self.model, 'predict_proba'):
|
| 381 |
+
probabilities = self.model.predict_proba(X_scaled)
|
| 382 |
+
return self.label_encoder.inverse_transform(predictions), probabilities
|
| 383 |
+
|
| 384 |
+
def predict_single_json(self, json_path):
|
| 385 |
+
"""
|
| 386 |
+
Predict pose class for a single JSON file.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
json_path: path to the JSON file
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
dict: prediction details or error information
|
| 393 |
+
"""
|
| 394 |
+
if self.model is None:
|
| 395 |
+
raise ValueError("Model not trained or loaded")
|
| 396 |
+
|
| 397 |
+
try:
|
| 398 |
+
# Read JSON file
|
| 399 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 400 |
+
data = json.load(f)
|
| 401 |
+
|
| 402 |
+
landmarks = data.get('landmarks', {})
|
| 403 |
+
|
| 404 |
+
# Extract coordinates of target joints
|
| 405 |
+
features = []
|
| 406 |
+
missing_joints = []
|
| 407 |
+
available_joints = []
|
| 408 |
+
|
| 409 |
+
for joint in self.target_joints:
|
| 410 |
+
if joint in landmarks:
|
| 411 |
+
joint_data = landmarks[joint]
|
| 412 |
+
features.extend([
|
| 413 |
+
joint_data.get('x', 0.0),
|
| 414 |
+
joint_data.get('y', 0.0),
|
| 415 |
+
joint_data.get('z', 0.0)
|
| 416 |
+
])
|
| 417 |
+
available_joints.append(joint)
|
| 418 |
+
else:
|
| 419 |
+
# If a joint is missing, fill with zeros
|
| 420 |
+
features.extend([0.0, 0.0, 0.0])
|
| 421 |
+
missing_joints.append(joint)
|
| 422 |
+
|
| 423 |
+
if len(features) != len(self.feature_columns):
|
| 424 |
+
raise ValueError(f"Feature dimension mismatch: expected {len(self.feature_columns)}, got {len(features)}")
|
| 425 |
+
|
| 426 |
+
# Convert to numpy array and predict
|
| 427 |
+
X = np.array([features])
|
| 428 |
+
predictions, probabilities = self.predict(X)
|
| 429 |
+
|
| 430 |
+
# build result dict
|
| 431 |
+
result = {
|
| 432 |
+
'file_path': str(json_path),
|
| 433 |
+
'file_name': Path(json_path).name,
|
| 434 |
+
'predicted_label': predictions[0],
|
| 435 |
+
'confidence_scores': {},
|
| 436 |
+
'available_joints': available_joints,
|
| 437 |
+
'missing_joints': missing_joints,
|
| 438 |
+
'joint_coverage': f"{len(available_joints)}/{len(self.target_joints)}"
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
# add per-class confidence scores
|
| 442 |
+
if probabilities is not None:
|
| 443 |
+
for i, label in enumerate(self.label_encoder.classes_):
|
| 444 |
+
result['confidence_scores'][label] = float(probabilities[0][i])
|
| 445 |
+
|
| 446 |
+
# highest confidence
|
| 447 |
+
max_prob_idx = np.argmax(probabilities[0])
|
| 448 |
+
result['max_confidence'] = float(probabilities[0][max_prob_idx])
|
| 449 |
+
|
| 450 |
+
return result
|
| 451 |
+
|
| 452 |
+
except Exception as e:
|
| 453 |
+
return {
|
| 454 |
+
'file_path': str(json_path),
|
| 455 |
+
'file_name': Path(json_path).name,
|
| 456 |
+
'error': str(e),
|
| 457 |
+
'predicted_label': None
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
def evaluate_test_directory(self, test_dir):
|
| 461 |
+
"""
|
| 462 |
+
Evaluate all data in a test directory.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
test_dir: path to the test data directory
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
dict: dictionary containing detailed evaluation results
|
| 469 |
+
"""
|
| 470 |
+
if self.model is None:
|
| 471 |
+
raise ValueError("Model not trained or loaded")
|
| 472 |
+
|
| 473 |
+
test_path = Path(test_dir)
|
| 474 |
+
if not test_path.exists():
|
| 475 |
+
raise ValueError(f"Test directory does not exist: {test_dir}")
|
| 476 |
+
|
| 477 |
+
# start timing
|
| 478 |
+
start_time = time.time()
|
| 479 |
+
print(f"Starting evaluation on test dataset: {test_path}")
|
| 480 |
+
print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
|
| 481 |
+
|
| 482 |
+
# store all prediction results
|
| 483 |
+
all_results = []
|
| 484 |
+
label_stats = {}
|
| 485 |
+
total_prediction_time = 0.0
|
| 486 |
+
prediction_count = 0
|
| 487 |
+
|
| 488 |
+
# iterate over label folders
|
| 489 |
+
for label_dir in test_path.iterdir():
|
| 490 |
+
if not label_dir.is_dir() or not label_dir.name.startswith('label_'):
|
| 491 |
+
continue
|
| 492 |
+
|
| 493 |
+
true_label = label_dir.name
|
| 494 |
+
json_files = list(label_dir.glob('*.json'))
|
| 495 |
+
|
| 496 |
+
print(f"Evaluating {true_label}: {len(json_files)} files")
|
| 497 |
+
|
| 498 |
+
label_stats[true_label] = {
|
| 499 |
+
'total': len(json_files),
|
| 500 |
+
'correct': 0,
|
| 501 |
+
'incorrect': 0,
|
| 502 |
+
'errors': 0,
|
| 503 |
+
'predictions': {},
|
| 504 |
+
'confidence_scores': [],
|
| 505 |
+
'prediction_times': []
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
for json_file in json_files:
|
| 509 |
+
# Single prediction timing
|
| 510 |
+
pred_start_time = time.time()
|
| 511 |
+
result = self.predict_single_json(json_file)
|
| 512 |
+
pred_end_time = time.time()
|
| 513 |
+
|
| 514 |
+
single_prediction_time = pred_end_time - pred_start_time
|
| 515 |
+
total_prediction_time += single_prediction_time
|
| 516 |
+
prediction_count += 1
|
| 517 |
+
|
| 518 |
+
if 'error' in result:
|
| 519 |
+
label_stats[true_label]['errors'] += 1
|
| 520 |
+
print(f" Error: {json_file.name} - {result['error']}")
|
| 521 |
+
continue
|
| 522 |
+
|
| 523 |
+
predicted_label = result['predicted_label']
|
| 524 |
+
is_correct = predicted_label == true_label
|
| 525 |
+
|
| 526 |
+
if is_correct:
|
| 527 |
+
label_stats[true_label]['correct'] += 1
|
| 528 |
+
else:
|
| 529 |
+
label_stats[true_label]['incorrect'] += 1
|
| 530 |
+
|
| 531 |
+
# Count prediction distribution
|
| 532 |
+
if predicted_label not in label_stats[true_label]['predictions']:
|
| 533 |
+
label_stats[true_label]['predictions'][predicted_label] = 0
|
| 534 |
+
label_stats[true_label]['predictions'][predicted_label] += 1
|
| 535 |
+
|
| 536 |
+
# Record confidence and prediction time
|
| 537 |
+
if 'max_confidence' in result:
|
| 538 |
+
label_stats[true_label]['confidence_scores'].append(result['max_confidence'])
|
| 539 |
+
label_stats[true_label]['prediction_times'].append(single_prediction_time)
|
| 540 |
+
|
| 541 |
+
# Save detailed result
|
| 542 |
+
all_results.append({
|
| 543 |
+
'file_path': str(json_file),
|
| 544 |
+
'file_name': json_file.name,
|
| 545 |
+
'true_label': true_label,
|
| 546 |
+
'predicted_label': predicted_label,
|
| 547 |
+
'is_correct': is_correct,
|
| 548 |
+
'confidence': result.get('max_confidence', 0.0),
|
| 549 |
+
'confidence_scores': result.get('confidence_scores', {}),
|
| 550 |
+
'joint_coverage': result.get('joint_coverage', '0/13'),
|
| 551 |
+
'prediction_time': single_prediction_time
|
| 552 |
+
})
|
| 553 |
+
|
| 554 |
+
# end timing
|
| 555 |
+
end_time = time.time()
|
| 556 |
+
total_execution_time = end_time - start_time
|
| 557 |
+
|
| 558 |
+
# compute aggregate statistics
|
| 559 |
+
total_samples = sum(stats['total'] for stats in label_stats.values())
|
| 560 |
+
total_correct = sum(stats['correct'] for stats in label_stats.values())
|
| 561 |
+
total_errors = sum(stats['errors'] for stats in label_stats.values())
|
| 562 |
+
total_tested = total_samples - total_errors
|
| 563 |
+
|
| 564 |
+
overall_accuracy = total_correct / total_tested if total_tested > 0 else 0.0
|
| 565 |
+
avg_prediction_time = total_prediction_time / prediction_count if prediction_count > 0 else 0.0
|
| 566 |
+
|
| 567 |
+
# build confusion matrix
|
| 568 |
+
confusion_matrix = {}
|
| 569 |
+
for true_label in label_stats.keys():
|
| 570 |
+
confusion_matrix[true_label] = {}
|
| 571 |
+
for predicted_label in label_stats.keys():
|
| 572 |
+
confusion_matrix[true_label][predicted_label] = 0
|
| 573 |
+
|
| 574 |
+
for result in all_results:
|
| 575 |
+
if result.get('is_correct') is not None: # exclude error cases
|
| 576 |
+
true_label = result['true_label']
|
| 577 |
+
predicted_label = result['predicted_label']
|
| 578 |
+
confusion_matrix[true_label][predicted_label] += 1
|
| 579 |
+
|
| 580 |
+
return {
|
| 581 |
+
'label_stats': label_stats,
|
| 582 |
+
'overall_accuracy': overall_accuracy,
|
| 583 |
+
'total_samples': total_samples,
|
| 584 |
+
'total_correct': total_correct,
|
| 585 |
+
'total_errors': total_errors,
|
| 586 |
+
'total_tested': total_tested,
|
| 587 |
+
'confusion_matrix': confusion_matrix,
|
| 588 |
+
'detailed_results': all_results,
|
| 589 |
+
'timing_stats': {
|
| 590 |
+
'total_execution_time': total_execution_time,
|
| 591 |
+
'total_prediction_time': total_prediction_time,
|
| 592 |
+
'avg_prediction_time': avg_prediction_time,
|
| 593 |
+
'prediction_count': prediction_count,
|
| 594 |
+
'start_time': start_time,
|
| 595 |
+
'end_time': end_time,
|
| 596 |
+
'overhead_time': total_execution_time - total_prediction_time
|
| 597 |
+
}
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
def print_evaluation_report(self, eval_results):
|
| 601 |
+
"""
|
| 602 |
+
Print a detailed evaluation report.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
eval_results: dictionary returned by evaluate_test_directory
|
| 606 |
+
"""
|
| 607 |
+
timing_stats = eval_results.get('timing_stats', {})
|
| 608 |
+
|
| 609 |
+
print("\n" + "=" * 80)
|
| 610 |
+
print("Test dataset evaluation report")
|
| 611 |
+
print("=" * 80)
|
| 612 |
+
|
| 613 |
+
# Overall statistics
|
| 614 |
+
print(f"Total samples: {eval_results['total_samples']}")
|
| 615 |
+
print(f"Successfully tested: {eval_results['total_tested']}")
|
| 616 |
+
print(f"Errors: {eval_results['total_errors']}")
|
| 617 |
+
print(
|
| 618 |
+
f"Overall accuracy: {eval_results['overall_accuracy']:.4f} "
|
| 619 |
+
f"({eval_results['total_correct']}/{eval_results['total_tested']})"
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Timing statistics
|
| 623 |
+
if timing_stats:
|
| 624 |
+
total_time = timing_stats['total_execution_time']
|
| 625 |
+
prediction_time = timing_stats['total_prediction_time']
|
| 626 |
+
avg_time = timing_stats['avg_prediction_time']
|
| 627 |
+
overhead_time = timing_stats['overhead_time']
|
| 628 |
+
prediction_count = timing_stats['prediction_count']
|
| 629 |
+
|
| 630 |
+
print("\nTiming statistics:")
|
| 631 |
+
print("-" * 50)
|
| 632 |
+
print(f"Total execution time: {total_time:.4f} s")
|
| 633 |
+
print(f"Total prediction time: {prediction_time:.4f} s")
|
| 634 |
+
print(f"Overhead time: {overhead_time:.4f} s")
|
| 635 |
+
print(f"Average prediction time: {avg_time * 1000:.2f} ms")
|
| 636 |
+
print(f"Prediction throughput: {prediction_count / total_time:.2f} preds/s")
|
| 637 |
+
print(
|
| 638 |
+
f"Prediction efficiency: {(prediction_time / total_time) * 100:.1f}% "
|
| 639 |
+
f"(prediction time / total)"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# Per-label detailed statistics
|
| 643 |
+
print("\nPer-label stats:")
|
| 644 |
+
print("-" * 80)
|
| 645 |
+
print(
|
| 646 |
+
f"{'Label':<10} {'Total':<6} {'Correct':<6} {'Wrong':<6} "
|
| 647 |
+
f"{'Accuracy':<8} {'AvgConf':<10} {'AvgPredTime':<12}"
|
| 648 |
+
)
|
| 649 |
+
print("-" * 80)
|
| 650 |
+
|
| 651 |
+
for label, stats in sorted(eval_results['label_stats'].items()):
|
| 652 |
+
accuracy = (
|
| 653 |
+
stats['correct'] / (stats['total'] - stats['errors'])
|
| 654 |
+
if (stats['total'] - stats['errors']) > 0
|
| 655 |
+
else 0.0
|
| 656 |
+
)
|
| 657 |
+
avg_confidence = (
|
| 658 |
+
np.mean(stats['confidence_scores']) if stats['confidence_scores'] else 0.0
|
| 659 |
+
)
|
| 660 |
+
avg_pred_time = (
|
| 661 |
+
np.mean(stats['prediction_times'])
|
| 662 |
+
if 'prediction_times' in stats and stats['prediction_times']
|
| 663 |
+
else 0.0
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
print(
|
| 667 |
+
f"{label:<10} {stats['total']:<6} {stats['correct']:<6} {stats['incorrect']:<6} "
|
| 668 |
+
f"{accuracy:.4f} {avg_confidence:.4f} {avg_pred_time * 1000:.2f}ms"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Confusion matrix
|
| 672 |
+
print("\nConfusion matrix:")
|
| 673 |
+
print("-" * 60)
|
| 674 |
+
labels = sorted(eval_results['label_stats'].keys())
|
| 675 |
+
|
| 676 |
+
# Header row
|
| 677 |
+
print(f"{'True\\Pred':<12}", end="")
|
| 678 |
+
for label in labels:
|
| 679 |
+
print(f"{label:<10}", end="")
|
| 680 |
+
print()
|
| 681 |
+
|
| 682 |
+
# Data rows
|
| 683 |
+
for true_label in labels:
|
| 684 |
+
print(f"{true_label:<12}", end="")
|
| 685 |
+
for pred_label in labels:
|
| 686 |
+
count = eval_results['confusion_matrix'][true_label][pred_label]
|
| 687 |
+
print(f"{count:<10}", end="")
|
| 688 |
+
print()
|
| 689 |
+
|
| 690 |
+
# Per-label prediction distribution
|
| 691 |
+
print("\nPer-label prediction distribution:")
|
| 692 |
+
print("-" * 80)
|
| 693 |
+
for true_label, stats in sorted(eval_results['label_stats'].items()):
|
| 694 |
+
if stats['predictions']:
|
| 695 |
+
print(f"{true_label}:")
|
| 696 |
+
total_predictions = sum(stats['predictions'].values())
|
| 697 |
+
for pred_label, count in sorted(stats['predictions'].items()):
|
| 698 |
+
percentage = (count / total_predictions) * 100
|
| 699 |
+
print(f" -> {pred_label}: {count} ({percentage:.1f}%)")
|
| 700 |
+
|
| 701 |
+
# Error analysis
|
| 702 |
+
print("\nError analysis:")
|
| 703 |
+
print("-" * 40)
|
| 704 |
+
incorrect_results = [r for r in eval_results['detailed_results'] if not r['is_correct']]
|
| 705 |
+
|
| 706 |
+
if incorrect_results:
|
| 707 |
+
# Sort by confidence and show top mistaken predictions
|
| 708 |
+
incorrect_results.sort(key=lambda x: x['confidence'], reverse=True)
|
| 709 |
+
print("Highest-confidence incorrect predictions (top 10):")
|
| 710 |
+
for i, result in enumerate(incorrect_results[:10]):
|
| 711 |
+
pred_time = result.get('prediction_time', 0) * 1000 # ms
|
| 712 |
+
print(
|
| 713 |
+
f"{i + 1:2d}. {result['file_name']}: {result['true_label']} -> {result['predicted_label']} "
|
| 714 |
+
f"(conf: {result['confidence']:.4f}, time: {pred_time:.2f}ms)"
|
| 715 |
+
)
|
| 716 |
+
else:
|
| 717 |
+
print("No incorrect predictions found.")
|
| 718 |
+
|
| 719 |
+
# Performance analysis
|
| 720 |
+
if timing_stats and eval_results['detailed_results']:
|
| 721 |
+
print("\nPerformance analysis:")
|
| 722 |
+
print("-" * 40)
|
| 723 |
+
prediction_times = [
|
| 724 |
+
r.get('prediction_time', 0) for r in eval_results['detailed_results'] if 'prediction_time' in r
|
| 725 |
+
]
|
| 726 |
+
if prediction_times:
|
| 727 |
+
min_time = min(prediction_times) * 1000
|
| 728 |
+
max_time = max(prediction_times) * 1000
|
| 729 |
+
median_time = np.median(prediction_times) * 1000
|
| 730 |
+
std_time = np.std(prediction_times) * 1000
|
| 731 |
+
|
| 732 |
+
print("Prediction time distribution:")
|
| 733 |
+
print(f" Fastest: {min_time:.2f}ms")
|
| 734 |
+
print(f" Slowest: {max_time:.2f}ms")
|
| 735 |
+
print(f" Median: {median_time:.2f}ms")
|
| 736 |
+
print(f" Stddev: {std_time:.2f}ms")
|
| 737 |
+
|
| 738 |
+
print("\n" + "=" * 80)
|
| 739 |
+
|
| 740 |
+
def plot_confusion_matrix(self, cm, target_names, save_path=None):
|
| 741 |
+
"""Plot confusion matrix."""
|
| 742 |
+
plt.figure(figsize=(10, 8))
|
| 743 |
+
if SEABORN_AVAILABLE:
|
| 744 |
+
sns.heatmap(
|
| 745 |
+
cm,
|
| 746 |
+
annot=True,
|
| 747 |
+
fmt='d',
|
| 748 |
+
cmap='Blues',
|
| 749 |
+
xticklabels=target_names,
|
| 750 |
+
yticklabels=target_names,
|
| 751 |
+
)
|
| 752 |
+
else:
|
| 753 |
+
# Fallback using matplotlib only
|
| 754 |
+
im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
| 755 |
+
plt.colorbar(im)
|
| 756 |
+
tick_marks = np.arange(len(target_names))
|
| 757 |
+
plt.xticks(tick_marks, target_names, rotation=45, ha='right')
|
| 758 |
+
plt.yticks(tick_marks, target_names)
|
| 759 |
+
# Annotate cells
|
| 760 |
+
thresh = cm.max() / 2.0 if cm.size else 0
|
| 761 |
+
for i in range(cm.shape[0]):
|
| 762 |
+
for j in range(cm.shape[1]):
|
| 763 |
+
plt.text(j, i, format(cm[i, j], 'd'),
|
| 764 |
+
ha="center", va="center",
|
| 765 |
+
color="white" if cm[i, j] > thresh else "black")
|
| 766 |
+
|
| 767 |
+
plt.title(f"{self.model_type.title()} model confusion matrix")
|
| 768 |
+
plt.xlabel('Predicted')
|
| 769 |
+
plt.ylabel('True')
|
| 770 |
+
|
| 771 |
+
if save_path:
|
| 772 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 773 |
+
print(f"Confusion matrix saved to: {save_path}")
|
| 774 |
+
|
| 775 |
+
plt.show()
|
| 776 |
+
|
| 777 |
+
def export_to_onnx(self, model_type='random_forest', output_path=None):
|
| 778 |
+
"""
|
| 779 |
+
Export the trained model to ONNX format (only models supported by Barracuda).
|
| 780 |
+
Note: Barracuda does not support LinearClassifier layers (e.g., LogisticRegression/SVM) — only tree models are supported.
|
| 781 |
+
"""
|
| 782 |
+
if not ONNX_AVAILABLE:
|
| 783 |
+
print("Error: ONNX export is unavailable. Please install skl2onnx and onnx packages:")
|
| 784 |
+
print("pip install skl2onnx onnx")
|
| 785 |
+
return None
|
| 786 |
+
|
| 787 |
+
if not hasattr(self, 'model') or self.model is None:
|
| 788 |
+
print("Error: Model is not trained yet. Please train the model first.")
|
| 789 |
+
return None
|
| 790 |
+
|
| 791 |
+
# Check if current model type matches requested export type
|
| 792 |
+
if hasattr(self, 'model_type') and self.model_type != model_type:
|
| 793 |
+
print(f"Warning: Currently trained {self.model_type} model, but requested to export {model_type} model")
|
| 794 |
+
print(f"Will export currently trained {self.model_type} model")
|
| 795 |
+
model_name = self.model_type
|
| 796 |
+
else:
|
| 797 |
+
model_name = model_type
|
| 798 |
+
|
| 799 |
+
# Barracuda only supports tree models, not LinearClassifier
|
| 800 |
+
if model_name in ['logistic', 'svm']:
|
| 801 |
+
print(f"❌ Barracuda/Unity does not support ONNX import for {model_name} models (LinearClassifier layer).")
|
| 802 |
+
print("Please use random_forest or gradient_boost for export.")
|
| 803 |
+
return None
|
| 804 |
+
|
| 805 |
+
# If student_model exists -> export student_model (MLP), otherwise export self.model
|
| 806 |
+
model_to_export = None
|
| 807 |
+
export_name = None
|
| 808 |
+
|
| 809 |
+
if self.student_model is not None:
|
| 810 |
+
model_to_export = self.student_model
|
| 811 |
+
export_name = 'distilled_mlp'
|
| 812 |
+
print("Detected student_model. Exporting student (MLP) to ONNX (suitable for Unity/Barracuda).")
|
| 813 |
+
else:
|
| 814 |
+
model_to_export = self.model
|
| 815 |
+
export_name = model_name
|
| 816 |
+
|
| 817 |
+
if model_to_export is None:
|
| 818 |
+
print("Error: No model available for export.")
|
| 819 |
+
return None
|
| 820 |
+
|
| 821 |
+
# Generate output file path
|
| 822 |
+
if output_path is None:
|
| 823 |
+
output_path = f"pose_classifier_{export_name}.onnx"
|
| 824 |
+
|
| 825 |
+
print(f"About to export model to: {output_path}, export target: {export_name}")
|
| 826 |
+
|
| 827 |
+
try:
|
| 828 |
+
feature_count = len(self.target_joints) * 3
|
| 829 |
+
initial_type = [('float_input', FloatTensorType([None, feature_count]))]
|
| 830 |
+
|
| 831 |
+
onnx_model = convert_sklearn(
|
| 832 |
+
model_to_export,
|
| 833 |
+
initial_types=initial_type,
|
| 834 |
+
target_opset=12
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
with open(output_path, "wb") as f:
|
| 838 |
+
f.write(onnx_model.SerializeToString())
|
| 839 |
+
|
| 840 |
+
print(f"✅ Successfully exported {export_name} model to ONNX format: {output_path}")
|
| 841 |
+
|
| 842 |
+
# Save label mapping and Scaler parameters
|
| 843 |
+
label_mapping_path = output_path.replace('.onnx', '_labels.json')
|
| 844 |
+
label_mapping = {
|
| 845 |
+
'label_encoder_classes': self.label_encoder.classes_.tolist(),
|
| 846 |
+
'model_type': export_name,
|
| 847 |
+
'feature_count': feature_count,
|
| 848 |
+
'target_joints': self.target_joints,
|
| 849 |
+
'description': f'Pose classifier - {len(self.target_joints)} landmarks with x,y,z coordinates',
|
| 850 |
+
'scaler_mean': self.scaler.mean_.tolist(),
|
| 851 |
+
'scaler_scale': self.scaler.scale_.tolist()
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
with open(label_mapping_path, 'w', encoding='utf-8') as f:
|
| 855 |
+
json.dump(label_mapping, f, ensure_ascii=False, indent=2)
|
| 856 |
+
|
| 857 |
+
print(f"✅ Label mapping and scaler parameters saved to: {label_mapping_path}")
|
| 858 |
+
|
| 859 |
+
print("⚠️ Note: The exported ONNX expects inputs to be standardized with scaler_mean/scaler_scale.")
|
| 860 |
+
|
| 861 |
+
return output_path
|
| 862 |
+
|
| 863 |
+
except Exception as e:
|
| 864 |
+
print(f"❌ ONNX export failed: {str(e)}")
|
| 865 |
+
import traceback
|
| 866 |
+
traceback.print_exc()
|
| 867 |
+
return None
|
| 868 |
+
|
| 869 |
+
def export_to_tflite(self, output_path=None):
|
| 870 |
+
"""
|
| 871 |
+
Export student_model (MLP) to TFLite format.
|
| 872 |
+
Dependencies: skl2onnx, onnx, onnx-tf, tensorflow
|
| 873 |
+
"""
|
| 874 |
+
if self.student_model is None:
|
| 875 |
+
print("❌ Only exporting student_model (MLPRegressor) to TFLite is supported. Please train with --model distilled_rf first.")
|
| 876 |
+
return None
|
| 877 |
+
|
| 878 |
+
try:
|
| 879 |
+
import onnx
|
| 880 |
+
from skl2onnx import convert_sklearn
|
| 881 |
+
from skl2onnx.common.data_types import FloatTensorType
|
| 882 |
+
from onnx_tf.backend import prepare
|
| 883 |
+
import tensorflow as tf
|
| 884 |
+
except ImportError:
|
| 885 |
+
print("❌ You need to install skl2onnx, onnx, onnx-tf, tensorflow.")
|
| 886 |
+
print("pip install skl2onnx onnx onnx-tf tensorflow")
|
| 887 |
+
return None
|
| 888 |
+
|
| 889 |
+
feature_count = len(self.target_joints) * 3
|
| 890 |
+
initial_type = [('float_input', FloatTensorType([None, feature_count]))]
|
| 891 |
+
|
| 892 |
+
# 1. Export to ONNX
|
| 893 |
+
print("Exporting student_model to ONNX...")
|
| 894 |
+
onnx_model = convert_sklearn(
|
| 895 |
+
self.student_model,
|
| 896 |
+
initial_types=initial_type,
|
| 897 |
+
target_opset=12
|
| 898 |
+
)
|
| 899 |
+
onnx_path = "temp_student.onnx"
|
| 900 |
+
with open(onnx_path, "wb") as f:
|
| 901 |
+
f.write(onnx_model.SerializeToString())
|
| 902 |
+
print(f"✅ ONNX export successful: {onnx_path}")
|
| 903 |
+
|
| 904 |
+
# 2. ONNX -> TensorFlow SavedModel
|
| 905 |
+
print("Converting ONNX to TensorFlow SavedModel...")
|
| 906 |
+
tf_model = prepare(onnx.load(onnx_path))
|
| 907 |
+
tf_saved_path = "temp_student_tf"
|
| 908 |
+
tf_model.export_graph(tf_saved_path)
|
| 909 |
+
print(f"✅ SavedModel export successful: {tf_saved_path}")
|
| 910 |
+
|
| 911 |
+
# 3. SavedModel -> TFLite
|
| 912 |
+
print("Converting SavedModel to TFLite...")
|
| 913 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_path)
|
| 914 |
+
tflite_model = converter.convert()
|
| 915 |
+
if output_path is None:
|
| 916 |
+
output_path = "pose_classifier_distilled_mlp.tflite"
|
| 917 |
+
with open(output_path, "wb") as f:
|
| 918 |
+
f.write(tflite_model)
|
| 919 |
+
print(f"✅ TFLite export successful: {output_path}")
|
| 920 |
+
|
| 921 |
+
# Cleanup temporary files (optional)
|
| 922 |
+
import os
|
| 923 |
+
os.remove(onnx_path)
|
| 924 |
+
import shutil
|
| 925 |
+
shutil.rmtree(tf_saved_path, ignore_errors=True)
|
| 926 |
+
|
| 927 |
+
return output_path
|
| 928 |
+
|
| 929 |
+
def main():
|
| 930 |
+
parser = argparse.ArgumentParser(description="Pose classification machine learning script")
|
| 931 |
+
parser.add_argument("--data", "-d", default="PoseData", help="Pose data directory (default: PoseData)")
|
| 932 |
+
parser.add_argument(
|
| 933 |
+
"--model",
|
| 934 |
+
"-m",
|
| 935 |
+
choices=['random_forest', 'svm', 'gradient_boost', 'logistic', 'distilled_rf'],
|
| 936 |
+
default='random_forest',
|
| 937 |
+
help="Model type (default: random_forest)",
|
| 938 |
+
)
|
| 939 |
+
parser.add_argument("--test-size", "-t", type=float, default=0.2, help="Test set ratio (default: 0.2)")
|
| 940 |
+
parser.add_argument("--save-model", "-s", help="Path to save the trained model")
|
| 941 |
+
parser.add_argument("--load-model", "-l", help="Path to load an already trained model")
|
| 942 |
+
parser.add_argument("--predict", "-p", help="Path of a single JSON file to predict")
|
| 943 |
+
parser.add_argument("--evaluate", "-e", help="Path of a test directory to evaluate all JSON files")
|
| 944 |
+
parser.add_argument("--no-plot", action="store_true", help="Do not display confusion matrix plot")
|
| 945 |
+
parser.add_argument("--train", action="store_true", help="Force training even if --load-model is provided")
|
| 946 |
+
parser.add_argument("--export-onnx", help="Export model to ONNX format; specify output file path")
|
| 947 |
+
parser.add_argument(
|
| 948 |
+
"--export-model-type",
|
| 949 |
+
choices=['random_forest', 'logistic', 'distilled_rf'],
|
| 950 |
+
default='random_forest',
|
| 951 |
+
help="Model type to export (default: random_forest)",
|
| 952 |
+
)
|
| 953 |
+
parser.add_argument("--test-onnx", help="Test an ONNX model; specify ONNX file path")
|
| 954 |
+
parser.add_argument("--onnx-labels", help="ONNX label mapping JSON path (auto-detect if not provided)")
|
| 955 |
+
parser.add_argument("--onnx-test-data", help="ONNX batch test data directory (if not provided, single-sample test)")
|
| 956 |
+
parser.add_argument(
|
| 957 |
+
"--export-tflite",
|
| 958 |
+
help="Export model to TFLite format; specify output path (supported for distilled_rf student model only)",
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
args = parser.parse_args()
|
| 962 |
+
|
| 963 |
+
print("Pose classification ML tool")
|
| 964 |
+
print("=" * 60)
|
| 965 |
+
|
| 966 |
+
# If ONNX test mode
|
| 967 |
+
if args.test_onnx:
|
| 968 |
+
print("ONNX model test mode")
|
| 969 |
+
print(f"ONNX model: {args.test_onnx}")
|
| 970 |
+
print("=" * 60)
|
| 971 |
+
|
| 972 |
+
# Create classifier instance for testing
|
| 973 |
+
classifier = PoseClassifier()
|
| 974 |
+
# Note: test_onnx_model is not implemented in this script; this is a placeholder.
|
| 975 |
+
# You can implement it later if needed.
|
| 976 |
+
print("ONNX test requested but functionality is not implemented in this script.")
|
| 977 |
+
return
|
| 978 |
+
|
| 979 |
+
# If evaluation mode
|
| 980 |
+
if args.evaluate:
|
| 981 |
+
if not args.load_model:
|
| 982 |
+
# Try to use default model file
|
| 983 |
+
default_model = f"pose_classifier_{args.model}.pkl"
|
| 984 |
+
if Path(default_model).exists():
|
| 985 |
+
args.load_model = default_model
|
| 986 |
+
else:
|
| 987 |
+
print(
|
| 988 |
+
f"Error: Need to specify model file path (--load-model) or ensure default model file exists: {default_model}"
|
| 989 |
+
)
|
| 990 |
+
return
|
| 991 |
+
|
| 992 |
+
print("Evaluation mode")
|
| 993 |
+
print(f"Test data directory: {args.evaluate}")
|
| 994 |
+
print(f"Model file: {args.load_model}")
|
| 995 |
+
print("=" * 60)
|
| 996 |
+
|
| 997 |
+
# Create classifier and load model
|
| 998 |
+
classifier = PoseClassifier(model_type=args.model)
|
| 999 |
+
classifier.load_model(args.load_model)
|
| 1000 |
+
|
| 1001 |
+
# Perform comprehensive evaluation
|
| 1002 |
+
try:
|
| 1003 |
+
eval_results = classifier.evaluate_test_directory(args.evaluate)
|
| 1004 |
+
classifier.print_evaluation_report(eval_results)
|
| 1005 |
+
except Exception as e:
|
| 1006 |
+
print(f"Error during evaluation: {e}")
|
| 1007 |
+
|
| 1008 |
+
return
|
| 1009 |
+
|
| 1010 |
+
# Prediction-only mode
|
| 1011 |
+
if args.predict:
|
| 1012 |
+
if not args.load_model:
|
| 1013 |
+
# Try to use default model file
|
| 1014 |
+
default_model = f"pose_classifier_{args.model}.pkl"
|
| 1015 |
+
if Path(default_model).exists():
|
| 1016 |
+
args.load_model = default_model
|
| 1017 |
+
else:
|
| 1018 |
+
print(
|
| 1019 |
+
f"Error: Need to specify model file path (--load-model) or ensure default model file exists: {default_model}"
|
| 1020 |
+
)
|
| 1021 |
+
return
|
| 1022 |
+
|
| 1023 |
+
print("Prediction mode")
|
| 1024 |
+
print(f"JSON file: {args.predict}")
|
| 1025 |
+
print(f"Model file: {args.load_model}")
|
| 1026 |
+
print("=" * 60)
|
| 1027 |
+
|
| 1028 |
+
# Create classifier and load model
|
| 1029 |
+
classifier = PoseClassifier(model_type=args.model)
|
| 1030 |
+
classifier.load_model(args.load_model)
|
| 1031 |
+
|
| 1032 |
+
# Run prediction
|
| 1033 |
+
result = classifier.predict_single_json(args.predict)
|
| 1034 |
+
|
| 1035 |
+
# Show prediction result
|
| 1036 |
+
print("\nPrediction result:")
|
| 1037 |
+
print(f"File: {result['file_name']}")
|
| 1038 |
+
|
| 1039 |
+
if 'error' in result:
|
| 1040 |
+
print(f"Error: {result['error']}")
|
| 1041 |
+
else:
|
| 1042 |
+
print(f"Predicted label: {result['predicted_label']}")
|
| 1043 |
+
print(f"Joint coverage: {result['joint_coverage']}")
|
| 1044 |
+
|
| 1045 |
+
if result['confidence_scores']:
|
| 1046 |
+
print(f"Max confidence: {result['max_confidence']:.4f}")
|
| 1047 |
+
print("\nPer-class confidence:")
|
| 1048 |
+
sorted_scores = sorted(result['confidence_scores'].items(), key=lambda x: x[1], reverse=True)
|
| 1049 |
+
for label, score in sorted_scores:
|
| 1050 |
+
print(f" {label}: {score:.4f}")
|
| 1051 |
+
|
| 1052 |
+
if result['missing_joints']:
|
| 1053 |
+
print(f"\nMissing joints: {', '.join(result['missing_joints'])}")
|
| 1054 |
+
|
| 1055 |
+
return
|
| 1056 |
+
|
| 1057 |
+
# Training mode
|
| 1058 |
+
print("Training mode")
|
| 1059 |
+
print(f"Data directory: {args.data}")
|
| 1060 |
+
print(f"Model type: {args.model}")
|
| 1061 |
+
print(f"Test size: {args.test_size}")
|
| 1062 |
+
print("=" * 60)
|
| 1063 |
+
|
| 1064 |
+
# Check data directory
|
| 1065 |
+
if not Path(args.data).exists():
|
| 1066 |
+
print(f"Error: data directory does not exist: {args.data}")
|
| 1067 |
+
return
|
| 1068 |
+
|
| 1069 |
+
# Create classifier
|
| 1070 |
+
classifier = PoseClassifier(model_type=args.model)
|
| 1071 |
+
|
| 1072 |
+
# If loading an existing model and not forcing training
|
| 1073 |
+
if args.load_model and not args.train:
|
| 1074 |
+
print(f"Loading existing model: {args.load_model}")
|
| 1075 |
+
classifier.load_model(args.load_model)
|
| 1076 |
+
print("Model loaded, skipping training step")
|
| 1077 |
+
else:
|
| 1078 |
+
# Load data
|
| 1079 |
+
X, y = classifier.load_data(args.data)
|
| 1080 |
+
if len(X) == 0:
|
| 1081 |
+
print("Error: no valid data found")
|
| 1082 |
+
return
|
| 1083 |
+
# Train model
|
| 1084 |
+
results = classifier.train(X, y, test_size=args.test_size)
|
| 1085 |
+
# Plot confusion matrix (if not disabled)
|
| 1086 |
+
if not args.no_plot:
|
| 1087 |
+
try:
|
| 1088 |
+
classifier.plot_confusion_matrix(
|
| 1089 |
+
results['confusion_matrix'], results['target_names'], save_path=f"confusion_matrix_{args.model}.png"
|
| 1090 |
+
)
|
| 1091 |
+
except Exception as e:
|
| 1092 |
+
print(f"Error while plotting confusion matrix: {e}")
|
| 1093 |
+
# Save model (if specified)
|
| 1094 |
+
if args.save_model:
|
| 1095 |
+
classifier.save_model(args.save_model)
|
| 1096 |
+
else:
|
| 1097 |
+
# Default save path
|
| 1098 |
+
default_path = f"pose_classifier_{args.model}.pkl"
|
| 1099 |
+
classifier.save_model(default_path)
|
| 1100 |
+
print("\nTraining complete!")
|
| 1101 |
+
print(f"Final test accuracy: {results['test_accuracy']:.4f}")
|
| 1102 |
+
|
| 1103 |
+
# Export ONNX if requested
|
| 1104 |
+
if args.export_onnx:
|
| 1105 |
+
print(f"\nExporting {args.export_model_type} model to ONNX format...")
|
| 1106 |
+
onnx_path = classifier.export_to_onnx(model_type=args.export_model_type, output_path=args.export_onnx)
|
| 1107 |
+
if onnx_path:
|
| 1108 |
+
print(f"✅ ONNX model exported: {onnx_path}")
|
| 1109 |
+
|
| 1110 |
+
# Export TFLite if requested
|
| 1111 |
+
if args.export_tflite:
|
| 1112 |
+
print("\nExporting student_model to TFLite format...")
|
| 1113 |
+
tflite_path = classifier.export_to_tflite(output_path=args.export_tflite)
|
| 1114 |
+
if tflite_path:
|
| 1115 |
+
print(f"✅ TFLite model exported: {tflite_path}")
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
if __name__ == "__main__":
|
| 1120 |
+
main()
|
| 1121 |
+
|
pose_detection.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Use MediaPipe to detect poses in images and extract landmark coordinates.
|
| 4 |
+
|
| 5 |
+
Features:
|
| 6 |
+
1. Run MediaPipe pose detection on images in the train folder
|
| 7 |
+
2. Use the nose as the head reference point (headPos)
|
| 8 |
+
3. Process coordinates as: pos = (pos - headPos) * 100 and round to 2 decimals
|
| 9 |
+
4. Save processed landmarks into JSON files named after the image files
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python pose_detection.py [--input INPUT_DIR] [--output OUTPUT_DIR]
|
| 13 |
+
"""
|
| 14 |
+
import os
|
| 15 |
+
import json
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import cv2
|
| 19 |
+
import mediapipe as mp
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PoseDetector:
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initialize MediaPipe pose detector."""
|
| 25 |
+
self.mp_pose = mp.solutions.pose
|
| 26 |
+
self.pose = self.mp_pose.Pose(
|
| 27 |
+
static_image_mode=True,
|
| 28 |
+
model_complexity=2,
|
| 29 |
+
enable_segmentation=False,
|
| 30 |
+
min_detection_confidence=0.5
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# MediaPipe pose landmark name mapping
|
| 34 |
+
self.landmark_names = [
|
| 35 |
+
'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
|
| 36 |
+
'right_eye_inner', 'right_eye', 'right_eye_outer',
|
| 37 |
+
'left_ear', 'right_ear', 'mouth_left', 'mouth_right',
|
| 38 |
+
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
|
| 39 |
+
'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky',
|
| 40 |
+
'left_index', 'right_index', 'left_thumb', 'right_thumb',
|
| 41 |
+
'left_hip', 'right_hip', 'left_knee', 'right_knee',
|
| 42 |
+
'left_ankle', 'right_ankle', 'left_heel', 'right_heel',
|
| 43 |
+
'left_foot_index', 'right_foot_index'
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
def get_head_position(self, landmarks):
|
| 47 |
+
"""
|
| 48 |
+
Compute the head reference position (use the nose landmark).
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
landmarks: MediaPipe detected landmarks
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
tuple: (x, y, z) head coordinates
|
| 55 |
+
"""
|
| 56 |
+
# use nose as the head reference point
|
| 57 |
+
nose = landmarks[0] # nose is the 0th landmark
|
| 58 |
+
return (nose.x, nose.y, nose.z)
|
| 59 |
+
|
| 60 |
+
def process_landmarks(self, landmarks, head_pos):
|
| 61 |
+
"""
|
| 62 |
+
Process landmarks: pos = (pos - headPos) * 100 and round to 2 decimals.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
landmarks: MediaPipe detected landmarks
|
| 66 |
+
head_pos: head coordinates (x, y, z)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
dict: processed landmarks dictionary
|
| 70 |
+
"""
|
| 71 |
+
processed_landmarks = {}
|
| 72 |
+
head_pos_x = head_pos[0]
|
| 73 |
+
head_pos_y = head_pos[1]
|
| 74 |
+
head_pos_z = head_pos[2]
|
| 75 |
+
|
| 76 |
+
for i, landmark in enumerate(landmarks):
|
| 77 |
+
if i < len(self.landmark_names):
|
| 78 |
+
name = self.landmark_names[i]
|
| 79 |
+
|
| 80 |
+
# Calculate coordinates relative to head and multiply by 100
|
| 81 |
+
rel_x = round((landmark.x - head_pos_x) * 100, 2)
|
| 82 |
+
rel_y = round((landmark.y - head_pos_y) * 100, 2)
|
| 83 |
+
rel_z = round((landmark.z - head_pos_z) * 100, 2)
|
| 84 |
+
|
| 85 |
+
processed_landmarks[name] = {
|
| 86 |
+
'x': rel_x,
|
| 87 |
+
'y': rel_y,
|
| 88 |
+
'z': rel_z,
|
| 89 |
+
'visibility': round(landmark.visibility, 3)
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return processed_landmarks
|
| 93 |
+
|
| 94 |
+
def detect_pose(self, image_path):
|
| 95 |
+
"""
|
| 96 |
+
Detect pose for a single image.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image_path: path to the image file
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
dict: processed landmarks and metadata, or None on failure
|
| 103 |
+
"""
|
| 104 |
+
try:
|
| 105 |
+
# Read image
|
| 106 |
+
image = cv2.imread(str(image_path))
|
| 107 |
+
if image is None:
|
| 108 |
+
print(f"Unable to read image: {image_path}")
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
# Convert color space (BGR -> RGB)
|
| 112 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 113 |
+
|
| 114 |
+
# Run pose detection
|
| 115 |
+
results = self.pose.process(image_rgb)
|
| 116 |
+
|
| 117 |
+
if results.pose_landmarks is None:
|
| 118 |
+
print(f"No pose detected: {image_path}")
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
# Get keypoints
|
| 122 |
+
landmarks = results.pose_landmarks.landmark
|
| 123 |
+
|
| 124 |
+
# Get head position
|
| 125 |
+
head_pos = self.get_head_position(landmarks)
|
| 126 |
+
|
| 127 |
+
# Process keypoint coordinates
|
| 128 |
+
processed_landmarks = self.process_landmarks(landmarks, head_pos)
|
| 129 |
+
|
| 130 |
+
# extract label from parent folder name
|
| 131 |
+
label = image_path.parent.name
|
| 132 |
+
|
| 133 |
+
# Add metadata
|
| 134 |
+
result = {
|
| 135 |
+
'image_path': str(image_path),
|
| 136 |
+
'image_name': image_path.name,
|
| 137 |
+
'label': label,
|
| 138 |
+
'head_position': {
|
| 139 |
+
'x': round(head_pos[0], 4),
|
| 140 |
+
'y': round(head_pos[1], 4),
|
| 141 |
+
'z': round(head_pos[2], 4)
|
| 142 |
+
},
|
| 143 |
+
'landmarks': processed_landmarks,
|
| 144 |
+
'total_landmarks': len(processed_landmarks)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
return result
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Error processing image {image_path}: {e}")
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
def close(self):
|
| 154 |
+
"""Close MediaPipe resources."""
|
| 155 |
+
self.pose.close()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def process_all_training_data(input_dir, output_dir, batch_size=100):
|
| 159 |
+
"""
|
| 160 |
+
Process all images in the training dataset and write JSON files.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
input_dir: input images directory (TrainData/train)
|
| 164 |
+
output_dir: output JSON directory (PoseData)
|
| 165 |
+
batch_size: progress report batch size
|
| 166 |
+
"""
|
| 167 |
+
input_path = Path(input_dir)
|
| 168 |
+
output_path = Path(output_dir)
|
| 169 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 170 |
+
|
| 171 |
+
# Supported image formats
|
| 172 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
| 173 |
+
|
| 174 |
+
detector = PoseDetector()
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
# statistics
|
| 178 |
+
total_images = 0
|
| 179 |
+
success_count = 0
|
| 180 |
+
failed_count = 0
|
| 181 |
+
label_stats = {}
|
| 182 |
+
|
| 183 |
+
print(f"Starting processing dataset: {input_path}")
|
| 184 |
+
print(f"Output directory: {output_path}")
|
| 185 |
+
|
| 186 |
+
# first count all images
|
| 187 |
+
print("Counting images...")
|
| 188 |
+
label_dirs = []
|
| 189 |
+
for item in input_path.iterdir():
|
| 190 |
+
if item.is_dir() and item.name.startswith('label_'):
|
| 191 |
+
label = item.name
|
| 192 |
+
image_files = [f for f in item.iterdir()
|
| 193 |
+
if f.is_file() and f.suffix.lower() in image_extensions]
|
| 194 |
+
if image_files:
|
| 195 |
+
label_dirs.append((item, label, image_files))
|
| 196 |
+
total_images += len(image_files)
|
| 197 |
+
label_stats[label] = {'total': len(image_files), 'success': 0, 'failed': 0}
|
| 198 |
+
|
| 199 |
+
print(f"Found {len(label_dirs)} label directories, total {total_images} images")
|
| 200 |
+
for label, stats in label_stats.items():
|
| 201 |
+
print(f" {label}: {stats['total']} images")
|
| 202 |
+
|
| 203 |
+
print("\nStarting to process images...")
|
| 204 |
+
|
| 205 |
+
# process each label directory
|
| 206 |
+
for label_dir, label_name, image_files in label_dirs:
|
| 207 |
+
print(f"\n--- Processing {label_name} ({len(image_files)} images) ---")
|
| 208 |
+
|
| 209 |
+
# create output folder for this label
|
| 210 |
+
output_label_dir = output_path / label_name
|
| 211 |
+
output_label_dir.mkdir(parents=True, exist_ok=True)
|
| 212 |
+
|
| 213 |
+
# process every image in this label
|
| 214 |
+
for i, image_file in enumerate(image_files, 1):
|
| 215 |
+
json_filename = image_file.stem + '.json'
|
| 216 |
+
json_path = output_label_dir / json_filename
|
| 217 |
+
|
| 218 |
+
# detect pose
|
| 219 |
+
result = detector.detect_pose(image_file)
|
| 220 |
+
|
| 221 |
+
if result is not None:
|
| 222 |
+
# save JSON
|
| 223 |
+
try:
|
| 224 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 225 |
+
json.dump(result, f, ensure_ascii=False, indent=2)
|
| 226 |
+
success_count += 1
|
| 227 |
+
label_stats[label_name]['success'] += 1
|
| 228 |
+
|
| 229 |
+
# progress
|
| 230 |
+
if success_count % batch_size == 0:
|
| 231 |
+
progress = (success_count / total_images) * 100 if total_images else 0
|
| 232 |
+
print(f" Progress: {success_count}/{total_images} ({progress:.1f}%) - Current: {label_name} {i}/{len(image_files)}")
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f" Failed to save JSON {json_path}: {e}")
|
| 236 |
+
failed_count += 1
|
| 237 |
+
label_stats[label_name]['failed'] += 1
|
| 238 |
+
else:
|
| 239 |
+
failed_count += 1
|
| 240 |
+
label_stats[label_name]['failed'] += 1
|
| 241 |
+
if failed_count % 10 == 0: # print every 10 failures
|
| 242 |
+
print(f" Detection failed: {image_file.name}")
|
| 243 |
+
|
| 244 |
+
# report for this label
|
| 245 |
+
stats = label_stats[label_name]
|
| 246 |
+
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
|
| 247 |
+
print(f" {label_name} Done: Success {stats['success']}, Failed {stats['failed']}, Success rate: {success_rate:.1f}%")
|
| 248 |
+
|
| 249 |
+
print("\n" + "=" * 60)
|
| 250 |
+
print("Processing complete!")
|
| 251 |
+
print(f"Total images: {total_images}")
|
| 252 |
+
print(f"Successfully processed: {success_count}")
|
| 253 |
+
print(f"Failed: {failed_count}")
|
| 254 |
+
total_success_rate = (success_count / total_images) * 100 if total_images > 0 else 0
|
| 255 |
+
print(f"Overall success rate: {total_success_rate:.1f}%")
|
| 256 |
+
|
| 257 |
+
print("\nPer-label statistics:")
|
| 258 |
+
for label, stats in label_stats.items():
|
| 259 |
+
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
|
| 260 |
+
print(f" {label}: {stats['success']}/{stats['total']} ({success_rate:.1f}%)")
|
| 261 |
+
|
| 262 |
+
print(f"\nJSON files saved to: {output_path.absolute()}")
|
| 263 |
+
print("Directory structure:")
|
| 264 |
+
print("PoseData/")
|
| 265 |
+
for label in sorted(label_stats.keys()):
|
| 266 |
+
print(f"├── {label}/")
|
| 267 |
+
print("│ └── *.json")
|
| 268 |
+
|
| 269 |
+
finally:
|
| 270 |
+
detector.close()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def process_directory(input_dir, output_dir):
|
| 274 |
+
"""
|
| 275 |
+
Process all images in a directory tree and write JSON files.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
input_dir: input images directory
|
| 279 |
+
output_dir: output JSON directory
|
| 280 |
+
"""
|
| 281 |
+
input_path = Path(input_dir)
|
| 282 |
+
output_path = Path(output_dir)
|
| 283 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
# Supported image formats
|
| 286 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
| 287 |
+
|
| 288 |
+
detector = PoseDetector()
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
# statistics
|
| 292 |
+
total_images = 0
|
| 293 |
+
success_count = 0
|
| 294 |
+
failed_count = 0
|
| 295 |
+
|
| 296 |
+
print(f"Starting to process directory: {input_path}")
|
| 297 |
+
print(f"Output directory: {output_path}")
|
| 298 |
+
|
| 299 |
+
# walk through the tree
|
| 300 |
+
for root, dirs, files in os.walk(input_path):
|
| 301 |
+
root_path = Path(root)
|
| 302 |
+
|
| 303 |
+
# create corresponding output folder
|
| 304 |
+
relative_path = root_path.relative_to(input_path)
|
| 305 |
+
current_output_dir = output_path / relative_path
|
| 306 |
+
current_output_dir.mkdir(parents=True, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
# collect image files in this folder
|
| 309 |
+
image_files = [f for f in files if Path(f).suffix.lower() in image_extensions]
|
| 310 |
+
|
| 311 |
+
if image_files:
|
| 312 |
+
print(f"\nProcessing directory: {root_path}")
|
| 313 |
+
print(f"Found {len(image_files)} images")
|
| 314 |
+
|
| 315 |
+
for filename in image_files:
|
| 316 |
+
total_images += 1
|
| 317 |
+
image_path = root_path / filename
|
| 318 |
+
|
| 319 |
+
# generate JSON filename (replace extension with .json)
|
| 320 |
+
json_filename = Path(filename).stem + '.json'
|
| 321 |
+
json_path = current_output_dir / json_filename
|
| 322 |
+
|
| 323 |
+
# detect pose
|
| 324 |
+
result = detector.detect_pose(image_path)
|
| 325 |
+
|
| 326 |
+
if result is not None:
|
| 327 |
+
# save JSON file
|
| 328 |
+
try:
|
| 329 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 330 |
+
json.dump(result, f, ensure_ascii=False, indent=2)
|
| 331 |
+
success_count += 1
|
| 332 |
+
|
| 333 |
+
if success_count % 50 == 0:
|
| 334 |
+
print(f"Successfully processed {success_count} images...")
|
| 335 |
+
|
| 336 |
+
except Exception as e:
|
| 337 |
+
print(f"Failed to save JSON {json_path}: {e}")
|
| 338 |
+
failed_count += 1
|
| 339 |
+
else:
|
| 340 |
+
failed_count += 1
|
| 341 |
+
|
| 342 |
+
print("\nProcessing complete!")
|
| 343 |
+
print(f"Total images: {total_images}")
|
| 344 |
+
print(f"Successfully processed: {success_count}")
|
| 345 |
+
print(f"Failed: {failed_count}")
|
| 346 |
+
print(f"Success rate: {success_count/total_images*100:.1f}%")
|
| 347 |
+
|
| 348 |
+
finally:
|
| 349 |
+
detector.close()
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def main():
|
| 353 |
+
parser = argparse.ArgumentParser(description="Run MediaPipe pose detection and save landmark data")
|
| 354 |
+
parser.add_argument("--input", "-i", default="TrainData/train",
|
| 355 |
+
help="input images directory (default: TrainData/train)")
|
| 356 |
+
parser.add_argument("--output", "-o", default="PoseData",
|
| 357 |
+
help="output JSON directory (default: PoseData)")
|
| 358 |
+
parser.add_argument("--batch-size", "-b", type=int, default=100,
|
| 359 |
+
help="batch size for progress reporting (default: 100)")
|
| 360 |
+
|
| 361 |
+
args = parser.parse_args()
|
| 362 |
+
|
| 363 |
+
# check input directory exists
|
| 364 |
+
if not Path(args.input).exists():
|
| 365 |
+
print(f"Error: input directory does not exist: {args.input}")
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
print("MediaPipe pose detection tool")
|
| 369 |
+
print("=" * 60)
|
| 370 |
+
print(f"Input directory: {args.input}")
|
| 371 |
+
print(f"Output directory: {args.output}")
|
| 372 |
+
print("Processing rule: pos = (pos - headPos) * 100, round to 2 decimals")
|
| 373 |
+
print("Head reference: nose")
|
| 374 |
+
print(f"Batch size: show progress every {args.batch_size} images")
|
| 375 |
+
print("=" * 60)
|
| 376 |
+
|
| 377 |
+
# Start processing the entire training dataset
|
| 378 |
+
process_all_training_data(args.input, args.output, args.batch_size)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
main()
|
realtime_pose_classifier.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Real-time pose classifier
|
| 4 |
+
Uses MediaPipe to capture camera input, perform pose recognition and classification, and display results on screen
|
| 5 |
+
|
| 6 |
+
Features:
|
| 7 |
+
1. Use MediaPipe to obtain real-time pose data from camera
|
| 8 |
+
2. Extract joint coordinates and preprocess them
|
| 9 |
+
3. Use trained machine learning models for pose classification
|
| 10 |
+
4. Display classification results and keypoints in real-time on video screen
|
| 11 |
+
|
| 12 |
+
Dependencies:
|
| 13 |
+
pip install opencv-python mediapipe numpy scikit-learn
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python realtime_pose_classifier.py [--model MODEL_PATH] [--camera CAMERA_ID]
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import cv2
|
| 20 |
+
import mediapipe as mp
|
| 21 |
+
import numpy as np
|
| 22 |
+
import json
|
| 23 |
+
import joblib
|
| 24 |
+
import argparse
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
import traceback
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RealtimePoseClassifier:
|
| 31 |
+
def __init__(self, model_path=None, camera_id=0):
|
| 32 |
+
"""
|
| 33 |
+
Initialize real-time pose classifier
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_path (str): Model file path, auto-detect if None
|
| 37 |
+
camera_id (int): Camera ID, default 0
|
| 38 |
+
"""
|
| 39 |
+
self.camera_id = camera_id
|
| 40 |
+
|
| 41 |
+
# Initialize MediaPipe
|
| 42 |
+
self.mp_pose = mp.solutions.pose
|
| 43 |
+
self.mp_drawing = mp.solutions.drawing_utils
|
| 44 |
+
self.mp_drawing_styles = mp.solutions.drawing_styles
|
| 45 |
+
|
| 46 |
+
# Configure pose detector
|
| 47 |
+
self.pose = self.mp_pose.Pose(
|
| 48 |
+
static_image_mode=False,
|
| 49 |
+
model_complexity=1, # Use lower complexity for real-time applications
|
| 50 |
+
enable_segmentation=False,
|
| 51 |
+
min_detection_confidence=0.7,
|
| 52 |
+
min_tracking_confidence=0.5
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# MediaPipe landmark name mapping
|
| 56 |
+
self.landmark_names = [
|
| 57 |
+
'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
|
| 58 |
+
'right_eye_inner', 'right_eye', 'right_eye_outer',
|
| 59 |
+
'left_ear', 'right_ear', 'mouth_left', 'mouth_right',
|
| 60 |
+
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
|
| 61 |
+
'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky',
|
| 62 |
+
'left_index', 'right_index', 'left_thumb', 'right_thumb',
|
| 63 |
+
'left_hip', 'right_hip', 'left_knee', 'right_knee',
|
| 64 |
+
'left_ankle', 'right_ankle', 'left_heel', 'right_heel',
|
| 65 |
+
'left_foot_index', 'right_foot_index'
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# Load model
|
| 69 |
+
self.model = None
|
| 70 |
+
self.scaler = None
|
| 71 |
+
self.label_encoder = None
|
| 72 |
+
self.target_joints = None
|
| 73 |
+
self.model_info = None
|
| 74 |
+
|
| 75 |
+
self.load_model(model_path)
|
| 76 |
+
|
| 77 |
+
# Prediction result cache
|
| 78 |
+
self.prediction_history = []
|
| 79 |
+
self.history_size = 5 # Keep recent 5 predictions for smoothing
|
| 80 |
+
|
| 81 |
+
# Performance statistics
|
| 82 |
+
self.fps_counter = 0
|
| 83 |
+
self.fps_start_time = time.time()
|
| 84 |
+
self.current_fps = 0
|
| 85 |
+
|
| 86 |
+
# Added: Time statistics
|
| 87 |
+
self.mediapipe_time_total = 0.0
|
| 88 |
+
self.mediapipe_time_count = 0
|
| 89 |
+
self.feature_pred_time_total = 0.0
|
| 90 |
+
self.feature_pred_time_count = 0
|
| 91 |
+
|
| 92 |
+
# Display settings
|
| 93 |
+
self.show_landmarks = True
|
| 94 |
+
self.show_connections = True
|
| 95 |
+
|
| 96 |
+
def load_model(self, model_path=None):
|
| 97 |
+
"""Load trained model"""
|
| 98 |
+
if model_path is None:
|
| 99 |
+
# Auto-detect available model files
|
| 100 |
+
possible_models = [
|
| 101 |
+
'pose_classifier_random_forest.pkl',
|
| 102 |
+
'pose_classifier_logistic.pkl',
|
| 103 |
+
'pose_classifier_distilled_rf.pkl'
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
for model_file in possible_models:
|
| 107 |
+
if Path(model_file).exists():
|
| 108 |
+
model_path = model_file
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if model_path is None:
|
| 112 |
+
raise FileNotFoundError("No available model file found, please specify model path")
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
print(f"Loading model: {model_path}")
|
| 116 |
+
model_data = joblib.load(model_path)
|
| 117 |
+
|
| 118 |
+
self.model = model_data['model']
|
| 119 |
+
self.scaler = model_data['scaler']
|
| 120 |
+
self.label_encoder = model_data['label_encoder']
|
| 121 |
+
self.target_joints = model_data['target_joints']
|
| 122 |
+
|
| 123 |
+
# Try to load corresponding labels file
|
| 124 |
+
labels_path = model_path.replace('.pkl', '_labels.json')
|
| 125 |
+
if Path(labels_path).exists():
|
| 126 |
+
with open(labels_path, 'r') as f:
|
| 127 |
+
self.model_info = json.load(f)
|
| 128 |
+
print(f"Loaded label information: {labels_path}")
|
| 129 |
+
|
| 130 |
+
print("Model loaded successfully!")
|
| 131 |
+
print(f"Target joints: {self.target_joints}")
|
| 132 |
+
print(f"Classification classes: {self.label_encoder.classes_}")
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
raise RuntimeError(f"Model loading failed: {e}")
|
| 136 |
+
|
| 137 |
+
def extract_pose_features(self, landmarks):
|
| 138 |
+
"""
|
| 139 |
+
Extract pose features from MediaPipe landmarks (vectorized optimized version)
|
| 140 |
+
"""
|
| 141 |
+
if landmarks is None:
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
# Get all joint coordinates as NumPy array
|
| 145 |
+
coords = np.array([[lm.x, lm.y, lm.z] for lm in landmarks.landmark], dtype=np.float32)
|
| 146 |
+
|
| 147 |
+
# Get head position (nose as reference point)
|
| 148 |
+
try:
|
| 149 |
+
head_idx = self.landmark_names.index('nose')
|
| 150 |
+
head_pos = coords[head_idx]
|
| 151 |
+
except ValueError:
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# Build target joint indices list
|
| 155 |
+
joint_indices = [self.landmark_names.index(j) if j in self.landmark_names else -1 for j in self.target_joints]
|
| 156 |
+
|
| 157 |
+
# Extract target joint coordinates (fill with 0 if not exist)
|
| 158 |
+
joint_coords = np.array([
|
| 159 |
+
coords[idx] if idx >= 0 else np.zeros(3, dtype=np.float32)
|
| 160 |
+
for idx in joint_indices
|
| 161 |
+
], dtype=np.float32)
|
| 162 |
+
|
| 163 |
+
# Calculate relative position to head and scale
|
| 164 |
+
relative_coords = (joint_coords - head_pos) * 100 # Keep consistent with training processing
|
| 165 |
+
|
| 166 |
+
# Keep two decimal places
|
| 167 |
+
features = np.round(relative_coords, 2).flatten()
|
| 168 |
+
|
| 169 |
+
return features
|
| 170 |
+
|
| 171 |
+
def predict_pose(self, features):
|
| 172 |
+
"""
|
| 173 |
+
Use machine learning model to predict pose
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
features: Feature vector
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
dict: Prediction result containing label, confidence, etc.
|
| 180 |
+
"""
|
| 181 |
+
if features is None or self.model is None:
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
# Standardize features
|
| 186 |
+
features_scaled = self.scaler.transform(features.reshape(1, -1))
|
| 187 |
+
|
| 188 |
+
# Predict
|
| 189 |
+
prediction = self.model.predict(features_scaled)[0]
|
| 190 |
+
predicted_label = self.label_encoder.inverse_transform([prediction])[0]
|
| 191 |
+
|
| 192 |
+
# Get confidence (if model supports probability prediction)
|
| 193 |
+
confidence = 0.0
|
| 194 |
+
probabilities = None
|
| 195 |
+
if hasattr(self.model, 'predict_proba'):
|
| 196 |
+
probs = self.model.predict_proba(features_scaled)[0]
|
| 197 |
+
confidence = float(np.max(probs))
|
| 198 |
+
probabilities = dict(zip(self.label_encoder.classes_, probs))
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
'predicted_label': predicted_label,
|
| 202 |
+
'confidence': confidence,
|
| 203 |
+
'probabilities': probabilities
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print(f"Prediction error: {e}")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
def smooth_predictions(self, current_prediction):
|
| 211 |
+
"""
|
| 212 |
+
Smooth prediction results
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
current_prediction: Current prediction result
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
dict: Smoothed prediction result
|
| 219 |
+
"""
|
| 220 |
+
if current_prediction is None:
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
# Add to history
|
| 224 |
+
self.prediction_history.append(current_prediction)
|
| 225 |
+
if len(self.prediction_history) > self.history_size:
|
| 226 |
+
self.prediction_history.pop(0)
|
| 227 |
+
|
| 228 |
+
# If history is insufficient, return current prediction directly
|
| 229 |
+
if len(self.prediction_history) < 3:
|
| 230 |
+
return current_prediction
|
| 231 |
+
|
| 232 |
+
# Count recent prediction labels
|
| 233 |
+
recent_labels = [pred['predicted_label'] for pred in self.prediction_history]
|
| 234 |
+
|
| 235 |
+
# Use mode as final prediction
|
| 236 |
+
from collections import Counter
|
| 237 |
+
label_counts = Counter(recent_labels)
|
| 238 |
+
most_common_label = label_counts.most_common(1)[0][0]
|
| 239 |
+
|
| 240 |
+
# Calculate average confidence for this label
|
| 241 |
+
avg_confidence = np.mean([
|
| 242 |
+
pred['confidence'] for pred in self.prediction_history
|
| 243 |
+
if pred['predicted_label'] == most_common_label
|
| 244 |
+
])
|
| 245 |
+
|
| 246 |
+
return {
|
| 247 |
+
'predicted_label': most_common_label,
|
| 248 |
+
'confidence': avg_confidence,
|
| 249 |
+
'stability': label_counts[most_common_label] / len(recent_labels)
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
def draw_pose_info(self, image, landmarks, prediction_result):
|
| 253 |
+
"""
|
| 254 |
+
Draw pose information on image
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
image: OpenCV image
|
| 258 |
+
landmarks: MediaPipe landmarks
|
| 259 |
+
prediction_result: Prediction result
|
| 260 |
+
"""
|
| 261 |
+
height, width = image.shape[:2]
|
| 262 |
+
|
| 263 |
+
# Draw pose skeleton
|
| 264 |
+
if landmarks and self.show_connections:
|
| 265 |
+
self.mp_drawing.draw_landmarks(
|
| 266 |
+
image,
|
| 267 |
+
landmarks,
|
| 268 |
+
self.mp_pose.POSE_CONNECTIONS,
|
| 269 |
+
landmark_drawing_spec=self.mp_drawing_styles.get_default_pose_landmarks_style()
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Draw keypoints
|
| 273 |
+
if landmarks and self.show_landmarks:
|
| 274 |
+
for i, landmark in enumerate(landmarks.landmark):
|
| 275 |
+
if self.landmark_names[i] in self.target_joints:
|
| 276 |
+
x = int(landmark.x * width)
|
| 277 |
+
y = int(landmark.y * height)
|
| 278 |
+
cv2.circle(image, (x, y), 8, (0, 255, 0), -1)
|
| 279 |
+
cv2.putText(image, self.landmark_names[i], (x + 10, y - 10),
|
| 280 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
| 281 |
+
|
| 282 |
+
# Display prediction results
|
| 283 |
+
if prediction_result:
|
| 284 |
+
label = prediction_result['predicted_label']
|
| 285 |
+
confidence = prediction_result.get('confidence', 0.0)
|
| 286 |
+
stability = prediction_result.get('stability', 1.0)
|
| 287 |
+
|
| 288 |
+
# Set color based on confidence
|
| 289 |
+
if confidence > 0.8:
|
| 290 |
+
color = (0, 255, 0) # Green - high confidence
|
| 291 |
+
elif confidence > 0.6:
|
| 292 |
+
color = (0, 255, 255) # Yellow - medium confidence
|
| 293 |
+
else:
|
| 294 |
+
color = (0, 0, 255) # Red - low confidence
|
| 295 |
+
|
| 296 |
+
# Draw prediction result background box
|
| 297 |
+
cv2.rectangle(image, (10, 10), (400, 120), (0, 0, 0), -1)
|
| 298 |
+
cv2.rectangle(image, (10, 10), (400, 120), color, 2)
|
| 299 |
+
|
| 300 |
+
# Display prediction label
|
| 301 |
+
cv2.putText(image, f"Pose: {label}", (20, 40),
|
| 302 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
|
| 303 |
+
|
| 304 |
+
# Display confidence
|
| 305 |
+
cv2.putText(image, f"Confidence: {confidence:.2f}", (20, 70),
|
| 306 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 307 |
+
|
| 308 |
+
# Display stability
|
| 309 |
+
cv2.putText(image, f"Stability: {stability:.2f}", (20, 95),
|
| 310 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 311 |
+
|
| 312 |
+
# Display FPS
|
| 313 |
+
cv2.putText(image, f"FPS: {self.current_fps:.1f}", (width - 150, 30),
|
| 314 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 315 |
+
|
| 316 |
+
# Display control instructions
|
| 317 |
+
instructions = [
|
| 318 |
+
"Controls:",
|
| 319 |
+
"Q - Quit",
|
| 320 |
+
"L - Toggle Landmarks",
|
| 321 |
+
"C - Toggle Connections",
|
| 322 |
+
"R - Reset History"
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
for i, instruction in enumerate(instructions):
|
| 326 |
+
cv2.putText(image, instruction, (width - 200, height - 120 + i * 25),
|
| 327 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
|
| 328 |
+
|
| 329 |
+
# Added: Display timing statistics
|
| 330 |
+
mp_avg = self.mediapipe_time_total / self.mediapipe_time_count if self.mediapipe_time_count else 0.0
|
| 331 |
+
fp_avg = self.feature_pred_time_total / self.feature_pred_time_count if self.feature_pred_time_count else 0.0
|
| 332 |
+
cv2.putText(image, f"MP avg: {mp_avg*1000:.1f}ms", (width - 150, 55),
|
| 333 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
| 334 |
+
cv2.putText(image, f"FP avg: {fp_avg*1000:.1f}ms", (width - 150, 75),
|
| 335 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
| 336 |
+
# Display average frame rate
|
| 337 |
+
total_frames = max(self.mediapipe_time_count, 1)
|
| 338 |
+
avg_fps = total_frames / max(self.mediapipe_time_total + self.feature_pred_time_total, 1e-6)
|
| 339 |
+
cv2.putText(image, f"Avg FPS: {avg_fps:.1f}", (width - 150, 95),
|
| 340 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
| 341 |
+
|
| 342 |
+
def update_fps(self):
|
| 343 |
+
"""Update FPS calculation"""
|
| 344 |
+
self.fps_counter += 1
|
| 345 |
+
if self.fps_counter >= 30: # Update FPS every 30 frames
|
| 346 |
+
current_time = time.time()
|
| 347 |
+
self.current_fps = 30 / (current_time - self.fps_start_time)
|
| 348 |
+
self.fps_start_time = current_time
|
| 349 |
+
self.fps_counter = 0
|
| 350 |
+
|
| 351 |
+
def run(self):
|
| 352 |
+
"""Run real-time pose classification"""
|
| 353 |
+
print("Starting real-time pose classifier...")
|
| 354 |
+
print("Press 'Q' to quit, 'L' to toggle landmark display, 'C' to toggle skeleton connections, 'R' to reset history")
|
| 355 |
+
|
| 356 |
+
# Initialize camera
|
| 357 |
+
cap = cv2.VideoCapture(self.camera_id)
|
| 358 |
+
if not cap.isOpened():
|
| 359 |
+
raise RuntimeError(f"Cannot open camera {self.camera_id}")
|
| 360 |
+
|
| 361 |
+
# Set camera parameters
|
| 362 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
|
| 363 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
|
| 364 |
+
cap.set(cv2.CAP_PROP_FPS, 30)
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
while True:
|
| 368 |
+
success, frame = cap.read()
|
| 369 |
+
if not success:
|
| 370 |
+
print("Cannot read camera frame")
|
| 371 |
+
break
|
| 372 |
+
|
| 373 |
+
# Flip image horizontally (mirror effect)
|
| 374 |
+
frame = cv2.flip(frame, 1)
|
| 375 |
+
|
| 376 |
+
# Convert color space
|
| 377 |
+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 378 |
+
|
| 379 |
+
# Time MediaPipe pose detection
|
| 380 |
+
mp_start = time.time()
|
| 381 |
+
results = self.pose.process(rgb_frame)
|
| 382 |
+
mp_end = time.time()
|
| 383 |
+
self.mediapipe_time_total += (mp_end - mp_start)
|
| 384 |
+
self.mediapipe_time_count += 1
|
| 385 |
+
|
| 386 |
+
# Extract features and predict
|
| 387 |
+
fp_start = time.time()
|
| 388 |
+
prediction_result = None
|
| 389 |
+
if results.pose_landmarks:
|
| 390 |
+
features = self.extract_pose_features(results.pose_landmarks)
|
| 391 |
+
if features is not None:
|
| 392 |
+
raw_prediction = self.predict_pose(features)
|
| 393 |
+
prediction_result = self.smooth_predictions(raw_prediction)
|
| 394 |
+
fp_end = time.time()
|
| 395 |
+
self.feature_pred_time_total += (fp_end - fp_start)
|
| 396 |
+
self.feature_pred_time_count += 1
|
| 397 |
+
|
| 398 |
+
# Draw results
|
| 399 |
+
self.draw_pose_info(frame, results.pose_landmarks, prediction_result)
|
| 400 |
+
|
| 401 |
+
# Update FPS
|
| 402 |
+
self.update_fps()
|
| 403 |
+
|
| 404 |
+
# Display image
|
| 405 |
+
cv2.imshow('Real-time Pose Classification', frame)
|
| 406 |
+
|
| 407 |
+
# Handle key presses
|
| 408 |
+
key = cv2.waitKey(1) & 0xFF
|
| 409 |
+
if key == ord('q') or key == ord('Q'):
|
| 410 |
+
break
|
| 411 |
+
elif key == ord('l') or key == ord('L'):
|
| 412 |
+
self.show_landmarks = not self.show_landmarks
|
| 413 |
+
print(f"Landmark display: {'On' if self.show_landmarks else 'Off'}")
|
| 414 |
+
elif key == ord('c') or key == ord('C'):
|
| 415 |
+
self.show_connections = not self.show_connections
|
| 416 |
+
print(f"Skeleton connection display: {'On' if self.show_connections else 'Off'}")
|
| 417 |
+
elif key == ord('r') or key == ord('R'):
|
| 418 |
+
self.prediction_history.clear()
|
| 419 |
+
print("Prediction history reset")
|
| 420 |
+
|
| 421 |
+
except KeyboardInterrupt:
|
| 422 |
+
print("\nUser interrupted program")
|
| 423 |
+
except Exception as e:
|
| 424 |
+
print(f"Runtime error: {e}")
|
| 425 |
+
traceback.print_exc()
|
| 426 |
+
finally:
|
| 427 |
+
cap.release()
|
| 428 |
+
cv2.destroyAllWindows()
|
| 429 |
+
print("Program exited")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def main():
|
| 433 |
+
"""Main function"""
|
| 434 |
+
parser = argparse.ArgumentParser(description='Real-time pose classifier')
|
| 435 |
+
parser.add_argument('--model', '-m', type=str, default=None,
|
| 436 |
+
help='Model file path (auto-detect by default)')
|
| 437 |
+
parser.add_argument('--camera', '-c', type=int, default=0,
|
| 438 |
+
help='Camera ID (default 0)')
|
| 439 |
+
|
| 440 |
+
args = parser.parse_args()
|
| 441 |
+
|
| 442 |
+
try:
|
| 443 |
+
classifier = RealtimePoseClassifier(
|
| 444 |
+
model_path=args.model,
|
| 445 |
+
camera_id=args.camera
|
| 446 |
+
)
|
| 447 |
+
classifier.run()
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"Program startup failed: {e}")
|
| 450 |
+
return 1
|
| 451 |
+
|
| 452 |
+
return 0
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
if __name__ == "__main__":
|
| 456 |
+
exit(main())
|