YogaPoseClassify / extract_images.py
pegasama's picture
train and test python script
b26156a verified
raw
history blame
5.51 kB
#!/usr/bin/env python3
"""Extract images and labels from Parquet files and save them into
subfolders by label.
Usage:
python extract_images.py [--train] [--test] [--output OUTPUT_DIR]
Defaults:
train: process training data (train-00000-of-00001.parquet)
test: process test data (test-00000-of-00001.parquet)
output: TrainData (relative to script location)
"""
import os
import sys
import argparse
from pathlib import Path
import pyarrow.parquet as pq
def extract_images_from_parquet(parquet_path, output_dir, split_name):
"""Extract images from a Parquet file and save them into label folders."""
print(f"Processing {parquet_path}...")
# read parquet file
try:
table = pq.read_table(parquet_path)
df = table.to_pandas()
except Exception as e:
print(f"Failed to read parquet file: {e}")
return False
print(f"Found {len(df)} images")
# get unique labels
unique_labels = sorted(df['label'].unique())
print(f"Label classes: {unique_labels}")
# create folder for each label
for label in unique_labels:
label_dir = output_dir / split_name / f"label_{label}"
label_dir.mkdir(parents=True, exist_ok=True)
print(f"Created folder: {label_dir}")
# extract and save images
success_count = 0
error_count = 0
for idx, row in df.iterrows():
try:
# get image data
image_struct = row['image']
image_bytes = image_struct['bytes']
original_path = image_struct['path']
label = row['label']
# get file extension
_, ext = os.path.splitext(original_path)
if not ext:
ext = '.jpg' # default extension
# build a new filename (preserve original base name, avoid collisions)
base_name = os.path.splitext(os.path.basename(original_path))[0]
filename = f"{base_name}{ext}"
# ensure filename is unique
label_dir = output_dir / split_name / f"label_{label}"
output_path = label_dir / filename
counter = 1
while output_path.exists():
filename = f"{base_name}_{counter}{ext}"
output_path = label_dir / filename
counter += 1
# save image
with open(output_path, 'wb') as f:
f.write(image_bytes)
success_count += 1
if success_count % 100 == 0:
print(f"Processed {success_count} images...")
except Exception as e:
print(f"Error processing image {idx}: {e}")
error_count += 1
continue
print(f"Done! Success: {success_count}, Failed: {error_count}")
# report counts per label
print("\nImage count per label:")
for label in unique_labels:
label_dir = output_dir / split_name / f"label_{label}"
count = len(list(label_dir.glob("*")))
print(f" label {label}: {count} images")
return success_count > 0
def main():
parser = argparse.ArgumentParser(description="Extract images from Parquet files and organize by label")
parser.add_argument("--train", action="store_true", help="process training data")
parser.add_argument("--test", action="store_true", help="process test data")
parser.add_argument("--output", "-o", default="TrainData", help="output directory")
args = parser.parse_args()
# if neither train nor test specified, do both by default
if not args.train and not args.test:
args.train = True
args.test = True
# set paths
script_dir = Path(__file__).parent
yoga_data_dir = script_dir / "YogaDataSet" / "data"
output_dir = Path(args.output)
# ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir.absolute()}")
success = True
# process training data
if args.train:
train_parquet = yoga_data_dir / "train-00000-of-00001.parquet"
if train_parquet.exists():
if not extract_images_from_parquet(train_parquet, output_dir, "train"):
success = False
else:
print(f"Training parquet file not found: {train_parquet}")
success = False
# process test data
if args.test:
test_parquet = yoga_data_dir / "test-00000-of-00001.parquet"
if test_parquet.exists():
if not extract_images_from_parquet(test_parquet, output_dir, "test"):
success = False
else:
print(f"Test parquet file not found: {test_parquet}")
success = False
if success:
print("\nβœ… All images extracted!")
print(f"Images saved to: {output_dir.absolute()}")
print("Directory structure:")
print("TrainData/")
if args.train:
print("β”œβ”€β”€ train/")
print("β”‚ β”œβ”€β”€ label_0/")
print("β”‚ β”œβ”€β”€ label_1/")
print("β”‚ └── ...")
if args.test:
print("└── test/")
print(" β”œβ”€β”€ label_0/")
print(" β”œβ”€β”€ label_1/")
print(" └── ...")
else:
print("\n❌ Errors occurred during extraction")
sys.exit(1)
if __name__ == "__main__":
main()