# train.py from .yolo_manager import YOLOManager from .utils import get_abs_path, backup_file import os from .config import Config import yaml import os from pathlib import Path import shutil def convert_box_to_polygon(label_file: Path): """ Converts YOLO box-format labels (class xc yc w h) to YOLO polygon-format labels for segmentation. Creates a 4-point polygon representing the bounding box. Overwrites the label file in place if conversion is needed. """ if not label_file.exists(): return new_lines = [] changed = False with open(label_file, "r") as f: for line in f: line = line.strip() if not line: # Skip empty lines continue parts = line.split() if len(parts) == 5: # Box format → convert to polygon try: cls = int(float(parts[0])) # Class should be integer xc, yc, bw, bh = map(float, parts[1:]) # Calculate corner points (clockwise from top-left) x1 = max(0.0, min(1.0, xc - bw / 2)) # top-left x y1 = max(0.0, min(1.0, yc - bh / 2)) # top-left y x2 = max(0.0, min(1.0, xc + bw / 2)) # top-right x y2 = max(0.0, min(1.0, yc - bh / 2)) # top-right y x3 = max(0.0, min(1.0, xc + bw / 2)) # bottom-right x y3 = max(0.0, min(1.0, yc + bh / 2)) # bottom-right y x4 = max(0.0, min(1.0, xc - bw / 2)) # bottom-left x y4 = max(0.0, min(1.0, yc + bh / 2)) # bottom-left y # Format: class x1 y1 x2 y2 x3 y3 x4 y4 polygon_line = f"{cls} {x1:.6f} {y1:.6f} {x2:.6f} {y2:.6f} {x3:.6f} {y3:.6f} {x4:.6f} {y4:.6f}" new_lines.append(polygon_line) changed = True except (ValueError, IndexError): # If parsing fails, keep original line new_lines.append(line) elif len(parts) > 5 and len(parts) % 2 == 1: # Already polygon format (odd number of parts: class + pairs of coordinates) try: cls = int(float(parts[0])) coords = [float(x) for x in parts[1:]] # Clamp coordinates to [0,1] range coords = [max(0.0, min(1.0, coord)) for coord in coords] coord_str = " ".join(f"{coord:.6f}" for coord in coords) new_lines.append(f"{cls} {coord_str}") except (ValueError, IndexError): # If parsing fails, keep original line new_lines.append(line) else: # Unknown format, keep as-is new_lines.append(line) if changed: with open(label_file, "w") as f: f.write("\n".join(new_lines) + "\n") def create_filtered_dataset(original_dataset_path, output_filtered_dataset_path): """ Create a filtered dataset with only images that have non-empty labels """ shutil.rmtree(output_filtered_dataset_path, ignore_errors=True) original_path = Path(original_dataset_path) output_path = Path(output_filtered_dataset_path) # Create output directory structure output_images = output_path / "images" output_labels = output_path / "labels" for split in ['train', 'val', 'test']: (output_images / split).mkdir(parents=True, exist_ok=True) (output_labels / split).mkdir(parents=True, exist_ok=True) filtered_counts = {} for split in ['train', 'val', 'test']: original_images_dir = original_path / 'images' / split original_labels_dir = original_path / 'labels' / split output_images_dir = output_images / split output_labels_dir = output_labels / split if not original_images_dir.exists() or not original_labels_dir.exists(): print(f"Skipping {split} - source directory not found") filtered_counts[split] = 0 continue total_count = 0 copied_count = 0 # Process each image for img_file in original_images_dir.glob('*'): if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']: total_count += 1 label_file = original_labels_dir / f"{img_file.stem}.txt" # Check if label file exists and has content if label_file.exists(): with open(label_file, 'r') as f: content = f.read().strip() if content: # Label file has content # Copy image shutil.copy2(img_file, output_images_dir / img_file.name) # Copy label shutil.copy2(label_file, output_labels_dir / label_file.name) convert_box_to_polygon(output_labels_dir / label_file.name) copied_count += 1 else: print(f"Skipping {img_file.name} - empty label file") else: print(f"Skipping {img_file.name} - no label file") filtered_counts[split] = copied_count print(f"{split.upper()} split: {copied_count}/{total_count} images copied") return filtered_counts def create_filtered_yaml(output_filtered_dataset_path, filtered_counts): """ Create the YAML file for the filtered dataset """ output_path = Path(output_filtered_dataset_path) yaml_path = f'{Config.current_path}/filtered_comic.yaml' # Create YAML structure yaml_data = { 'names': ['panel'], 'nc': 1, 'path': str(output_path), 'train': str(output_path / 'images' / 'train'), 'val': str(output_path / 'images' / 'val') } # Only add test if it has images if filtered_counts.get('test', 0) > 0: yaml_data['test'] = str(output_path / 'images' / 'test') # Write YAML file with open(yaml_path, 'w') as f: yaml.dump(yaml_data, f, default_flow_style=False, sort_keys=False) print(f"\nāœ… Created filtered dataset YAML: {yaml_path}") return yaml_path def main(): """Main training function.""" try: # Initialize YOLO manager yolo_manager = YOLOManager() # Configuration data_yaml_path = f'{Config.current_path}/filtered_comic.yaml' if not os.path.isfile(data_yaml_path): raise FileNotFoundError(f"āŒ Dataset YAML not found: {data_yaml_path}") print(f"šŸŽÆ Training model: {Config.YOLO_MODEL_NAME}") # Train model model = yolo_manager.train( data_yaml_path=data_yaml_path, run_name=Config.YOLO_MODEL_NAME ) # Validate model metrics = yolo_manager.validate() # Backup best weights weights_path = yolo_manager.get_best_weights_path() backup_path = Config.yolo_trained_model_path backup_file(weights_path, backup_path) print("šŸŽ‰ Training completed successfully!") except Exception as e: print(f"āŒ Training failed: {str(e)}") raise if __name__ == "__main__":# Configuration # Configuration original_dataset_path = "/home/jebin/git/comic-panel-extractor/comic_panel_extractor/dataset" output_filtered_dataset_path = "/home/jebin/git/comic-panel-extractor/comic_panel_extractor/filtered_dataset" print("šŸ” Starting dataset filtering...") print(f"šŸ“‚ Source: {original_dataset_path}") print(f"šŸ“ Output: {output_filtered_dataset_path}") # Create filtered dataset filtered_counts = create_filtered_dataset(original_dataset_path, output_filtered_dataset_path) # Create YAML file yaml_path = create_filtered_yaml(output_filtered_dataset_path, filtered_counts) # Summary total_filtered = sum(filtered_counts.values()) print(f"\nšŸ“Š Filtering Summary:") for split, count in filtered_counts.items(): if count > 0: print(f" {split.upper()}: {count} images") print(f" TOTAL: {total_filtered} images with labels") print(f"\nšŸŽÆ Use this YAML for training: {yaml_path}") # Display the created YAML content with open(yaml_path, 'r') as f: yaml_content = f.read() print(f"\nšŸ“„ Generated YAML content:") print("─" * 50) print(yaml_content) print("─" * 50) main()