pegasama commited on
Commit
b26156a
·
verified ·
1 Parent(s): 61bf3ea

train and test python script

Browse files
extract_images.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Extract images and labels from Parquet files and save them into
3
+ subfolders by label.
4
+
5
+ Usage:
6
+ python extract_images.py [--train] [--test] [--output OUTPUT_DIR]
7
+
8
+ Defaults:
9
+ train: process training data (train-00000-of-00001.parquet)
10
+ test: process test data (test-00000-of-00001.parquet)
11
+ output: TrainData (relative to script location)
12
+ """
13
+ import os
14
+ import sys
15
+ import argparse
16
+ from pathlib import Path
17
+ import pyarrow.parquet as pq
18
+
19
+
20
+ def extract_images_from_parquet(parquet_path, output_dir, split_name):
21
+ """Extract images from a Parquet file and save them into label folders."""
22
+
23
+ print(f"Processing {parquet_path}...")
24
+
25
+ # read parquet file
26
+ try:
27
+ table = pq.read_table(parquet_path)
28
+ df = table.to_pandas()
29
+ except Exception as e:
30
+ print(f"Failed to read parquet file: {e}")
31
+ return False
32
+
33
+ print(f"Found {len(df)} images")
34
+
35
+ # get unique labels
36
+ unique_labels = sorted(df['label'].unique())
37
+ print(f"Label classes: {unique_labels}")
38
+
39
+ # create folder for each label
40
+ for label in unique_labels:
41
+ label_dir = output_dir / split_name / f"label_{label}"
42
+ label_dir.mkdir(parents=True, exist_ok=True)
43
+ print(f"Created folder: {label_dir}")
44
+
45
+ # extract and save images
46
+ success_count = 0
47
+ error_count = 0
48
+
49
+ for idx, row in df.iterrows():
50
+ try:
51
+ # get image data
52
+ image_struct = row['image']
53
+ image_bytes = image_struct['bytes']
54
+ original_path = image_struct['path']
55
+ label = row['label']
56
+
57
+ # get file extension
58
+ _, ext = os.path.splitext(original_path)
59
+ if not ext:
60
+ ext = '.jpg' # default extension
61
+
62
+ # build a new filename (preserve original base name, avoid collisions)
63
+ base_name = os.path.splitext(os.path.basename(original_path))[0]
64
+ filename = f"{base_name}{ext}"
65
+
66
+ # ensure filename is unique
67
+ label_dir = output_dir / split_name / f"label_{label}"
68
+ output_path = label_dir / filename
69
+ counter = 1
70
+ while output_path.exists():
71
+ filename = f"{base_name}_{counter}{ext}"
72
+ output_path = label_dir / filename
73
+ counter += 1
74
+
75
+ # save image
76
+ with open(output_path, 'wb') as f:
77
+ f.write(image_bytes)
78
+
79
+ success_count += 1
80
+ if success_count % 100 == 0:
81
+ print(f"Processed {success_count} images...")
82
+
83
+ except Exception as e:
84
+ print(f"Error processing image {idx}: {e}")
85
+ error_count += 1
86
+ continue
87
+
88
+ print(f"Done! Success: {success_count}, Failed: {error_count}")
89
+
90
+ # report counts per label
91
+ print("\nImage count per label:")
92
+ for label in unique_labels:
93
+ label_dir = output_dir / split_name / f"label_{label}"
94
+ count = len(list(label_dir.glob("*")))
95
+ print(f" label {label}: {count} images")
96
+
97
+ return success_count > 0
98
+
99
+
100
+ def main():
101
+ parser = argparse.ArgumentParser(description="Extract images from Parquet files and organize by label")
102
+ parser.add_argument("--train", action="store_true", help="process training data")
103
+ parser.add_argument("--test", action="store_true", help="process test data")
104
+ parser.add_argument("--output", "-o", default="TrainData", help="output directory")
105
+
106
+ args = parser.parse_args()
107
+
108
+ # if neither train nor test specified, do both by default
109
+ if not args.train and not args.test:
110
+ args.train = True
111
+ args.test = True
112
+
113
+ # set paths
114
+ script_dir = Path(__file__).parent
115
+ yoga_data_dir = script_dir / "YogaDataSet" / "data"
116
+ output_dir = Path(args.output)
117
+
118
+ # ensure output directory exists
119
+ output_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ print(f"Output directory: {output_dir.absolute()}")
122
+
123
+ success = True
124
+
125
+ # process training data
126
+ if args.train:
127
+ train_parquet = yoga_data_dir / "train-00000-of-00001.parquet"
128
+ if train_parquet.exists():
129
+ if not extract_images_from_parquet(train_parquet, output_dir, "train"):
130
+ success = False
131
+ else:
132
+ print(f"Training parquet file not found: {train_parquet}")
133
+ success = False
134
+
135
+ # process test data
136
+ if args.test:
137
+ test_parquet = yoga_data_dir / "test-00000-of-00001.parquet"
138
+ if test_parquet.exists():
139
+ if not extract_images_from_parquet(test_parquet, output_dir, "test"):
140
+ success = False
141
+ else:
142
+ print(f"Test parquet file not found: {test_parquet}")
143
+ success = False
144
+
145
+ if success:
146
+ print("\n✅ All images extracted!")
147
+ print(f"Images saved to: {output_dir.absolute()}")
148
+ print("Directory structure:")
149
+ print("TrainData/")
150
+ if args.train:
151
+ print("├── train/")
152
+ print("│ ├── label_0/")
153
+ print("│ ├── label_1/")
154
+ print("│ └── ...")
155
+ if args.test:
156
+ print("└── test/")
157
+ print(" ├── label_0/")
158
+ print(" ├── label_1/")
159
+ print(" └── ...")
160
+ else:
161
+ print("\n❌ Errors occurred during extraction")
162
+ sys.exit(1)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
ml_pose_classifier.py ADDED
@@ -0,0 +1,1121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Machine learning pose classification script.
4
+
5
+ Features:
6
+ 1. Train classifiers on pose landmark inputs
7
+ 2. Use selected landmark coordinates as features
8
+ 3. Use folder names as class labels
9
+ 4. Train and evaluate models
10
+
11
+ Usage:
12
+ python ml_pose_classifier.py [--data DATA_DIR] [--model MODEL_TYPE] [--test-size RATIO]
13
+ """
14
+
15
+ import json
16
+ import argparse
17
+ import numpy as np
18
+ import time
19
+ from pathlib import Path
20
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
21
+ from sklearn.svm import SVC
22
+ from sklearn.linear_model import LogisticRegression
23
+ from sklearn.model_selection import train_test_split, cross_val_score
24
+ from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
25
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
26
+ # from sklearn.pipeline import Pipeline # not used
27
+ from sklearn.neural_network import MLPRegressor
28
+ import joblib
29
+ import matplotlib.pyplot as plt
30
+ # seaborn is optional; used only for confusion matrix plotting
31
+ try:
32
+ import seaborn as sns
33
+ SEABORN_AVAILABLE = True
34
+ except ImportError:
35
+ SEABORN_AVAILABLE = False
36
+
37
+ # ONNX related imports
38
+ try:
39
+ from skl2onnx import convert_sklearn
40
+ from skl2onnx.common.data_types import FloatTensorType
41
+ # onnx is not required here; we import it lazily where needed
42
+ ONNX_AVAILABLE = True
43
+ except ImportError:
44
+ ONNX_AVAILABLE = False
45
+
46
+ # ONNX Runtime import
47
+ try:
48
+ # onnxruntime is optional and not required unless ONNX runtime testing is implemented
49
+ ONNX_RUNTIME_AVAILABLE = False
50
+ except ImportError:
51
+ ONNX_RUNTIME_AVAILABLE = False
52
+
53
+
54
+ class PoseClassifier:
55
+ def __init__(self, model_type='random_forest'):
56
+ """
57
+ Initialize the pose classifier.
58
+
59
+ Args:
60
+ model_type: model type ('random_forest', 'svm', 'gradient_boost', 'logistic', 'distilled_rf')
61
+ """
62
+ self.model_type = model_type
63
+ self.model = None
64
+ self.student_model = None # If distillation is used, save student (MLP) model
65
+ self.scaler = StandardScaler()
66
+ self.label_encoder = LabelEncoder()
67
+
68
+ # Define joints we want to use (based on MediaPipe keypoint indices)
69
+ self.target_joints = [
70
+ 'nose', # Head (nose as reference, but will actually be 0,0,0)
71
+ 'left_shoulder', # Left shoulder
72
+ 'right_shoulder', # Right shoulder
73
+ 'left_elbow', # Left elbow
74
+ 'right_elbow', # Right elbow
75
+ 'left_wrist', # Left wrist
76
+ 'right_wrist', # Right wrist
77
+ 'left_hip', # Left hip
78
+ 'right_hip', # Right hip
79
+ 'left_knee', # Left knee
80
+ 'right_knee', # Right knee
81
+ 'left_ankle', # Left ankle
82
+ 'right_ankle' # Right ankle
83
+ ]
84
+
85
+ self.feature_columns = []
86
+ for joint in self.target_joints:
87
+ self.feature_columns.extend([f'{joint}_x', f'{joint}_y', f'{joint}_z'])
88
+
89
+ print(f"Target joints: {len(self.target_joints)}")
90
+ print(f"Feature dimension: {len(self.feature_columns)}")
91
+ print("Joint list:", ', '.join(self.target_joints))
92
+
93
+ def _get_model(self):
94
+ """Create a classifier based on the selected model type."""
95
+ if self.model_type == 'random_forest':
96
+ return RandomForestClassifier(
97
+ n_estimators=100,
98
+ max_depth=15,
99
+ min_samples_split=5,
100
+ min_samples_leaf=2,
101
+ random_state=42,
102
+ n_jobs=-1
103
+ )
104
+ elif self.model_type == 'svm':
105
+ return SVC(
106
+ C=1.0,
107
+ kernel='rbf',
108
+ gamma='scale',
109
+ random_state=42
110
+ )
111
+ elif self.model_type == 'gradient_boost':
112
+ return GradientBoostingClassifier(
113
+ n_estimators=100,
114
+ learning_rate=0.1,
115
+ max_depth=6,
116
+ random_state=42
117
+ )
118
+ elif self.model_type == 'logistic':
119
+ return LogisticRegression(
120
+ C=10.0, # Increase regularization parameter to improve model complexity
121
+ max_iter=2000, # Increase maximum iterations
122
+ solver='lbfgs', # Use L-BFGS solver, suitable for small datasets
123
+ multi_class='multinomial', # Multi-class strategy
124
+ random_state=42,
125
+ n_jobs=-1
126
+ )
127
+ elif self.model_type == 'distilled_rf':
128
+ # Teacher uses random forest (returns an RF for training process)
129
+ return RandomForestClassifier(
130
+ n_estimators=100,
131
+ max_depth=15,
132
+ min_samples_split=5,
133
+ min_samples_leaf=2,
134
+ random_state=42,
135
+ n_jobs=-1
136
+ )
137
+ else:
138
+ raise ValueError(f"Unsupported model type: {self.model_type}")
139
+
140
+ def load_data(self, data_dir):
141
+ """
142
+ Load pose data from JSON files
143
+
144
+ Args:
145
+ data_dir: Data directory containing label folders
146
+
147
+ Returns:
148
+ tuple: (feature data, labels)
149
+ """
150
+ data_path = Path(data_dir)
151
+ all_features = []
152
+ all_labels = []
153
+
154
+ print(f"Loading data from: {data_path}")
155
+
156
+ # Iterate over each label directory
157
+ for label_dir in data_path.iterdir():
158
+ if not label_dir.is_dir() or not label_dir.name.startswith('label_'):
159
+ continue
160
+
161
+ label = label_dir.name
162
+ json_files = list(label_dir.glob('*.json'))
163
+
164
+ print(f"Processing {label}: {len(json_files)} files")
165
+
166
+ for json_file in json_files:
167
+ try:
168
+ with open(json_file, 'r', encoding='utf-8') as f:
169
+ data = json.load(f)
170
+
171
+ landmarks = data.get('landmarks', {})
172
+
173
+ # Extract coordinates of target joints
174
+ features = []
175
+ missing_joints = []
176
+
177
+ for joint in self.target_joints:
178
+ if joint in landmarks:
179
+ joint_data = landmarks[joint]
180
+ features.extend([
181
+ joint_data.get('x', 0.0),
182
+ joint_data.get('y', 0.0),
183
+ joint_data.get('z', 0.0)
184
+ ])
185
+ else:
186
+ # If a joint is missing, fill with zeros
187
+ features.extend([0.0, 0.0, 0.0])
188
+ missing_joints.append(joint)
189
+
190
+ if len(features) == len(self.feature_columns):
191
+ all_features.append(features)
192
+ all_labels.append(label)
193
+ else:
194
+ print(f"Skipping file {json_file}: feature dimension mismatch")
195
+
196
+ if missing_joints:
197
+ print(f"File {json_file.name} missing joints: {missing_joints}")
198
+
199
+ except Exception as e:
200
+ print(f"Error reading file {json_file}: {e}")
201
+ continue
202
+
203
+ print(f"Loaded {len(all_features)} samples")
204
+
205
+ # count samples per label
206
+ label_counts = {}
207
+ for label in all_labels:
208
+ label_counts[label] = label_counts.get(label, 0) + 1
209
+
210
+ print("Label distribution:")
211
+ for label, count in sorted(label_counts.items()):
212
+ print(f" {label}: {count} samples")
213
+
214
+ return np.array(all_features), np.array(all_labels)
215
+
216
+ def train(self, X, y, test_size=0.2):
217
+ """
218
+ Train the classifier.
219
+
220
+ Args:
221
+ X: feature data
222
+ y: labels
223
+ test_size: ratio for test split
224
+
225
+ Returns:
226
+ dict: a dictionary containing training results
227
+ """
228
+ print(f"\nStarting training for model: {self.model_type}...")
229
+ print(f"Data shape: {X.shape}")
230
+ print(f"Number of labels: {len(np.unique(y))}")
231
+
232
+ # Encode labels
233
+ y_encoded = self.label_encoder.fit_transform(y)
234
+
235
+ # Split data
236
+ X_train, X_test, y_train, y_test = train_test_split(
237
+ X, y_encoded, test_size=test_size, random_state=42, stratify=y_encoded
238
+ )
239
+
240
+ print(f"Train set size: {X_train.shape[0]}")
241
+ print(f"Test set size: {X_test.shape[0]}")
242
+
243
+ # standardize features
244
+ X_train_scaled = self.scaler.fit_transform(X_train)
245
+ X_test_scaled = self.scaler.transform(X_test)
246
+
247
+ # If using distillation process: train RF teacher first, then train MLPRegressor student to fit teacher's predict_proba
248
+ if self.model_type == 'distilled_rf':
249
+ print("Using distillation: train RandomForest teacher, then fit an MLPRegressor student to teacher soft labels")
250
+ # Train teacher
251
+ teacher = self._get_model()
252
+ teacher.fit(X_train_scaled, y_train)
253
+
254
+ # Get teacher's probability distribution as soft labels
255
+ y_train_proba = teacher.predict_proba(X_train_scaled)
256
+
257
+ # Create and train student (MLPRegressor) to fit probability vectors
258
+ student = MLPRegressor(hidden_layer_sizes=(128, 64, 32),
259
+ activation='relu',
260
+ solver='adam',
261
+ max_iter=1000,
262
+ learning_rate_init=0.001,
263
+ random_state=42,
264
+ early_stopping=True,
265
+ validation_fraction=0.1)
266
+
267
+ print("Training student model to fit teacher probability outputs...")
268
+ print(f"Teacher probability output shape: {y_train_proba.shape}")
269
+
270
+ # Multi-output regression, target is probability vector
271
+ student.fit(X_train_scaled, y_train_proba)
272
+
273
+ # Save models
274
+ self.model = teacher
275
+ self.student_model = student
276
+
277
+ # Use student to predict on train/test sets
278
+ y_train_pred_proba = student.predict(X_train_scaled)
279
+ y_test_pred_proba = student.predict(X_test_scaled)
280
+
281
+ # Apply softmax to ensure probabilities sum to 1
282
+ def softmax(x):
283
+ exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
284
+ return exp_x / np.sum(exp_x, axis=1, keepdims=True)
285
+
286
+ y_train_pred_proba = softmax(y_train_pred_proba)
287
+ y_test_pred_proba = softmax(y_test_pred_proba)
288
+
289
+ y_train_pred = np.argmax(y_train_pred_proba, axis=1)
290
+ y_test_pred = np.argmax(y_test_pred_proba, axis=1)
291
+
292
+ print(f"Student predicted probability shape: {y_test_pred_proba.shape}")
293
+ print(f"Student training accuracy: {accuracy_score(y_train, y_train_pred):.4f}")
294
+
295
+ else:
296
+ # Standard flow: train a single model
297
+ self.model = self._get_model()
298
+ self.model.fit(X_train_scaled, y_train)
299
+
300
+ y_train_pred = self.model.predict(X_train_scaled)
301
+ y_test_pred = self.model.predict(X_test_scaled)
302
+
303
+ # compute accuracies
304
+ train_accuracy = accuracy_score(y_train, y_train_pred)
305
+ test_accuracy = accuracy_score(y_test, y_test_pred)
306
+
307
+ # cross validation on the model used for training
308
+ # if student_model exists, still use teacher for cross-val
309
+ cv_model = self.model if self.model is not None else None
310
+ if cv_model is not None:
311
+ cv_scores = cross_val_score(cv_model, X_train_scaled, y_train, cv=5)
312
+ else:
313
+ cv_scores = np.array([])
314
+
315
+ print("\nTraining results:")
316
+ print(f"Train accuracy: {train_accuracy:.4f}")
317
+ print(f"Test accuracy: {test_accuracy:.4f}")
318
+ print(f"5-fold CV accuracy: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")
319
+
320
+ # classification report
321
+ print("\nTest set classification report:")
322
+ target_names = self.label_encoder.classes_
323
+ print(classification_report(y_test, y_test_pred, target_names=target_names))
324
+
325
+ # confusion matrix
326
+ cm = confusion_matrix(y_test, y_test_pred)
327
+
328
+ return {
329
+ 'train_accuracy': train_accuracy,
330
+ 'test_accuracy': test_accuracy,
331
+ 'cv_scores': cv_scores,
332
+ 'confusion_matrix': cm,
333
+ 'target_names': target_names,
334
+ 'X_test': X_test_scaled,
335
+ 'y_test': y_test,
336
+ 'y_test_pred': y_test_pred
337
+ }
338
+
339
+ def save_model(self, filepath):
340
+ """Save trained model to disk."""
341
+ model_data = {
342
+ 'model': self.model,
343
+ 'scaler': self.scaler,
344
+ 'label_encoder': self.label_encoder,
345
+ 'model_type': self.model_type,
346
+ 'target_joints': self.target_joints,
347
+ 'feature_columns': self.feature_columns
348
+ }
349
+ joblib.dump(model_data, filepath)
350
+ print(f"Model saved to: {filepath}")
351
+
352
+ def load_model(self, filepath):
353
+ """Load trained model from disk."""
354
+ model_data = joblib.load(filepath)
355
+ self.model = model_data['model']
356
+ self.scaler = model_data['scaler']
357
+ self.label_encoder = model_data['label_encoder']
358
+ self.model_type = model_data['model_type']
359
+ self.target_joints = model_data['target_joints']
360
+ self.feature_columns = model_data['feature_columns']
361
+ print(f"Model loaded from: {filepath}")
362
+
363
+ def predict(self, X):
364
+ """Run prediction on input features."""
365
+ if self.model is None and self.student_model is None:
366
+ raise ValueError("Model not trained or loaded")
367
+
368
+ X_scaled = self.scaler.transform(X)
369
+
370
+ # Prefer to use student_model (if exists) to generate probability output
371
+ if self.student_model is not None:
372
+ proba = self.student_model.predict(X_scaled) # Returns probability vector
373
+ preds = np.argmax(proba, axis=1)
374
+ labels = self.label_encoder.inverse_transform(preds)
375
+ return labels, proba
376
+
377
+ # Otherwise fall back to original model
378
+ predictions = self.model.predict(X_scaled)
379
+ probabilities = None
380
+ if hasattr(self.model, 'predict_proba'):
381
+ probabilities = self.model.predict_proba(X_scaled)
382
+ return self.label_encoder.inverse_transform(predictions), probabilities
383
+
384
+ def predict_single_json(self, json_path):
385
+ """
386
+ Predict pose class for a single JSON file.
387
+
388
+ Args:
389
+ json_path: path to the JSON file
390
+
391
+ Returns:
392
+ dict: prediction details or error information
393
+ """
394
+ if self.model is None:
395
+ raise ValueError("Model not trained or loaded")
396
+
397
+ try:
398
+ # Read JSON file
399
+ with open(json_path, 'r', encoding='utf-8') as f:
400
+ data = json.load(f)
401
+
402
+ landmarks = data.get('landmarks', {})
403
+
404
+ # Extract coordinates of target joints
405
+ features = []
406
+ missing_joints = []
407
+ available_joints = []
408
+
409
+ for joint in self.target_joints:
410
+ if joint in landmarks:
411
+ joint_data = landmarks[joint]
412
+ features.extend([
413
+ joint_data.get('x', 0.0),
414
+ joint_data.get('y', 0.0),
415
+ joint_data.get('z', 0.0)
416
+ ])
417
+ available_joints.append(joint)
418
+ else:
419
+ # If a joint is missing, fill with zeros
420
+ features.extend([0.0, 0.0, 0.0])
421
+ missing_joints.append(joint)
422
+
423
+ if len(features) != len(self.feature_columns):
424
+ raise ValueError(f"Feature dimension mismatch: expected {len(self.feature_columns)}, got {len(features)}")
425
+
426
+ # Convert to numpy array and predict
427
+ X = np.array([features])
428
+ predictions, probabilities = self.predict(X)
429
+
430
+ # build result dict
431
+ result = {
432
+ 'file_path': str(json_path),
433
+ 'file_name': Path(json_path).name,
434
+ 'predicted_label': predictions[0],
435
+ 'confidence_scores': {},
436
+ 'available_joints': available_joints,
437
+ 'missing_joints': missing_joints,
438
+ 'joint_coverage': f"{len(available_joints)}/{len(self.target_joints)}"
439
+ }
440
+
441
+ # add per-class confidence scores
442
+ if probabilities is not None:
443
+ for i, label in enumerate(self.label_encoder.classes_):
444
+ result['confidence_scores'][label] = float(probabilities[0][i])
445
+
446
+ # highest confidence
447
+ max_prob_idx = np.argmax(probabilities[0])
448
+ result['max_confidence'] = float(probabilities[0][max_prob_idx])
449
+
450
+ return result
451
+
452
+ except Exception as e:
453
+ return {
454
+ 'file_path': str(json_path),
455
+ 'file_name': Path(json_path).name,
456
+ 'error': str(e),
457
+ 'predicted_label': None
458
+ }
459
+
460
+ def evaluate_test_directory(self, test_dir):
461
+ """
462
+ Evaluate all data in a test directory.
463
+
464
+ Args:
465
+ test_dir: path to the test data directory
466
+
467
+ Returns:
468
+ dict: dictionary containing detailed evaluation results
469
+ """
470
+ if self.model is None:
471
+ raise ValueError("Model not trained or loaded")
472
+
473
+ test_path = Path(test_dir)
474
+ if not test_path.exists():
475
+ raise ValueError(f"Test directory does not exist: {test_dir}")
476
+
477
+ # start timing
478
+ start_time = time.time()
479
+ print(f"Starting evaluation on test dataset: {test_path}")
480
+ print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
481
+
482
+ # store all prediction results
483
+ all_results = []
484
+ label_stats = {}
485
+ total_prediction_time = 0.0
486
+ prediction_count = 0
487
+
488
+ # iterate over label folders
489
+ for label_dir in test_path.iterdir():
490
+ if not label_dir.is_dir() or not label_dir.name.startswith('label_'):
491
+ continue
492
+
493
+ true_label = label_dir.name
494
+ json_files = list(label_dir.glob('*.json'))
495
+
496
+ print(f"Evaluating {true_label}: {len(json_files)} files")
497
+
498
+ label_stats[true_label] = {
499
+ 'total': len(json_files),
500
+ 'correct': 0,
501
+ 'incorrect': 0,
502
+ 'errors': 0,
503
+ 'predictions': {},
504
+ 'confidence_scores': [],
505
+ 'prediction_times': []
506
+ }
507
+
508
+ for json_file in json_files:
509
+ # Single prediction timing
510
+ pred_start_time = time.time()
511
+ result = self.predict_single_json(json_file)
512
+ pred_end_time = time.time()
513
+
514
+ single_prediction_time = pred_end_time - pred_start_time
515
+ total_prediction_time += single_prediction_time
516
+ prediction_count += 1
517
+
518
+ if 'error' in result:
519
+ label_stats[true_label]['errors'] += 1
520
+ print(f" Error: {json_file.name} - {result['error']}")
521
+ continue
522
+
523
+ predicted_label = result['predicted_label']
524
+ is_correct = predicted_label == true_label
525
+
526
+ if is_correct:
527
+ label_stats[true_label]['correct'] += 1
528
+ else:
529
+ label_stats[true_label]['incorrect'] += 1
530
+
531
+ # Count prediction distribution
532
+ if predicted_label not in label_stats[true_label]['predictions']:
533
+ label_stats[true_label]['predictions'][predicted_label] = 0
534
+ label_stats[true_label]['predictions'][predicted_label] += 1
535
+
536
+ # Record confidence and prediction time
537
+ if 'max_confidence' in result:
538
+ label_stats[true_label]['confidence_scores'].append(result['max_confidence'])
539
+ label_stats[true_label]['prediction_times'].append(single_prediction_time)
540
+
541
+ # Save detailed result
542
+ all_results.append({
543
+ 'file_path': str(json_file),
544
+ 'file_name': json_file.name,
545
+ 'true_label': true_label,
546
+ 'predicted_label': predicted_label,
547
+ 'is_correct': is_correct,
548
+ 'confidence': result.get('max_confidence', 0.0),
549
+ 'confidence_scores': result.get('confidence_scores', {}),
550
+ 'joint_coverage': result.get('joint_coverage', '0/13'),
551
+ 'prediction_time': single_prediction_time
552
+ })
553
+
554
+ # end timing
555
+ end_time = time.time()
556
+ total_execution_time = end_time - start_time
557
+
558
+ # compute aggregate statistics
559
+ total_samples = sum(stats['total'] for stats in label_stats.values())
560
+ total_correct = sum(stats['correct'] for stats in label_stats.values())
561
+ total_errors = sum(stats['errors'] for stats in label_stats.values())
562
+ total_tested = total_samples - total_errors
563
+
564
+ overall_accuracy = total_correct / total_tested if total_tested > 0 else 0.0
565
+ avg_prediction_time = total_prediction_time / prediction_count if prediction_count > 0 else 0.0
566
+
567
+ # build confusion matrix
568
+ confusion_matrix = {}
569
+ for true_label in label_stats.keys():
570
+ confusion_matrix[true_label] = {}
571
+ for predicted_label in label_stats.keys():
572
+ confusion_matrix[true_label][predicted_label] = 0
573
+
574
+ for result in all_results:
575
+ if result.get('is_correct') is not None: # exclude error cases
576
+ true_label = result['true_label']
577
+ predicted_label = result['predicted_label']
578
+ confusion_matrix[true_label][predicted_label] += 1
579
+
580
+ return {
581
+ 'label_stats': label_stats,
582
+ 'overall_accuracy': overall_accuracy,
583
+ 'total_samples': total_samples,
584
+ 'total_correct': total_correct,
585
+ 'total_errors': total_errors,
586
+ 'total_tested': total_tested,
587
+ 'confusion_matrix': confusion_matrix,
588
+ 'detailed_results': all_results,
589
+ 'timing_stats': {
590
+ 'total_execution_time': total_execution_time,
591
+ 'total_prediction_time': total_prediction_time,
592
+ 'avg_prediction_time': avg_prediction_time,
593
+ 'prediction_count': prediction_count,
594
+ 'start_time': start_time,
595
+ 'end_time': end_time,
596
+ 'overhead_time': total_execution_time - total_prediction_time
597
+ }
598
+ }
599
+
600
+ def print_evaluation_report(self, eval_results):
601
+ """
602
+ Print a detailed evaluation report.
603
+
604
+ Args:
605
+ eval_results: dictionary returned by evaluate_test_directory
606
+ """
607
+ timing_stats = eval_results.get('timing_stats', {})
608
+
609
+ print("\n" + "=" * 80)
610
+ print("Test dataset evaluation report")
611
+ print("=" * 80)
612
+
613
+ # Overall statistics
614
+ print(f"Total samples: {eval_results['total_samples']}")
615
+ print(f"Successfully tested: {eval_results['total_tested']}")
616
+ print(f"Errors: {eval_results['total_errors']}")
617
+ print(
618
+ f"Overall accuracy: {eval_results['overall_accuracy']:.4f} "
619
+ f"({eval_results['total_correct']}/{eval_results['total_tested']})"
620
+ )
621
+
622
+ # Timing statistics
623
+ if timing_stats:
624
+ total_time = timing_stats['total_execution_time']
625
+ prediction_time = timing_stats['total_prediction_time']
626
+ avg_time = timing_stats['avg_prediction_time']
627
+ overhead_time = timing_stats['overhead_time']
628
+ prediction_count = timing_stats['prediction_count']
629
+
630
+ print("\nTiming statistics:")
631
+ print("-" * 50)
632
+ print(f"Total execution time: {total_time:.4f} s")
633
+ print(f"Total prediction time: {prediction_time:.4f} s")
634
+ print(f"Overhead time: {overhead_time:.4f} s")
635
+ print(f"Average prediction time: {avg_time * 1000:.2f} ms")
636
+ print(f"Prediction throughput: {prediction_count / total_time:.2f} preds/s")
637
+ print(
638
+ f"Prediction efficiency: {(prediction_time / total_time) * 100:.1f}% "
639
+ f"(prediction time / total)"
640
+ )
641
+
642
+ # Per-label detailed statistics
643
+ print("\nPer-label stats:")
644
+ print("-" * 80)
645
+ print(
646
+ f"{'Label':<10} {'Total':<6} {'Correct':<6} {'Wrong':<6} "
647
+ f"{'Accuracy':<8} {'AvgConf':<10} {'AvgPredTime':<12}"
648
+ )
649
+ print("-" * 80)
650
+
651
+ for label, stats in sorted(eval_results['label_stats'].items()):
652
+ accuracy = (
653
+ stats['correct'] / (stats['total'] - stats['errors'])
654
+ if (stats['total'] - stats['errors']) > 0
655
+ else 0.0
656
+ )
657
+ avg_confidence = (
658
+ np.mean(stats['confidence_scores']) if stats['confidence_scores'] else 0.0
659
+ )
660
+ avg_pred_time = (
661
+ np.mean(stats['prediction_times'])
662
+ if 'prediction_times' in stats and stats['prediction_times']
663
+ else 0.0
664
+ )
665
+
666
+ print(
667
+ f"{label:<10} {stats['total']:<6} {stats['correct']:<6} {stats['incorrect']:<6} "
668
+ f"{accuracy:.4f} {avg_confidence:.4f} {avg_pred_time * 1000:.2f}ms"
669
+ )
670
+
671
+ # Confusion matrix
672
+ print("\nConfusion matrix:")
673
+ print("-" * 60)
674
+ labels = sorted(eval_results['label_stats'].keys())
675
+
676
+ # Header row
677
+ print(f"{'True\\Pred':<12}", end="")
678
+ for label in labels:
679
+ print(f"{label:<10}", end="")
680
+ print()
681
+
682
+ # Data rows
683
+ for true_label in labels:
684
+ print(f"{true_label:<12}", end="")
685
+ for pred_label in labels:
686
+ count = eval_results['confusion_matrix'][true_label][pred_label]
687
+ print(f"{count:<10}", end="")
688
+ print()
689
+
690
+ # Per-label prediction distribution
691
+ print("\nPer-label prediction distribution:")
692
+ print("-" * 80)
693
+ for true_label, stats in sorted(eval_results['label_stats'].items()):
694
+ if stats['predictions']:
695
+ print(f"{true_label}:")
696
+ total_predictions = sum(stats['predictions'].values())
697
+ for pred_label, count in sorted(stats['predictions'].items()):
698
+ percentage = (count / total_predictions) * 100
699
+ print(f" -> {pred_label}: {count} ({percentage:.1f}%)")
700
+
701
+ # Error analysis
702
+ print("\nError analysis:")
703
+ print("-" * 40)
704
+ incorrect_results = [r for r in eval_results['detailed_results'] if not r['is_correct']]
705
+
706
+ if incorrect_results:
707
+ # Sort by confidence and show top mistaken predictions
708
+ incorrect_results.sort(key=lambda x: x['confidence'], reverse=True)
709
+ print("Highest-confidence incorrect predictions (top 10):")
710
+ for i, result in enumerate(incorrect_results[:10]):
711
+ pred_time = result.get('prediction_time', 0) * 1000 # ms
712
+ print(
713
+ f"{i + 1:2d}. {result['file_name']}: {result['true_label']} -> {result['predicted_label']} "
714
+ f"(conf: {result['confidence']:.4f}, time: {pred_time:.2f}ms)"
715
+ )
716
+ else:
717
+ print("No incorrect predictions found.")
718
+
719
+ # Performance analysis
720
+ if timing_stats and eval_results['detailed_results']:
721
+ print("\nPerformance analysis:")
722
+ print("-" * 40)
723
+ prediction_times = [
724
+ r.get('prediction_time', 0) for r in eval_results['detailed_results'] if 'prediction_time' in r
725
+ ]
726
+ if prediction_times:
727
+ min_time = min(prediction_times) * 1000
728
+ max_time = max(prediction_times) * 1000
729
+ median_time = np.median(prediction_times) * 1000
730
+ std_time = np.std(prediction_times) * 1000
731
+
732
+ print("Prediction time distribution:")
733
+ print(f" Fastest: {min_time:.2f}ms")
734
+ print(f" Slowest: {max_time:.2f}ms")
735
+ print(f" Median: {median_time:.2f}ms")
736
+ print(f" Stddev: {std_time:.2f}ms")
737
+
738
+ print("\n" + "=" * 80)
739
+
740
+ def plot_confusion_matrix(self, cm, target_names, save_path=None):
741
+ """Plot confusion matrix."""
742
+ plt.figure(figsize=(10, 8))
743
+ if SEABORN_AVAILABLE:
744
+ sns.heatmap(
745
+ cm,
746
+ annot=True,
747
+ fmt='d',
748
+ cmap='Blues',
749
+ xticklabels=target_names,
750
+ yticklabels=target_names,
751
+ )
752
+ else:
753
+ # Fallback using matplotlib only
754
+ im = plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
755
+ plt.colorbar(im)
756
+ tick_marks = np.arange(len(target_names))
757
+ plt.xticks(tick_marks, target_names, rotation=45, ha='right')
758
+ plt.yticks(tick_marks, target_names)
759
+ # Annotate cells
760
+ thresh = cm.max() / 2.0 if cm.size else 0
761
+ for i in range(cm.shape[0]):
762
+ for j in range(cm.shape[1]):
763
+ plt.text(j, i, format(cm[i, j], 'd'),
764
+ ha="center", va="center",
765
+ color="white" if cm[i, j] > thresh else "black")
766
+
767
+ plt.title(f"{self.model_type.title()} model confusion matrix")
768
+ plt.xlabel('Predicted')
769
+ plt.ylabel('True')
770
+
771
+ if save_path:
772
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
773
+ print(f"Confusion matrix saved to: {save_path}")
774
+
775
+ plt.show()
776
+
777
+ def export_to_onnx(self, model_type='random_forest', output_path=None):
778
+ """
779
+ Export the trained model to ONNX format (only models supported by Barracuda).
780
+ Note: Barracuda does not support LinearClassifier layers (e.g., LogisticRegression/SVM) — only tree models are supported.
781
+ """
782
+ if not ONNX_AVAILABLE:
783
+ print("Error: ONNX export is unavailable. Please install skl2onnx and onnx packages:")
784
+ print("pip install skl2onnx onnx")
785
+ return None
786
+
787
+ if not hasattr(self, 'model') or self.model is None:
788
+ print("Error: Model is not trained yet. Please train the model first.")
789
+ return None
790
+
791
+ # Check if current model type matches requested export type
792
+ if hasattr(self, 'model_type') and self.model_type != model_type:
793
+ print(f"Warning: Currently trained {self.model_type} model, but requested to export {model_type} model")
794
+ print(f"Will export currently trained {self.model_type} model")
795
+ model_name = self.model_type
796
+ else:
797
+ model_name = model_type
798
+
799
+ # Barracuda only supports tree models, not LinearClassifier
800
+ if model_name in ['logistic', 'svm']:
801
+ print(f"❌ Barracuda/Unity does not support ONNX import for {model_name} models (LinearClassifier layer).")
802
+ print("Please use random_forest or gradient_boost for export.")
803
+ return None
804
+
805
+ # If student_model exists -> export student_model (MLP), otherwise export self.model
806
+ model_to_export = None
807
+ export_name = None
808
+
809
+ if self.student_model is not None:
810
+ model_to_export = self.student_model
811
+ export_name = 'distilled_mlp'
812
+ print("Detected student_model. Exporting student (MLP) to ONNX (suitable for Unity/Barracuda).")
813
+ else:
814
+ model_to_export = self.model
815
+ export_name = model_name
816
+
817
+ if model_to_export is None:
818
+ print("Error: No model available for export.")
819
+ return None
820
+
821
+ # Generate output file path
822
+ if output_path is None:
823
+ output_path = f"pose_classifier_{export_name}.onnx"
824
+
825
+ print(f"About to export model to: {output_path}, export target: {export_name}")
826
+
827
+ try:
828
+ feature_count = len(self.target_joints) * 3
829
+ initial_type = [('float_input', FloatTensorType([None, feature_count]))]
830
+
831
+ onnx_model = convert_sklearn(
832
+ model_to_export,
833
+ initial_types=initial_type,
834
+ target_opset=12
835
+ )
836
+
837
+ with open(output_path, "wb") as f:
838
+ f.write(onnx_model.SerializeToString())
839
+
840
+ print(f"✅ Successfully exported {export_name} model to ONNX format: {output_path}")
841
+
842
+ # Save label mapping and Scaler parameters
843
+ label_mapping_path = output_path.replace('.onnx', '_labels.json')
844
+ label_mapping = {
845
+ 'label_encoder_classes': self.label_encoder.classes_.tolist(),
846
+ 'model_type': export_name,
847
+ 'feature_count': feature_count,
848
+ 'target_joints': self.target_joints,
849
+ 'description': f'Pose classifier - {len(self.target_joints)} landmarks with x,y,z coordinates',
850
+ 'scaler_mean': self.scaler.mean_.tolist(),
851
+ 'scaler_scale': self.scaler.scale_.tolist()
852
+ }
853
+
854
+ with open(label_mapping_path, 'w', encoding='utf-8') as f:
855
+ json.dump(label_mapping, f, ensure_ascii=False, indent=2)
856
+
857
+ print(f"✅ Label mapping and scaler parameters saved to: {label_mapping_path}")
858
+
859
+ print("⚠️ Note: The exported ONNX expects inputs to be standardized with scaler_mean/scaler_scale.")
860
+
861
+ return output_path
862
+
863
+ except Exception as e:
864
+ print(f"❌ ONNX export failed: {str(e)}")
865
+ import traceback
866
+ traceback.print_exc()
867
+ return None
868
+
869
+ def export_to_tflite(self, output_path=None):
870
+ """
871
+ Export student_model (MLP) to TFLite format.
872
+ Dependencies: skl2onnx, onnx, onnx-tf, tensorflow
873
+ """
874
+ if self.student_model is None:
875
+ print("❌ Only exporting student_model (MLPRegressor) to TFLite is supported. Please train with --model distilled_rf first.")
876
+ return None
877
+
878
+ try:
879
+ import onnx
880
+ from skl2onnx import convert_sklearn
881
+ from skl2onnx.common.data_types import FloatTensorType
882
+ from onnx_tf.backend import prepare
883
+ import tensorflow as tf
884
+ except ImportError:
885
+ print("❌ You need to install skl2onnx, onnx, onnx-tf, tensorflow.")
886
+ print("pip install skl2onnx onnx onnx-tf tensorflow")
887
+ return None
888
+
889
+ feature_count = len(self.target_joints) * 3
890
+ initial_type = [('float_input', FloatTensorType([None, feature_count]))]
891
+
892
+ # 1. Export to ONNX
893
+ print("Exporting student_model to ONNX...")
894
+ onnx_model = convert_sklearn(
895
+ self.student_model,
896
+ initial_types=initial_type,
897
+ target_opset=12
898
+ )
899
+ onnx_path = "temp_student.onnx"
900
+ with open(onnx_path, "wb") as f:
901
+ f.write(onnx_model.SerializeToString())
902
+ print(f"✅ ONNX export successful: {onnx_path}")
903
+
904
+ # 2. ONNX -> TensorFlow SavedModel
905
+ print("Converting ONNX to TensorFlow SavedModel...")
906
+ tf_model = prepare(onnx.load(onnx_path))
907
+ tf_saved_path = "temp_student_tf"
908
+ tf_model.export_graph(tf_saved_path)
909
+ print(f"✅ SavedModel export successful: {tf_saved_path}")
910
+
911
+ # 3. SavedModel -> TFLite
912
+ print("Converting SavedModel to TFLite...")
913
+ converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_path)
914
+ tflite_model = converter.convert()
915
+ if output_path is None:
916
+ output_path = "pose_classifier_distilled_mlp.tflite"
917
+ with open(output_path, "wb") as f:
918
+ f.write(tflite_model)
919
+ print(f"✅ TFLite export successful: {output_path}")
920
+
921
+ # Cleanup temporary files (optional)
922
+ import os
923
+ os.remove(onnx_path)
924
+ import shutil
925
+ shutil.rmtree(tf_saved_path, ignore_errors=True)
926
+
927
+ return output_path
928
+
929
+ def main():
930
+ parser = argparse.ArgumentParser(description="Pose classification machine learning script")
931
+ parser.add_argument("--data", "-d", default="PoseData", help="Pose data directory (default: PoseData)")
932
+ parser.add_argument(
933
+ "--model",
934
+ "-m",
935
+ choices=['random_forest', 'svm', 'gradient_boost', 'logistic', 'distilled_rf'],
936
+ default='random_forest',
937
+ help="Model type (default: random_forest)",
938
+ )
939
+ parser.add_argument("--test-size", "-t", type=float, default=0.2, help="Test set ratio (default: 0.2)")
940
+ parser.add_argument("--save-model", "-s", help="Path to save the trained model")
941
+ parser.add_argument("--load-model", "-l", help="Path to load an already trained model")
942
+ parser.add_argument("--predict", "-p", help="Path of a single JSON file to predict")
943
+ parser.add_argument("--evaluate", "-e", help="Path of a test directory to evaluate all JSON files")
944
+ parser.add_argument("--no-plot", action="store_true", help="Do not display confusion matrix plot")
945
+ parser.add_argument("--train", action="store_true", help="Force training even if --load-model is provided")
946
+ parser.add_argument("--export-onnx", help="Export model to ONNX format; specify output file path")
947
+ parser.add_argument(
948
+ "--export-model-type",
949
+ choices=['random_forest', 'logistic', 'distilled_rf'],
950
+ default='random_forest',
951
+ help="Model type to export (default: random_forest)",
952
+ )
953
+ parser.add_argument("--test-onnx", help="Test an ONNX model; specify ONNX file path")
954
+ parser.add_argument("--onnx-labels", help="ONNX label mapping JSON path (auto-detect if not provided)")
955
+ parser.add_argument("--onnx-test-data", help="ONNX batch test data directory (if not provided, single-sample test)")
956
+ parser.add_argument(
957
+ "--export-tflite",
958
+ help="Export model to TFLite format; specify output path (supported for distilled_rf student model only)",
959
+ )
960
+
961
+ args = parser.parse_args()
962
+
963
+ print("Pose classification ML tool")
964
+ print("=" * 60)
965
+
966
+ # If ONNX test mode
967
+ if args.test_onnx:
968
+ print("ONNX model test mode")
969
+ print(f"ONNX model: {args.test_onnx}")
970
+ print("=" * 60)
971
+
972
+ # Create classifier instance for testing
973
+ classifier = PoseClassifier()
974
+ # Note: test_onnx_model is not implemented in this script; this is a placeholder.
975
+ # You can implement it later if needed.
976
+ print("ONNX test requested but functionality is not implemented in this script.")
977
+ return
978
+
979
+ # If evaluation mode
980
+ if args.evaluate:
981
+ if not args.load_model:
982
+ # Try to use default model file
983
+ default_model = f"pose_classifier_{args.model}.pkl"
984
+ if Path(default_model).exists():
985
+ args.load_model = default_model
986
+ else:
987
+ print(
988
+ f"Error: Need to specify model file path (--load-model) or ensure default model file exists: {default_model}"
989
+ )
990
+ return
991
+
992
+ print("Evaluation mode")
993
+ print(f"Test data directory: {args.evaluate}")
994
+ print(f"Model file: {args.load_model}")
995
+ print("=" * 60)
996
+
997
+ # Create classifier and load model
998
+ classifier = PoseClassifier(model_type=args.model)
999
+ classifier.load_model(args.load_model)
1000
+
1001
+ # Perform comprehensive evaluation
1002
+ try:
1003
+ eval_results = classifier.evaluate_test_directory(args.evaluate)
1004
+ classifier.print_evaluation_report(eval_results)
1005
+ except Exception as e:
1006
+ print(f"Error during evaluation: {e}")
1007
+
1008
+ return
1009
+
1010
+ # Prediction-only mode
1011
+ if args.predict:
1012
+ if not args.load_model:
1013
+ # Try to use default model file
1014
+ default_model = f"pose_classifier_{args.model}.pkl"
1015
+ if Path(default_model).exists():
1016
+ args.load_model = default_model
1017
+ else:
1018
+ print(
1019
+ f"Error: Need to specify model file path (--load-model) or ensure default model file exists: {default_model}"
1020
+ )
1021
+ return
1022
+
1023
+ print("Prediction mode")
1024
+ print(f"JSON file: {args.predict}")
1025
+ print(f"Model file: {args.load_model}")
1026
+ print("=" * 60)
1027
+
1028
+ # Create classifier and load model
1029
+ classifier = PoseClassifier(model_type=args.model)
1030
+ classifier.load_model(args.load_model)
1031
+
1032
+ # Run prediction
1033
+ result = classifier.predict_single_json(args.predict)
1034
+
1035
+ # Show prediction result
1036
+ print("\nPrediction result:")
1037
+ print(f"File: {result['file_name']}")
1038
+
1039
+ if 'error' in result:
1040
+ print(f"Error: {result['error']}")
1041
+ else:
1042
+ print(f"Predicted label: {result['predicted_label']}")
1043
+ print(f"Joint coverage: {result['joint_coverage']}")
1044
+
1045
+ if result['confidence_scores']:
1046
+ print(f"Max confidence: {result['max_confidence']:.4f}")
1047
+ print("\nPer-class confidence:")
1048
+ sorted_scores = sorted(result['confidence_scores'].items(), key=lambda x: x[1], reverse=True)
1049
+ for label, score in sorted_scores:
1050
+ print(f" {label}: {score:.4f}")
1051
+
1052
+ if result['missing_joints']:
1053
+ print(f"\nMissing joints: {', '.join(result['missing_joints'])}")
1054
+
1055
+ return
1056
+
1057
+ # Training mode
1058
+ print("Training mode")
1059
+ print(f"Data directory: {args.data}")
1060
+ print(f"Model type: {args.model}")
1061
+ print(f"Test size: {args.test_size}")
1062
+ print("=" * 60)
1063
+
1064
+ # Check data directory
1065
+ if not Path(args.data).exists():
1066
+ print(f"Error: data directory does not exist: {args.data}")
1067
+ return
1068
+
1069
+ # Create classifier
1070
+ classifier = PoseClassifier(model_type=args.model)
1071
+
1072
+ # If loading an existing model and not forcing training
1073
+ if args.load_model and not args.train:
1074
+ print(f"Loading existing model: {args.load_model}")
1075
+ classifier.load_model(args.load_model)
1076
+ print("Model loaded, skipping training step")
1077
+ else:
1078
+ # Load data
1079
+ X, y = classifier.load_data(args.data)
1080
+ if len(X) == 0:
1081
+ print("Error: no valid data found")
1082
+ return
1083
+ # Train model
1084
+ results = classifier.train(X, y, test_size=args.test_size)
1085
+ # Plot confusion matrix (if not disabled)
1086
+ if not args.no_plot:
1087
+ try:
1088
+ classifier.plot_confusion_matrix(
1089
+ results['confusion_matrix'], results['target_names'], save_path=f"confusion_matrix_{args.model}.png"
1090
+ )
1091
+ except Exception as e:
1092
+ print(f"Error while plotting confusion matrix: {e}")
1093
+ # Save model (if specified)
1094
+ if args.save_model:
1095
+ classifier.save_model(args.save_model)
1096
+ else:
1097
+ # Default save path
1098
+ default_path = f"pose_classifier_{args.model}.pkl"
1099
+ classifier.save_model(default_path)
1100
+ print("\nTraining complete!")
1101
+ print(f"Final test accuracy: {results['test_accuracy']:.4f}")
1102
+
1103
+ # Export ONNX if requested
1104
+ if args.export_onnx:
1105
+ print(f"\nExporting {args.export_model_type} model to ONNX format...")
1106
+ onnx_path = classifier.export_to_onnx(model_type=args.export_model_type, output_path=args.export_onnx)
1107
+ if onnx_path:
1108
+ print(f"✅ ONNX model exported: {onnx_path}")
1109
+
1110
+ # Export TFLite if requested
1111
+ if args.export_tflite:
1112
+ print("\nExporting student_model to TFLite format...")
1113
+ tflite_path = classifier.export_to_tflite(output_path=args.export_tflite)
1114
+ if tflite_path:
1115
+ print(f"✅ TFLite model exported: {tflite_path}")
1116
+
1117
+
1118
+
1119
+ if __name__ == "__main__":
1120
+ main()
1121
+
pose_detection.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Use MediaPipe to detect poses in images and extract landmark coordinates.
4
+
5
+ Features:
6
+ 1. Run MediaPipe pose detection on images in the train folder
7
+ 2. Use the nose as the head reference point (headPos)
8
+ 3. Process coordinates as: pos = (pos - headPos) * 100 and round to 2 decimals
9
+ 4. Save processed landmarks into JSON files named after the image files
10
+
11
+ Usage:
12
+ python pose_detection.py [--input INPUT_DIR] [--output OUTPUT_DIR]
13
+ """
14
+ import os
15
+ import json
16
+ import argparse
17
+ from pathlib import Path
18
+ import cv2
19
+ import mediapipe as mp
20
+
21
+
22
+ class PoseDetector:
23
+ def __init__(self):
24
+ """Initialize MediaPipe pose detector."""
25
+ self.mp_pose = mp.solutions.pose
26
+ self.pose = self.mp_pose.Pose(
27
+ static_image_mode=True,
28
+ model_complexity=2,
29
+ enable_segmentation=False,
30
+ min_detection_confidence=0.5
31
+ )
32
+
33
+ # MediaPipe pose landmark name mapping
34
+ self.landmark_names = [
35
+ 'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
36
+ 'right_eye_inner', 'right_eye', 'right_eye_outer',
37
+ 'left_ear', 'right_ear', 'mouth_left', 'mouth_right',
38
+ 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
39
+ 'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky',
40
+ 'left_index', 'right_index', 'left_thumb', 'right_thumb',
41
+ 'left_hip', 'right_hip', 'left_knee', 'right_knee',
42
+ 'left_ankle', 'right_ankle', 'left_heel', 'right_heel',
43
+ 'left_foot_index', 'right_foot_index'
44
+ ]
45
+
46
+ def get_head_position(self, landmarks):
47
+ """
48
+ Compute the head reference position (use the nose landmark).
49
+
50
+ Args:
51
+ landmarks: MediaPipe detected landmarks
52
+
53
+ Returns:
54
+ tuple: (x, y, z) head coordinates
55
+ """
56
+ # use nose as the head reference point
57
+ nose = landmarks[0] # nose is the 0th landmark
58
+ return (nose.x, nose.y, nose.z)
59
+
60
+ def process_landmarks(self, landmarks, head_pos):
61
+ """
62
+ Process landmarks: pos = (pos - headPos) * 100 and round to 2 decimals.
63
+
64
+ Args:
65
+ landmarks: MediaPipe detected landmarks
66
+ head_pos: head coordinates (x, y, z)
67
+
68
+ Returns:
69
+ dict: processed landmarks dictionary
70
+ """
71
+ processed_landmarks = {}
72
+ head_pos_x = head_pos[0]
73
+ head_pos_y = head_pos[1]
74
+ head_pos_z = head_pos[2]
75
+
76
+ for i, landmark in enumerate(landmarks):
77
+ if i < len(self.landmark_names):
78
+ name = self.landmark_names[i]
79
+
80
+ # Calculate coordinates relative to head and multiply by 100
81
+ rel_x = round((landmark.x - head_pos_x) * 100, 2)
82
+ rel_y = round((landmark.y - head_pos_y) * 100, 2)
83
+ rel_z = round((landmark.z - head_pos_z) * 100, 2)
84
+
85
+ processed_landmarks[name] = {
86
+ 'x': rel_x,
87
+ 'y': rel_y,
88
+ 'z': rel_z,
89
+ 'visibility': round(landmark.visibility, 3)
90
+ }
91
+
92
+ return processed_landmarks
93
+
94
+ def detect_pose(self, image_path):
95
+ """
96
+ Detect pose for a single image.
97
+
98
+ Args:
99
+ image_path: path to the image file
100
+
101
+ Returns:
102
+ dict: processed landmarks and metadata, or None on failure
103
+ """
104
+ try:
105
+ # Read image
106
+ image = cv2.imread(str(image_path))
107
+ if image is None:
108
+ print(f"Unable to read image: {image_path}")
109
+ return None
110
+
111
+ # Convert color space (BGR -> RGB)
112
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113
+
114
+ # Run pose detection
115
+ results = self.pose.process(image_rgb)
116
+
117
+ if results.pose_landmarks is None:
118
+ print(f"No pose detected: {image_path}")
119
+ return None
120
+
121
+ # Get keypoints
122
+ landmarks = results.pose_landmarks.landmark
123
+
124
+ # Get head position
125
+ head_pos = self.get_head_position(landmarks)
126
+
127
+ # Process keypoint coordinates
128
+ processed_landmarks = self.process_landmarks(landmarks, head_pos)
129
+
130
+ # extract label from parent folder name
131
+ label = image_path.parent.name
132
+
133
+ # Add metadata
134
+ result = {
135
+ 'image_path': str(image_path),
136
+ 'image_name': image_path.name,
137
+ 'label': label,
138
+ 'head_position': {
139
+ 'x': round(head_pos[0], 4),
140
+ 'y': round(head_pos[1], 4),
141
+ 'z': round(head_pos[2], 4)
142
+ },
143
+ 'landmarks': processed_landmarks,
144
+ 'total_landmarks': len(processed_landmarks)
145
+ }
146
+
147
+ return result
148
+
149
+ except Exception as e:
150
+ print(f"Error processing image {image_path}: {e}")
151
+ return None
152
+
153
+ def close(self):
154
+ """Close MediaPipe resources."""
155
+ self.pose.close()
156
+
157
+
158
+ def process_all_training_data(input_dir, output_dir, batch_size=100):
159
+ """
160
+ Process all images in the training dataset and write JSON files.
161
+
162
+ Args:
163
+ input_dir: input images directory (TrainData/train)
164
+ output_dir: output JSON directory (PoseData)
165
+ batch_size: progress report batch size
166
+ """
167
+ input_path = Path(input_dir)
168
+ output_path = Path(output_dir)
169
+ output_path.mkdir(parents=True, exist_ok=True)
170
+
171
+ # Supported image formats
172
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
173
+
174
+ detector = PoseDetector()
175
+
176
+ try:
177
+ # statistics
178
+ total_images = 0
179
+ success_count = 0
180
+ failed_count = 0
181
+ label_stats = {}
182
+
183
+ print(f"Starting processing dataset: {input_path}")
184
+ print(f"Output directory: {output_path}")
185
+
186
+ # first count all images
187
+ print("Counting images...")
188
+ label_dirs = []
189
+ for item in input_path.iterdir():
190
+ if item.is_dir() and item.name.startswith('label_'):
191
+ label = item.name
192
+ image_files = [f for f in item.iterdir()
193
+ if f.is_file() and f.suffix.lower() in image_extensions]
194
+ if image_files:
195
+ label_dirs.append((item, label, image_files))
196
+ total_images += len(image_files)
197
+ label_stats[label] = {'total': len(image_files), 'success': 0, 'failed': 0}
198
+
199
+ print(f"Found {len(label_dirs)} label directories, total {total_images} images")
200
+ for label, stats in label_stats.items():
201
+ print(f" {label}: {stats['total']} images")
202
+
203
+ print("\nStarting to process images...")
204
+
205
+ # process each label directory
206
+ for label_dir, label_name, image_files in label_dirs:
207
+ print(f"\n--- Processing {label_name} ({len(image_files)} images) ---")
208
+
209
+ # create output folder for this label
210
+ output_label_dir = output_path / label_name
211
+ output_label_dir.mkdir(parents=True, exist_ok=True)
212
+
213
+ # process every image in this label
214
+ for i, image_file in enumerate(image_files, 1):
215
+ json_filename = image_file.stem + '.json'
216
+ json_path = output_label_dir / json_filename
217
+
218
+ # detect pose
219
+ result = detector.detect_pose(image_file)
220
+
221
+ if result is not None:
222
+ # save JSON
223
+ try:
224
+ with open(json_path, 'w', encoding='utf-8') as f:
225
+ json.dump(result, f, ensure_ascii=False, indent=2)
226
+ success_count += 1
227
+ label_stats[label_name]['success'] += 1
228
+
229
+ # progress
230
+ if success_count % batch_size == 0:
231
+ progress = (success_count / total_images) * 100 if total_images else 0
232
+ print(f" Progress: {success_count}/{total_images} ({progress:.1f}%) - Current: {label_name} {i}/{len(image_files)}")
233
+
234
+ except Exception as e:
235
+ print(f" Failed to save JSON {json_path}: {e}")
236
+ failed_count += 1
237
+ label_stats[label_name]['failed'] += 1
238
+ else:
239
+ failed_count += 1
240
+ label_stats[label_name]['failed'] += 1
241
+ if failed_count % 10 == 0: # print every 10 failures
242
+ print(f" Detection failed: {image_file.name}")
243
+
244
+ # report for this label
245
+ stats = label_stats[label_name]
246
+ success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
247
+ print(f" {label_name} Done: Success {stats['success']}, Failed {stats['failed']}, Success rate: {success_rate:.1f}%")
248
+
249
+ print("\n" + "=" * 60)
250
+ print("Processing complete!")
251
+ print(f"Total images: {total_images}")
252
+ print(f"Successfully processed: {success_count}")
253
+ print(f"Failed: {failed_count}")
254
+ total_success_rate = (success_count / total_images) * 100 if total_images > 0 else 0
255
+ print(f"Overall success rate: {total_success_rate:.1f}%")
256
+
257
+ print("\nPer-label statistics:")
258
+ for label, stats in label_stats.items():
259
+ success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
260
+ print(f" {label}: {stats['success']}/{stats['total']} ({success_rate:.1f}%)")
261
+
262
+ print(f"\nJSON files saved to: {output_path.absolute()}")
263
+ print("Directory structure:")
264
+ print("PoseData/")
265
+ for label in sorted(label_stats.keys()):
266
+ print(f"├── {label}/")
267
+ print("│ └── *.json")
268
+
269
+ finally:
270
+ detector.close()
271
+
272
+
273
+ def process_directory(input_dir, output_dir):
274
+ """
275
+ Process all images in a directory tree and write JSON files.
276
+
277
+ Args:
278
+ input_dir: input images directory
279
+ output_dir: output JSON directory
280
+ """
281
+ input_path = Path(input_dir)
282
+ output_path = Path(output_dir)
283
+ output_path.mkdir(parents=True, exist_ok=True)
284
+
285
+ # Supported image formats
286
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
287
+
288
+ detector = PoseDetector()
289
+
290
+ try:
291
+ # statistics
292
+ total_images = 0
293
+ success_count = 0
294
+ failed_count = 0
295
+
296
+ print(f"Starting to process directory: {input_path}")
297
+ print(f"Output directory: {output_path}")
298
+
299
+ # walk through the tree
300
+ for root, dirs, files in os.walk(input_path):
301
+ root_path = Path(root)
302
+
303
+ # create corresponding output folder
304
+ relative_path = root_path.relative_to(input_path)
305
+ current_output_dir = output_path / relative_path
306
+ current_output_dir.mkdir(parents=True, exist_ok=True)
307
+
308
+ # collect image files in this folder
309
+ image_files = [f for f in files if Path(f).suffix.lower() in image_extensions]
310
+
311
+ if image_files:
312
+ print(f"\nProcessing directory: {root_path}")
313
+ print(f"Found {len(image_files)} images")
314
+
315
+ for filename in image_files:
316
+ total_images += 1
317
+ image_path = root_path / filename
318
+
319
+ # generate JSON filename (replace extension with .json)
320
+ json_filename = Path(filename).stem + '.json'
321
+ json_path = current_output_dir / json_filename
322
+
323
+ # detect pose
324
+ result = detector.detect_pose(image_path)
325
+
326
+ if result is not None:
327
+ # save JSON file
328
+ try:
329
+ with open(json_path, 'w', encoding='utf-8') as f:
330
+ json.dump(result, f, ensure_ascii=False, indent=2)
331
+ success_count += 1
332
+
333
+ if success_count % 50 == 0:
334
+ print(f"Successfully processed {success_count} images...")
335
+
336
+ except Exception as e:
337
+ print(f"Failed to save JSON {json_path}: {e}")
338
+ failed_count += 1
339
+ else:
340
+ failed_count += 1
341
+
342
+ print("\nProcessing complete!")
343
+ print(f"Total images: {total_images}")
344
+ print(f"Successfully processed: {success_count}")
345
+ print(f"Failed: {failed_count}")
346
+ print(f"Success rate: {success_count/total_images*100:.1f}%")
347
+
348
+ finally:
349
+ detector.close()
350
+
351
+
352
+ def main():
353
+ parser = argparse.ArgumentParser(description="Run MediaPipe pose detection and save landmark data")
354
+ parser.add_argument("--input", "-i", default="TrainData/train",
355
+ help="input images directory (default: TrainData/train)")
356
+ parser.add_argument("--output", "-o", default="PoseData",
357
+ help="output JSON directory (default: PoseData)")
358
+ parser.add_argument("--batch-size", "-b", type=int, default=100,
359
+ help="batch size for progress reporting (default: 100)")
360
+
361
+ args = parser.parse_args()
362
+
363
+ # check input directory exists
364
+ if not Path(args.input).exists():
365
+ print(f"Error: input directory does not exist: {args.input}")
366
+ return
367
+
368
+ print("MediaPipe pose detection tool")
369
+ print("=" * 60)
370
+ print(f"Input directory: {args.input}")
371
+ print(f"Output directory: {args.output}")
372
+ print("Processing rule: pos = (pos - headPos) * 100, round to 2 decimals")
373
+ print("Head reference: nose")
374
+ print(f"Batch size: show progress every {args.batch_size} images")
375
+ print("=" * 60)
376
+
377
+ # Start processing the entire training dataset
378
+ process_all_training_data(args.input, args.output, args.batch_size)
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()
realtime_pose_classifier.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Real-time pose classifier
4
+ Uses MediaPipe to capture camera input, perform pose recognition and classification, and display results on screen
5
+
6
+ Features:
7
+ 1. Use MediaPipe to obtain real-time pose data from camera
8
+ 2. Extract joint coordinates and preprocess them
9
+ 3. Use trained machine learning models for pose classification
10
+ 4. Display classification results and keypoints in real-time on video screen
11
+
12
+ Dependencies:
13
+ pip install opencv-python mediapipe numpy scikit-learn
14
+
15
+ Usage:
16
+ python realtime_pose_classifier.py [--model MODEL_PATH] [--camera CAMERA_ID]
17
+ """
18
+
19
+ import cv2
20
+ import mediapipe as mp
21
+ import numpy as np
22
+ import json
23
+ import joblib
24
+ import argparse
25
+ import time
26
+ from pathlib import Path
27
+ import traceback
28
+
29
+
30
+ class RealtimePoseClassifier:
31
+ def __init__(self, model_path=None, camera_id=0):
32
+ """
33
+ Initialize real-time pose classifier
34
+
35
+ Args:
36
+ model_path (str): Model file path, auto-detect if None
37
+ camera_id (int): Camera ID, default 0
38
+ """
39
+ self.camera_id = camera_id
40
+
41
+ # Initialize MediaPipe
42
+ self.mp_pose = mp.solutions.pose
43
+ self.mp_drawing = mp.solutions.drawing_utils
44
+ self.mp_drawing_styles = mp.solutions.drawing_styles
45
+
46
+ # Configure pose detector
47
+ self.pose = self.mp_pose.Pose(
48
+ static_image_mode=False,
49
+ model_complexity=1, # Use lower complexity for real-time applications
50
+ enable_segmentation=False,
51
+ min_detection_confidence=0.7,
52
+ min_tracking_confidence=0.5
53
+ )
54
+
55
+ # MediaPipe landmark name mapping
56
+ self.landmark_names = [
57
+ 'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
58
+ 'right_eye_inner', 'right_eye', 'right_eye_outer',
59
+ 'left_ear', 'right_ear', 'mouth_left', 'mouth_right',
60
+ 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
61
+ 'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky',
62
+ 'left_index', 'right_index', 'left_thumb', 'right_thumb',
63
+ 'left_hip', 'right_hip', 'left_knee', 'right_knee',
64
+ 'left_ankle', 'right_ankle', 'left_heel', 'right_heel',
65
+ 'left_foot_index', 'right_foot_index'
66
+ ]
67
+
68
+ # Load model
69
+ self.model = None
70
+ self.scaler = None
71
+ self.label_encoder = None
72
+ self.target_joints = None
73
+ self.model_info = None
74
+
75
+ self.load_model(model_path)
76
+
77
+ # Prediction result cache
78
+ self.prediction_history = []
79
+ self.history_size = 5 # Keep recent 5 predictions for smoothing
80
+
81
+ # Performance statistics
82
+ self.fps_counter = 0
83
+ self.fps_start_time = time.time()
84
+ self.current_fps = 0
85
+
86
+ # Added: Time statistics
87
+ self.mediapipe_time_total = 0.0
88
+ self.mediapipe_time_count = 0
89
+ self.feature_pred_time_total = 0.0
90
+ self.feature_pred_time_count = 0
91
+
92
+ # Display settings
93
+ self.show_landmarks = True
94
+ self.show_connections = True
95
+
96
+ def load_model(self, model_path=None):
97
+ """Load trained model"""
98
+ if model_path is None:
99
+ # Auto-detect available model files
100
+ possible_models = [
101
+ 'pose_classifier_random_forest.pkl',
102
+ 'pose_classifier_logistic.pkl',
103
+ 'pose_classifier_distilled_rf.pkl'
104
+ ]
105
+
106
+ for model_file in possible_models:
107
+ if Path(model_file).exists():
108
+ model_path = model_file
109
+ break
110
+
111
+ if model_path is None:
112
+ raise FileNotFoundError("No available model file found, please specify model path")
113
+
114
+ try:
115
+ print(f"Loading model: {model_path}")
116
+ model_data = joblib.load(model_path)
117
+
118
+ self.model = model_data['model']
119
+ self.scaler = model_data['scaler']
120
+ self.label_encoder = model_data['label_encoder']
121
+ self.target_joints = model_data['target_joints']
122
+
123
+ # Try to load corresponding labels file
124
+ labels_path = model_path.replace('.pkl', '_labels.json')
125
+ if Path(labels_path).exists():
126
+ with open(labels_path, 'r') as f:
127
+ self.model_info = json.load(f)
128
+ print(f"Loaded label information: {labels_path}")
129
+
130
+ print("Model loaded successfully!")
131
+ print(f"Target joints: {self.target_joints}")
132
+ print(f"Classification classes: {self.label_encoder.classes_}")
133
+
134
+ except Exception as e:
135
+ raise RuntimeError(f"Model loading failed: {e}")
136
+
137
+ def extract_pose_features(self, landmarks):
138
+ """
139
+ Extract pose features from MediaPipe landmarks (vectorized optimized version)
140
+ """
141
+ if landmarks is None:
142
+ return None
143
+
144
+ # Get all joint coordinates as NumPy array
145
+ coords = np.array([[lm.x, lm.y, lm.z] for lm in landmarks.landmark], dtype=np.float32)
146
+
147
+ # Get head position (nose as reference point)
148
+ try:
149
+ head_idx = self.landmark_names.index('nose')
150
+ head_pos = coords[head_idx]
151
+ except ValueError:
152
+ return None
153
+
154
+ # Build target joint indices list
155
+ joint_indices = [self.landmark_names.index(j) if j in self.landmark_names else -1 for j in self.target_joints]
156
+
157
+ # Extract target joint coordinates (fill with 0 if not exist)
158
+ joint_coords = np.array([
159
+ coords[idx] if idx >= 0 else np.zeros(3, dtype=np.float32)
160
+ for idx in joint_indices
161
+ ], dtype=np.float32)
162
+
163
+ # Calculate relative position to head and scale
164
+ relative_coords = (joint_coords - head_pos) * 100 # Keep consistent with training processing
165
+
166
+ # Keep two decimal places
167
+ features = np.round(relative_coords, 2).flatten()
168
+
169
+ return features
170
+
171
+ def predict_pose(self, features):
172
+ """
173
+ Use machine learning model to predict pose
174
+
175
+ Args:
176
+ features: Feature vector
177
+
178
+ Returns:
179
+ dict: Prediction result containing label, confidence, etc.
180
+ """
181
+ if features is None or self.model is None:
182
+ return None
183
+
184
+ try:
185
+ # Standardize features
186
+ features_scaled = self.scaler.transform(features.reshape(1, -1))
187
+
188
+ # Predict
189
+ prediction = self.model.predict(features_scaled)[0]
190
+ predicted_label = self.label_encoder.inverse_transform([prediction])[0]
191
+
192
+ # Get confidence (if model supports probability prediction)
193
+ confidence = 0.0
194
+ probabilities = None
195
+ if hasattr(self.model, 'predict_proba'):
196
+ probs = self.model.predict_proba(features_scaled)[0]
197
+ confidence = float(np.max(probs))
198
+ probabilities = dict(zip(self.label_encoder.classes_, probs))
199
+
200
+ return {
201
+ 'predicted_label': predicted_label,
202
+ 'confidence': confidence,
203
+ 'probabilities': probabilities
204
+ }
205
+
206
+ except Exception as e:
207
+ print(f"Prediction error: {e}")
208
+ return None
209
+
210
+ def smooth_predictions(self, current_prediction):
211
+ """
212
+ Smooth prediction results
213
+
214
+ Args:
215
+ current_prediction: Current prediction result
216
+
217
+ Returns:
218
+ dict: Smoothed prediction result
219
+ """
220
+ if current_prediction is None:
221
+ return None
222
+
223
+ # Add to history
224
+ self.prediction_history.append(current_prediction)
225
+ if len(self.prediction_history) > self.history_size:
226
+ self.prediction_history.pop(0)
227
+
228
+ # If history is insufficient, return current prediction directly
229
+ if len(self.prediction_history) < 3:
230
+ return current_prediction
231
+
232
+ # Count recent prediction labels
233
+ recent_labels = [pred['predicted_label'] for pred in self.prediction_history]
234
+
235
+ # Use mode as final prediction
236
+ from collections import Counter
237
+ label_counts = Counter(recent_labels)
238
+ most_common_label = label_counts.most_common(1)[0][0]
239
+
240
+ # Calculate average confidence for this label
241
+ avg_confidence = np.mean([
242
+ pred['confidence'] for pred in self.prediction_history
243
+ if pred['predicted_label'] == most_common_label
244
+ ])
245
+
246
+ return {
247
+ 'predicted_label': most_common_label,
248
+ 'confidence': avg_confidence,
249
+ 'stability': label_counts[most_common_label] / len(recent_labels)
250
+ }
251
+
252
+ def draw_pose_info(self, image, landmarks, prediction_result):
253
+ """
254
+ Draw pose information on image
255
+
256
+ Args:
257
+ image: OpenCV image
258
+ landmarks: MediaPipe landmarks
259
+ prediction_result: Prediction result
260
+ """
261
+ height, width = image.shape[:2]
262
+
263
+ # Draw pose skeleton
264
+ if landmarks and self.show_connections:
265
+ self.mp_drawing.draw_landmarks(
266
+ image,
267
+ landmarks,
268
+ self.mp_pose.POSE_CONNECTIONS,
269
+ landmark_drawing_spec=self.mp_drawing_styles.get_default_pose_landmarks_style()
270
+ )
271
+
272
+ # Draw keypoints
273
+ if landmarks and self.show_landmarks:
274
+ for i, landmark in enumerate(landmarks.landmark):
275
+ if self.landmark_names[i] in self.target_joints:
276
+ x = int(landmark.x * width)
277
+ y = int(landmark.y * height)
278
+ cv2.circle(image, (x, y), 8, (0, 255, 0), -1)
279
+ cv2.putText(image, self.landmark_names[i], (x + 10, y - 10),
280
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
281
+
282
+ # Display prediction results
283
+ if prediction_result:
284
+ label = prediction_result['predicted_label']
285
+ confidence = prediction_result.get('confidence', 0.0)
286
+ stability = prediction_result.get('stability', 1.0)
287
+
288
+ # Set color based on confidence
289
+ if confidence > 0.8:
290
+ color = (0, 255, 0) # Green - high confidence
291
+ elif confidence > 0.6:
292
+ color = (0, 255, 255) # Yellow - medium confidence
293
+ else:
294
+ color = (0, 0, 255) # Red - low confidence
295
+
296
+ # Draw prediction result background box
297
+ cv2.rectangle(image, (10, 10), (400, 120), (0, 0, 0), -1)
298
+ cv2.rectangle(image, (10, 10), (400, 120), color, 2)
299
+
300
+ # Display prediction label
301
+ cv2.putText(image, f"Pose: {label}", (20, 40),
302
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
303
+
304
+ # Display confidence
305
+ cv2.putText(image, f"Confidence: {confidence:.2f}", (20, 70),
306
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
307
+
308
+ # Display stability
309
+ cv2.putText(image, f"Stability: {stability:.2f}", (20, 95),
310
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
311
+
312
+ # Display FPS
313
+ cv2.putText(image, f"FPS: {self.current_fps:.1f}", (width - 150, 30),
314
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
315
+
316
+ # Display control instructions
317
+ instructions = [
318
+ "Controls:",
319
+ "Q - Quit",
320
+ "L - Toggle Landmarks",
321
+ "C - Toggle Connections",
322
+ "R - Reset History"
323
+ ]
324
+
325
+ for i, instruction in enumerate(instructions):
326
+ cv2.putText(image, instruction, (width - 200, height - 120 + i * 25),
327
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
328
+
329
+ # Added: Display timing statistics
330
+ mp_avg = self.mediapipe_time_total / self.mediapipe_time_count if self.mediapipe_time_count else 0.0
331
+ fp_avg = self.feature_pred_time_total / self.feature_pred_time_count if self.feature_pred_time_count else 0.0
332
+ cv2.putText(image, f"MP avg: {mp_avg*1000:.1f}ms", (width - 150, 55),
333
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
334
+ cv2.putText(image, f"FP avg: {fp_avg*1000:.1f}ms", (width - 150, 75),
335
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
336
+ # Display average frame rate
337
+ total_frames = max(self.mediapipe_time_count, 1)
338
+ avg_fps = total_frames / max(self.mediapipe_time_total + self.feature_pred_time_total, 1e-6)
339
+ cv2.putText(image, f"Avg FPS: {avg_fps:.1f}", (width - 150, 95),
340
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
341
+
342
+ def update_fps(self):
343
+ """Update FPS calculation"""
344
+ self.fps_counter += 1
345
+ if self.fps_counter >= 30: # Update FPS every 30 frames
346
+ current_time = time.time()
347
+ self.current_fps = 30 / (current_time - self.fps_start_time)
348
+ self.fps_start_time = current_time
349
+ self.fps_counter = 0
350
+
351
+ def run(self):
352
+ """Run real-time pose classification"""
353
+ print("Starting real-time pose classifier...")
354
+ print("Press 'Q' to quit, 'L' to toggle landmark display, 'C' to toggle skeleton connections, 'R' to reset history")
355
+
356
+ # Initialize camera
357
+ cap = cv2.VideoCapture(self.camera_id)
358
+ if not cap.isOpened():
359
+ raise RuntimeError(f"Cannot open camera {self.camera_id}")
360
+
361
+ # Set camera parameters
362
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
363
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
364
+ cap.set(cv2.CAP_PROP_FPS, 30)
365
+
366
+ try:
367
+ while True:
368
+ success, frame = cap.read()
369
+ if not success:
370
+ print("Cannot read camera frame")
371
+ break
372
+
373
+ # Flip image horizontally (mirror effect)
374
+ frame = cv2.flip(frame, 1)
375
+
376
+ # Convert color space
377
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
378
+
379
+ # Time MediaPipe pose detection
380
+ mp_start = time.time()
381
+ results = self.pose.process(rgb_frame)
382
+ mp_end = time.time()
383
+ self.mediapipe_time_total += (mp_end - mp_start)
384
+ self.mediapipe_time_count += 1
385
+
386
+ # Extract features and predict
387
+ fp_start = time.time()
388
+ prediction_result = None
389
+ if results.pose_landmarks:
390
+ features = self.extract_pose_features(results.pose_landmarks)
391
+ if features is not None:
392
+ raw_prediction = self.predict_pose(features)
393
+ prediction_result = self.smooth_predictions(raw_prediction)
394
+ fp_end = time.time()
395
+ self.feature_pred_time_total += (fp_end - fp_start)
396
+ self.feature_pred_time_count += 1
397
+
398
+ # Draw results
399
+ self.draw_pose_info(frame, results.pose_landmarks, prediction_result)
400
+
401
+ # Update FPS
402
+ self.update_fps()
403
+
404
+ # Display image
405
+ cv2.imshow('Real-time Pose Classification', frame)
406
+
407
+ # Handle key presses
408
+ key = cv2.waitKey(1) & 0xFF
409
+ if key == ord('q') or key == ord('Q'):
410
+ break
411
+ elif key == ord('l') or key == ord('L'):
412
+ self.show_landmarks = not self.show_landmarks
413
+ print(f"Landmark display: {'On' if self.show_landmarks else 'Off'}")
414
+ elif key == ord('c') or key == ord('C'):
415
+ self.show_connections = not self.show_connections
416
+ print(f"Skeleton connection display: {'On' if self.show_connections else 'Off'}")
417
+ elif key == ord('r') or key == ord('R'):
418
+ self.prediction_history.clear()
419
+ print("Prediction history reset")
420
+
421
+ except KeyboardInterrupt:
422
+ print("\nUser interrupted program")
423
+ except Exception as e:
424
+ print(f"Runtime error: {e}")
425
+ traceback.print_exc()
426
+ finally:
427
+ cap.release()
428
+ cv2.destroyAllWindows()
429
+ print("Program exited")
430
+
431
+
432
+ def main():
433
+ """Main function"""
434
+ parser = argparse.ArgumentParser(description='Real-time pose classifier')
435
+ parser.add_argument('--model', '-m', type=str, default=None,
436
+ help='Model file path (auto-detect by default)')
437
+ parser.add_argument('--camera', '-c', type=int, default=0,
438
+ help='Camera ID (default 0)')
439
+
440
+ args = parser.parse_args()
441
+
442
+ try:
443
+ classifier = RealtimePoseClassifier(
444
+ model_path=args.model,
445
+ camera_id=args.camera
446
+ )
447
+ classifier.run()
448
+ except Exception as e:
449
+ print(f"Program startup failed: {e}")
450
+ return 1
451
+
452
+ return 0
453
+
454
+
455
+ if __name__ == "__main__":
456
+ exit(main())