Spaces:
Build error
Build error
File size: 8,865 Bytes
05be5a5 c13ce0c 05be5a5 c13ce0c efc9b5d c458a5a c13ce0c c458a5a c13ce0c efc9b5d c13ce0c c458a5a c13ce0c c458a5a c13ce0c 05be5a5 c13ce0c 05be5a5 cf3d6df 05be5a5 c13ce0c 0b2f929 c13ce0c c458a5a c13ce0c c458a5a c13ce0c c458a5a c13ce0c 05be5a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
# train.py
from .yolo_manager import YOLOManager
from .utils import get_abs_path, backup_file
import os
from .config import Config
import yaml
import os
from pathlib import Path
import shutil
def convert_box_to_polygon(label_file: Path):
"""
Converts YOLO box-format labels (class xc yc w h) to YOLO polygon-format labels
for segmentation. Creates a 4-point polygon representing the bounding box.
Overwrites the label file in place if conversion is needed.
"""
if not label_file.exists():
return
new_lines = []
changed = False
with open(label_file, "r") as f:
for line in f:
line = line.strip()
if not line: # Skip empty lines
continue
parts = line.split()
if len(parts) == 5:
# Box format β convert to polygon
try:
cls = int(float(parts[0])) # Class should be integer
xc, yc, bw, bh = map(float, parts[1:])
# Calculate corner points (clockwise from top-left)
x1 = max(0.0, min(1.0, xc - bw / 2)) # top-left x
y1 = max(0.0, min(1.0, yc - bh / 2)) # top-left y
x2 = max(0.0, min(1.0, xc + bw / 2)) # top-right x
y2 = max(0.0, min(1.0, yc - bh / 2)) # top-right y
x3 = max(0.0, min(1.0, xc + bw / 2)) # bottom-right x
y3 = max(0.0, min(1.0, yc + bh / 2)) # bottom-right y
x4 = max(0.0, min(1.0, xc - bw / 2)) # bottom-left x
y4 = max(0.0, min(1.0, yc + bh / 2)) # bottom-left y
# Format: class x1 y1 x2 y2 x3 y3 x4 y4
polygon_line = f"{cls} {x1:.6f} {y1:.6f} {x2:.6f} {y2:.6f} {x3:.6f} {y3:.6f} {x4:.6f} {y4:.6f}"
new_lines.append(polygon_line)
changed = True
except (ValueError, IndexError):
# If parsing fails, keep original line
new_lines.append(line)
elif len(parts) > 5 and len(parts) % 2 == 1:
# Already polygon format (odd number of parts: class + pairs of coordinates)
try:
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
# Clamp coordinates to [0,1] range
coords = [max(0.0, min(1.0, coord)) for coord in coords]
coord_str = " ".join(f"{coord:.6f}" for coord in coords)
new_lines.append(f"{cls} {coord_str}")
except (ValueError, IndexError):
# If parsing fails, keep original line
new_lines.append(line)
else:
# Unknown format, keep as-is
new_lines.append(line)
if changed:
with open(label_file, "w") as f:
f.write("\n".join(new_lines) + "\n")
def create_filtered_dataset(original_dataset_path, output_filtered_dataset_path):
"""
Create a filtered dataset with only images that have non-empty labels
"""
shutil.rmtree(output_filtered_dataset_path, ignore_errors=True)
original_path = Path(original_dataset_path)
output_path = Path(output_filtered_dataset_path)
# Create output directory structure
output_images = output_path / "images"
output_labels = output_path / "labels"
for split in ['train', 'val', 'test']:
(output_images / split).mkdir(parents=True, exist_ok=True)
(output_labels / split).mkdir(parents=True, exist_ok=True)
filtered_counts = {}
for split in ['train', 'val', 'test']:
original_images_dir = original_path / 'images' / split
original_labels_dir = original_path / 'labels' / split
output_images_dir = output_images / split
output_labels_dir = output_labels / split
if not original_images_dir.exists() or not original_labels_dir.exists():
print(f"Skipping {split} - source directory not found")
filtered_counts[split] = 0
continue
total_count = 0
copied_count = 0
# Process each image
for img_file in original_images_dir.glob('*'):
if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
total_count += 1
label_file = original_labels_dir / f"{img_file.stem}.txt"
# Check if label file exists and has content
if label_file.exists():
with open(label_file, 'r') as f:
content = f.read().strip()
if content: # Label file has content
# Copy image
shutil.copy2(img_file, output_images_dir / img_file.name)
# Copy label
shutil.copy2(label_file, output_labels_dir / label_file.name)
convert_box_to_polygon(output_labels_dir / label_file.name)
copied_count += 1
else:
print(f"Skipping {img_file.name} - empty label file")
else:
print(f"Skipping {img_file.name} - no label file")
filtered_counts[split] = copied_count
print(f"{split.upper()} split: {copied_count}/{total_count} images copied")
return filtered_counts
def create_filtered_yaml(output_filtered_dataset_path, filtered_counts):
"""
Create the YAML file for the filtered dataset
"""
output_path = Path(output_filtered_dataset_path)
yaml_path = f'{Config.current_path}/filtered_comic.yaml'
# Create YAML structure
yaml_data = {
'names': ['panel'],
'nc': 1,
'path': str(output_path),
'train': str(output_path / 'images' / 'train'),
'val': str(output_path / 'images' / 'val')
}
# Only add test if it has images
if filtered_counts.get('test', 0) > 0:
yaml_data['test'] = str(output_path / 'images' / 'test')
# Write YAML file
with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, default_flow_style=False, sort_keys=False)
print(f"\nβ
Created filtered dataset YAML: {yaml_path}")
return yaml_path
def main():
"""Main training function."""
try:
# Initialize YOLO manager
yolo_manager = YOLOManager()
# Configuration
data_yaml_path = f'{Config.current_path}/filtered_comic.yaml'
if not os.path.isfile(data_yaml_path):
raise FileNotFoundError(f"β Dataset YAML not found: {data_yaml_path}")
print(f"π― Training model: {Config.YOLO_MODEL_NAME}")
# Train model
model = yolo_manager.train(
data_yaml_path=data_yaml_path,
run_name=Config.YOLO_MODEL_NAME
)
# Validate model
metrics = yolo_manager.validate()
# Backup best weights
weights_path = yolo_manager.get_best_weights_path()
backup_path = Config.yolo_trained_model_path
backup_file(weights_path, backup_path)
print("π Training completed successfully!")
except Exception as e:
print(f"β Training failed: {str(e)}")
raise
if __name__ == "__main__":# Configuration
# Configuration
original_dataset_path = "/home/jebin/git/comic-panel-extractor/comic_panel_extractor/dataset"
output_filtered_dataset_path = "/home/jebin/git/comic-panel-extractor/comic_panel_extractor/filtered_dataset"
print("π Starting dataset filtering...")
print(f"π Source: {original_dataset_path}")
print(f"π Output: {output_filtered_dataset_path}")
# Create filtered dataset
filtered_counts = create_filtered_dataset(original_dataset_path, output_filtered_dataset_path)
# Create YAML file
yaml_path = create_filtered_yaml(output_filtered_dataset_path, filtered_counts)
# Summary
total_filtered = sum(filtered_counts.values())
print(f"\nπ Filtering Summary:")
for split, count in filtered_counts.items():
if count > 0:
print(f" {split.upper()}: {count} images")
print(f" TOTAL: {total_filtered} images with labels")
print(f"\nπ― Use this YAML for training: {yaml_path}")
# Display the created YAML content
with open(yaml_path, 'r') as f:
yaml_content = f.read()
print(f"\nπ Generated YAML content:")
print("β" * 50)
print(yaml_content)
print("β" * 50)
main() |