File size: 5,511 Bytes
b26156a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3
"""Extract images and labels from Parquet files and save them into
subfolders by label.

Usage:
    python extract_images.py [--train] [--test] [--output OUTPUT_DIR]

Defaults:
    train: process training data (train-00000-of-00001.parquet)
    test:  process test data (test-00000-of-00001.parquet)
    output: TrainData (relative to script location)
"""
import os
import sys
import argparse
from pathlib import Path
import pyarrow.parquet as pq


def extract_images_from_parquet(parquet_path, output_dir, split_name):
    """Extract images from a Parquet file and save them into label folders."""

    print(f"Processing {parquet_path}...")
    
    # read parquet file
    try:
        table = pq.read_table(parquet_path)
        df = table.to_pandas()
    except Exception as e:
        print(f"Failed to read parquet file: {e}")
        return False
    
    print(f"Found {len(df)} images")
    
    # get unique labels
    unique_labels = sorted(df['label'].unique())
    print(f"Label classes: {unique_labels}")
    
    # create folder for each label
    for label in unique_labels:
        label_dir = output_dir / split_name / f"label_{label}"
        label_dir.mkdir(parents=True, exist_ok=True)
        print(f"Created folder: {label_dir}")
    
    # extract and save images
    success_count = 0
    error_count = 0
    
    for idx, row in df.iterrows():
        try:
            # get image data
            image_struct = row['image']
            image_bytes = image_struct['bytes']
            original_path = image_struct['path']
            label = row['label']
            
            # get file extension
            _, ext = os.path.splitext(original_path)
            if not ext:
                ext = '.jpg'  # default extension
            
            # build a new filename (preserve original base name, avoid collisions)
            base_name = os.path.splitext(os.path.basename(original_path))[0]
            filename = f"{base_name}{ext}"
            
            # ensure filename is unique
            label_dir = output_dir / split_name / f"label_{label}"
            output_path = label_dir / filename
            counter = 1
            while output_path.exists():
                filename = f"{base_name}_{counter}{ext}"
                output_path = label_dir / filename
                counter += 1
            
            # save image
            with open(output_path, 'wb') as f:
                f.write(image_bytes)
            
            success_count += 1
            if success_count % 100 == 0:
                print(f"Processed {success_count} images...")
                
        except Exception as e:
            print(f"Error processing image {idx}: {e}")
            error_count += 1
            continue
    
    print(f"Done! Success: {success_count}, Failed: {error_count}")
    
    # report counts per label
    print("\nImage count per label:")
    for label in unique_labels:
        label_dir = output_dir / split_name / f"label_{label}"
        count = len(list(label_dir.glob("*")))
        print(f"  label {label}: {count} images")
    
    return success_count > 0


def main():
    parser = argparse.ArgumentParser(description="Extract images from Parquet files and organize by label")
    parser.add_argument("--train", action="store_true", help="process training data")
    parser.add_argument("--test", action="store_true", help="process test data")
    parser.add_argument("--output", "-o", default="TrainData", help="output directory")
    
    args = parser.parse_args()
    
    # if neither train nor test specified, do both by default
    if not args.train and not args.test:
        args.train = True
        args.test = True
    
    # set paths
    script_dir = Path(__file__).parent
    yoga_data_dir = script_dir / "YogaDataSet" / "data"
    output_dir = Path(args.output)
    
    # ensure output directory exists
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"Output directory: {output_dir.absolute()}")
    
    success = True
    
    # process training data
    if args.train:
        train_parquet = yoga_data_dir / "train-00000-of-00001.parquet"
        if train_parquet.exists():
            if not extract_images_from_parquet(train_parquet, output_dir, "train"):
                success = False
        else:
            print(f"Training parquet file not found: {train_parquet}")
            success = False
    
    # process test data
    if args.test:
        test_parquet = yoga_data_dir / "test-00000-of-00001.parquet"
        if test_parquet.exists():
            if not extract_images_from_parquet(test_parquet, output_dir, "test"):
                success = False
        else:
            print(f"Test parquet file not found: {test_parquet}")
            success = False
    
    if success:
        print("\nβœ… All images extracted!")
        print(f"Images saved to: {output_dir.absolute()}")
        print("Directory structure:")
        print("TrainData/")
        if args.train:
            print("β”œβ”€β”€ train/")
            print("β”‚   β”œβ”€β”€ label_0/")
            print("β”‚   β”œβ”€β”€ label_1/")
            print("β”‚   └── ...")
        if args.test:
            print("└── test/")
            print("    β”œβ”€β”€ label_0/")
            print("    β”œβ”€β”€ label_1/")
            print("    └── ...")
    else:
        print("\n❌ Errors occurred during extraction")
        sys.exit(1)


if __name__ == "__main__":
    main()