|
import os |
|
import random |
|
from pathlib import Path |
|
import shutil |
|
|
|
|
|
|
|
|
|
|
|
|
|
SOURCE_DIRS = { |
|
'location_1': 'mpala', |
|
'location_2': 'opc', |
|
'location_3': 'wilds' |
|
} |
|
|
|
|
|
DEST_DIR = "/data" |
|
|
|
|
|
CLASS_LABELS = { |
|
0: "Zebra", |
|
1: "Giraffe", |
|
2: "Onager", |
|
3: "Dog", |
|
} |
|
|
|
|
|
SAMPLING_RATE = 10 |
|
|
|
|
|
splits = { |
|
'train': { |
|
'location_3': { |
|
'session_1': ['DJI_0034', 'DJI_0035_part1'], |
|
'session_2': ['P0140018'], |
|
'session_3': ['P0100010', 'P0110011', 'P0080008', 'P0090009'], |
|
|
|
}, |
|
'location_1': { |
|
'session_1': ['DJI_0001', 'DJI_0002'], |
|
'session_2': ['DJI_0005', 'DJI_0006'], |
|
'session_3': ['DJI_0068', 'DJI_0069'], |
|
'session_4': ['DJI_0142', 'DJI_0143', 'DJI_0144'], |
|
'session_5': ['DJI_0206', 'DJI_0208'], |
|
}, |
|
'location_2': { |
|
'session_1': ['P0800081', 'P0830086', 'P0840087', 'P0870091'], |
|
'session_2': ['P0910095'], |
|
} |
|
}, |
|
'test': { |
|
'location_3': { |
|
'session_1': ['DJI_0035_part2'], |
|
'session_3': ['P0070007', 'P0160016', 'P0120012'], |
|
'session_2': ['P0150019'], |
|
'session_4': ['P0070010'], |
|
}, |
|
'location_1': { |
|
'session_3': ['DJI_0070', 'DJI_0071'], |
|
'session_4': ['DJI_0145', 'DJI_0146', 'DJI_0147'], |
|
'session_5': ['DJI_0210', 'DJI_0211'], |
|
}, |
|
'location_2': { |
|
'session_1': ['P0860090'], |
|
'session_2': ['P0940098'], |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
for split in ['train', 'test']: |
|
os.makedirs(f"{DEST_DIR}/images/{split}", exist_ok=True) |
|
os.makedirs(f"{DEST_DIR}/labels/{split}", exist_ok=True) |
|
|
|
def find_images_in_directory(dir_path): |
|
"""Find all image files in a directory""" |
|
try: |
|
return [f for f in os.listdir(dir_path) |
|
if f.endswith(('.jpg', '.png', '.jpeg')) and os.path.isfile(dir_path / f)] |
|
except (FileNotFoundError, NotADirectoryError, PermissionError) as e: |
|
print(f"Error accessing {dir_path}: {e}") |
|
return [] |
|
|
|
def find_partitions(session_path): |
|
"""Find partition directories in a session""" |
|
try: |
|
return [d for d in os.listdir(session_path) |
|
if os.path.isdir(session_path / d) and d.startswith('partition_')] |
|
except (FileNotFoundError, NotADirectoryError, PermissionError) as e: |
|
print(f"Error accessing {session_path}: {e}") |
|
return [] |
|
|
|
def find_video_images(session_path, video_name): |
|
""" |
|
Find all images for a specific video in all partitions or video directory |
|
Returns a list of tuples: (image_path, image_name, partition_name) |
|
""" |
|
all_images = [] |
|
|
|
|
|
video_path = session_path / video_name |
|
if os.path.isdir(video_path): |
|
|
|
partitions = find_partitions(video_path) |
|
|
|
if partitions: |
|
|
|
for partition in partitions: |
|
partition_path = video_path / partition |
|
images = find_images_in_directory(partition_path) |
|
all_images.extend([(partition_path, img, partition) for img in images]) |
|
else: |
|
|
|
images = find_images_in_directory(video_path) |
|
all_images.extend([(video_path, img, "") for img in images]) |
|
|
|
|
|
partitions = find_partitions(session_path) |
|
for partition in partitions: |
|
partition_path = session_path / partition |
|
|
|
|
|
for img in find_images_in_directory(partition_path): |
|
|
|
if video_name in img: |
|
all_images.append((partition_path, img, partition)) |
|
|
|
return all_images |
|
|
|
|
|
for split_name, locations in splits.items(): |
|
for location_name, sessions in locations.items(): |
|
|
|
if location_name not in SOURCE_DIRS: |
|
print(f"Warning: No source directory defined for {location_name}. Skipping.") |
|
continue |
|
|
|
location_source_dir = Path(SOURCE_DIRS[location_name]) |
|
|
|
for session_name, video_info in sessions.items(): |
|
session_path = location_source_dir / session_name |
|
|
|
if not os.path.exists(session_path): |
|
print(f"Warning: Session path {session_path} does not exist. Skipping.") |
|
continue |
|
|
|
|
|
if isinstance(video_info, bool) and video_info: |
|
|
|
try: |
|
|
|
videos = [v for v in os.listdir(session_path) |
|
if os.path.isdir(session_path / v) and not v.startswith('partition_')] |
|
|
|
|
|
if not videos: |
|
partitions = find_partitions(session_path) |
|
if partitions: |
|
|
|
first_partition = session_path / partitions[0] |
|
all_imgs = find_images_in_directory(first_partition) |
|
|
|
videos = list(set([img.split('_')[0] for img in all_imgs if '_' in img])) |
|
|
|
except (FileNotFoundError, NotADirectoryError) as e: |
|
print(f"Warning: Could not list directory {session_path}: {e}") |
|
continue |
|
else: |
|
|
|
videos = video_info |
|
|
|
|
|
for video in videos: |
|
print(f"Processing {location_name}/{session_name}/{video}...") |
|
|
|
|
|
frame_info = find_video_images(session_path, video) |
|
|
|
if not frame_info: |
|
print(f"Warning: No frames found for {video} in {session_name}") |
|
continue |
|
|
|
|
|
frame_info.sort(key=lambda x: x[1]) |
|
|
|
|
|
sampled_frame_info = frame_info[::SAMPLING_RATE] |
|
|
|
|
|
for frame_dir, frame_name, partition in sampled_frame_info: |
|
|
|
partition_str = "" if partition == "" else f"_{partition}" |
|
|
|
|
|
src_img = frame_dir / frame_name |
|
dest_img_name = f"{location_name}_{session_name}_{video}{partition_str}_{frame_name}" |
|
dest_img = Path(DEST_DIR) / "images" / split_name / dest_img_name |
|
|
|
try: |
|
shutil.copy(src_img, dest_img) |
|
except (FileNotFoundError, IOError) as e: |
|
print(f"Error copying image {src_img}: {e}") |
|
continue |
|
|
|
|
|
label_name = frame_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt') |
|
|
|
|
|
possible_label_paths = [ |
|
|
|
frame_dir / label_name, |
|
|
|
|
|
frame_dir / "labels" / label_name, |
|
|
|
|
|
session_path / "labels" / partition / label_name, |
|
|
|
|
|
session_path / "labels" / label_name, |
|
|
|
|
|
session_path / video / "labels" / label_name, |
|
] |
|
|
|
src_label = None |
|
for label_path in possible_label_paths: |
|
if os.path.exists(label_path): |
|
src_label = label_path |
|
break |
|
|
|
if src_label: |
|
dest_label_name = dest_img_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt') |
|
dest_label = Path(DEST_DIR) / "labels" / split_name / dest_label_name |
|
try: |
|
shutil.copy(src_label, dest_label) |
|
except (FileNotFoundError, IOError) as e: |
|
print(f"Error copying label {src_label}: {e}") |
|
else: |
|
print(f"Warning: No label found for {src_img}") |
|
|
|
print("Dataset split completed successfully!") |
|
|
|
|
|
def create_dataset_yaml(): |
|
with open(f"{DEST_DIR}/dataset.yaml", "w") as f: |
|
f.write(f"# YOLOv11 dataset config\n") |
|
f.write(f"path: {os.path.abspath(DEST_DIR)} # dataset root dir\n") |
|
f.write(f"train: images/train # train images\n") |
|
f.write(f"val: images/train # validation uses train images\n") |
|
f.write(f"test: images/test # test images\n\n") |
|
|
|
f.write(f"# Classes\n") |
|
f.write(f"names:\n") |
|
for class_id, class_name in CLASS_LABELS.items(): |
|
f.write(f" {class_id}: {class_name}\n") |
|
|
|
create_dataset_yaml() |
|
|
|
|
|
stats = {"train": {}, "test": {}} |
|
|
|
for split in ['train', 'test']: |
|
|
|
locations = {} |
|
species_count = {} |
|
|
|
|
|
img_dir = Path(DEST_DIR) / "images" / split |
|
if not os.path.exists(img_dir): |
|
print(f"Warning: Directory {img_dir} does not exist.") |
|
continue |
|
|
|
total_count = 0 |
|
|
|
for img in os.listdir(img_dir): |
|
parts = img.split('_') |
|
if len(parts) < 2: |
|
continue |
|
|
|
location = parts[0] |
|
session = parts[1] |
|
|
|
|
|
if location not in locations: |
|
locations[location] = 0 |
|
locations[location] += 1 |
|
|
|
|
|
species_key = f"{location}_{session}" |
|
if species_key not in species_count: |
|
species_count[species_key] = 0 |
|
species_count[species_key] += 1 |
|
|
|
|
|
total_count += 1 |
|
|
|
stats[split]["total"] = total_count |
|
stats[split]["locations"] = locations |
|
stats[split]["species"] = species_count |
|
|
|
|
|
for split, data in stats.items(): |
|
print(f"\n{split.upper()} set:") |
|
print(f"Total images: {data['total']}") |
|
|
|
print("Distribution by location:") |
|
for loc, count in data["locations"].items(): |
|
percentage = (count/data['total']*100) if data['total'] > 0 else 0 |
|
print(f" - {loc}: {count} ({percentage:.1f}%)") |
|
|
|
print("\nDistribution by location_session:") |
|
for species_key, count in data["species"].items(): |
|
percentage = (count/data['total']*100) if data['total'] > 0 else 0 |
|
print(f" - {species_key}: {count} ({percentage:.1f}%)") |
|
|
|
print("\nOverall train/test ratio:", |
|
f"{stats['train']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}", |
|
f"/ {stats['test']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}") |