Spaces:
Build error
Build error
# train.py | |
from .yolo_manager import YOLOManager | |
from .utils import get_abs_path, backup_file | |
import os | |
from .config import Config | |
import yaml | |
import os | |
from pathlib import Path | |
import shutil | |
def convert_box_to_polygon(label_file: Path): | |
""" | |
Converts YOLO box-format labels (class xc yc w h) to YOLO polygon-format labels | |
for segmentation. Creates a 4-point polygon representing the bounding box. | |
Overwrites the label file in place if conversion is needed. | |
""" | |
if not label_file.exists(): | |
return | |
new_lines = [] | |
changed = False | |
with open(label_file, "r") as f: | |
for line in f: | |
line = line.strip() | |
if not line: # Skip empty lines | |
continue | |
parts = line.split() | |
if len(parts) == 5: | |
# Box format β convert to polygon | |
try: | |
cls = int(float(parts[0])) # Class should be integer | |
xc, yc, bw, bh = map(float, parts[1:]) | |
# Calculate corner points (clockwise from top-left) | |
x1 = max(0.0, min(1.0, xc - bw / 2)) # top-left x | |
y1 = max(0.0, min(1.0, yc - bh / 2)) # top-left y | |
x2 = max(0.0, min(1.0, xc + bw / 2)) # top-right x | |
y2 = max(0.0, min(1.0, yc - bh / 2)) # top-right y | |
x3 = max(0.0, min(1.0, xc + bw / 2)) # bottom-right x | |
y3 = max(0.0, min(1.0, yc + bh / 2)) # bottom-right y | |
x4 = max(0.0, min(1.0, xc - bw / 2)) # bottom-left x | |
y4 = max(0.0, min(1.0, yc + bh / 2)) # bottom-left y | |
# Format: class x1 y1 x2 y2 x3 y3 x4 y4 | |
polygon_line = f"{cls} {x1:.6f} {y1:.6f} {x2:.6f} {y2:.6f} {x3:.6f} {y3:.6f} {x4:.6f} {y4:.6f}" | |
new_lines.append(polygon_line) | |
changed = True | |
except (ValueError, IndexError): | |
# If parsing fails, keep original line | |
new_lines.append(line) | |
elif len(parts) > 5 and len(parts) % 2 == 1: | |
# Already polygon format (odd number of parts: class + pairs of coordinates) | |
try: | |
cls = int(float(parts[0])) | |
coords = [float(x) for x in parts[1:]] | |
# Clamp coordinates to [0,1] range | |
coords = [max(0.0, min(1.0, coord)) for coord in coords] | |
coord_str = " ".join(f"{coord:.6f}" for coord in coords) | |
new_lines.append(f"{cls} {coord_str}") | |
except (ValueError, IndexError): | |
# If parsing fails, keep original line | |
new_lines.append(line) | |
else: | |
# Unknown format, keep as-is | |
new_lines.append(line) | |
if changed: | |
with open(label_file, "w") as f: | |
f.write("\n".join(new_lines) + "\n") | |
def create_filtered_dataset(original_dataset_path, output_filtered_dataset_path): | |
""" | |
Create a filtered dataset with only images that have non-empty labels | |
""" | |
shutil.rmtree(output_filtered_dataset_path, ignore_errors=True) | |
original_path = Path(original_dataset_path) | |
output_path = Path(output_filtered_dataset_path) | |
# Create output directory structure | |
output_images = output_path / "images" | |
output_labels = output_path / "labels" | |
for split in ['train', 'val', 'test']: | |
(output_images / split).mkdir(parents=True, exist_ok=True) | |
(output_labels / split).mkdir(parents=True, exist_ok=True) | |
filtered_counts = {} | |
for split in ['train', 'val', 'test']: | |
original_images_dir = original_path / 'images' / split | |
original_labels_dir = original_path / 'labels' / split | |
output_images_dir = output_images / split | |
output_labels_dir = output_labels / split | |
if not original_images_dir.exists() or not original_labels_dir.exists(): | |
print(f"Skipping {split} - source directory not found") | |
filtered_counts[split] = 0 | |
continue | |
total_count = 0 | |
copied_count = 0 | |
# Process each image | |
for img_file in original_images_dir.glob('*'): | |
if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']: | |
total_count += 1 | |
label_file = original_labels_dir / f"{img_file.stem}.txt" | |
# Check if label file exists and has content | |
if label_file.exists(): | |
with open(label_file, 'r') as f: | |
content = f.read().strip() | |
if content: # Label file has content | |
# Copy image | |
shutil.copy2(img_file, output_images_dir / img_file.name) | |
# Copy label | |
shutil.copy2(label_file, output_labels_dir / label_file.name) | |
convert_box_to_polygon(output_labels_dir / label_file.name) | |
copied_count += 1 | |
else: | |
print(f"Skipping {img_file.name} - empty label file") | |
else: | |
print(f"Skipping {img_file.name} - no label file") | |
filtered_counts[split] = copied_count | |
print(f"{split.upper()} split: {copied_count}/{total_count} images copied") | |
return filtered_counts | |
def create_filtered_yaml(output_filtered_dataset_path, filtered_counts): | |
""" | |
Create the YAML file for the filtered dataset | |
""" | |
output_path = Path(output_filtered_dataset_path) | |
yaml_path = f'{Config.current_path}/filtered_comic.yaml' | |
# Create YAML structure | |
yaml_data = { | |
'names': ['panel'], | |
'nc': 1, | |
'path': str(output_path), | |
'train': str(output_path / 'images' / 'train'), | |
'val': str(output_path / 'images' / 'val') | |
} | |
# Only add test if it has images | |
if filtered_counts.get('test', 0) > 0: | |
yaml_data['test'] = str(output_path / 'images' / 'test') | |
# Write YAML file | |
with open(yaml_path, 'w') as f: | |
yaml.dump(yaml_data, f, default_flow_style=False, sort_keys=False) | |
print(f"\nβ Created filtered dataset YAML: {yaml_path}") | |
return yaml_path | |
def main(): | |
"""Main training function.""" | |
try: | |
# Initialize YOLO manager | |
yolo_manager = YOLOManager() | |
# Configuration | |
data_yaml_path = f'{Config.current_path}/filtered_comic.yaml' | |
if not os.path.isfile(data_yaml_path): | |
raise FileNotFoundError(f"β Dataset YAML not found: {data_yaml_path}") | |
print(f"π― Training model: {Config.YOLO_MODEL_NAME}") | |
# Train model | |
model = yolo_manager.train( | |
data_yaml_path=data_yaml_path, | |
run_name=Config.YOLO_MODEL_NAME | |
) | |
# Validate model | |
metrics = yolo_manager.validate() | |
# Backup best weights | |
weights_path = yolo_manager.get_best_weights_path() | |
backup_path = Config.yolo_trained_model_path | |
backup_file(weights_path, backup_path) | |
print("π Training completed successfully!") | |
except Exception as e: | |
print(f"β Training failed: {str(e)}") | |
raise | |
if __name__ == "__main__":# Configuration | |
# Configuration | |
original_dataset_path = "/home/jebin/git/comic-panel-extractor/comic_panel_extractor/dataset" | |
output_filtered_dataset_path = "/home/jebin/git/comic-panel-extractor/comic_panel_extractor/filtered_dataset" | |
print("π Starting dataset filtering...") | |
print(f"π Source: {original_dataset_path}") | |
print(f"π Output: {output_filtered_dataset_path}") | |
# Create filtered dataset | |
filtered_counts = create_filtered_dataset(original_dataset_path, output_filtered_dataset_path) | |
# Create YAML file | |
yaml_path = create_filtered_yaml(output_filtered_dataset_path, filtered_counts) | |
# Summary | |
total_filtered = sum(filtered_counts.values()) | |
print(f"\nπ Filtering Summary:") | |
for split, count in filtered_counts.items(): | |
if count > 0: | |
print(f" {split.upper()}: {count} images") | |
print(f" TOTAL: {total_filtered} images with labels") | |
print(f"\nπ― Use this YAML for training: {yaml_path}") | |
# Display the created YAML content | |
with open(yaml_path, 'r') as f: | |
yaml_content = f.read() | |
print(f"\nπ Generated YAML content:") | |
print("β" * 50) | |
print(yaml_content) | |
print("β" * 50) | |
main() |