#!/usr/bin/env python3 """ Stage 1: Data Loading and Image Downloading Downloads and preprocesses top 2000 images from parquet file """ import os import json import requests import pandas as pd from PIL import Image from io import BytesIO import concurrent.futures from pathlib import Path import time import logging import numpy as np from typing import Tuple # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def setup_environment(): """Setup data directory""" os.makedirs('./data', exist_ok=True) os.makedirs('./data/images', exist_ok=True) os.makedirs('./data/metadata', exist_ok=True) return True def load_and_sample_data(parquet_path: str, n_samples: int = 2000) -> pd.DataFrame: """Load parquet file and sample top N rows""" logger.info(f"Loading data from {parquet_path}") df = pd.read_parquet(parquet_path) logger.info(f"Loaded {len(df)} rows, sampling top {n_samples}") return df.head(n_samples) def has_white_edges(img: Image.Image, threshold: int = 240) -> bool: """Check if image has 3 or more white edges (mean RGB > threshold)""" try: img_array = np.array(img) height, width = img_array.shape[:2] # Define edge thickness (check 5 pixels from each edge) edge_thickness = 5 # Get edges top_edge = img_array[:edge_thickness, :].mean(axis=(0, 1)) bottom_edge = img_array[-edge_thickness:, :].mean(axis=(0, 1)) left_edge = img_array[:, :edge_thickness].mean(axis=(0, 1)) right_edge = img_array[:, -edge_thickness:].mean(axis=(0, 1)) # Check if edge is white (all RGB channels > threshold) edges = [top_edge, bottom_edge, left_edge, right_edge] white_edges = sum(1 for edge in edges if np.all(edge > threshold)) return white_edges >= 3 except Exception as e: logger.debug(f"Error checking white edges: {e}") return False def download_and_process_image(url: str, target_size: int = 256) -> Image.Image: """Download image and resize with center crop, skip if has white edges""" try: response = requests.get(url, timeout=10, headers={'User-Agent': 'Mozilla/5.0'}) response.raise_for_status() img = Image.open(BytesIO(response.content)).convert('RGB') # Check for white edges before processing if has_white_edges(img): logger.debug(f"Skipping image with white edges: {url}") return None # Resize and center crop to target_size x target_size width, height = img.size min_side = min(width, height) scale = target_size / min_side new_width = int(width * scale) new_height = int(height * scale) img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # Center crop left = (new_width - target_size) // 2 top = (new_height - target_size) // 2 right = left + target_size bottom = top + target_size img = img.crop((left, top, right, bottom)) # Double-check after processing if has_white_edges(img): logger.debug(f"Skipping processed image with white edges: {url}") return None return img except Exception as e: logger.error(f"Error downloading {url}: {e}") return None def process_single_image(args: Tuple[int, str, str, str]) -> bool: """Download and save a single image""" idx, url, hash_val, caption = args try: # Download and process image image = download_and_process_image(url) if image is None: logger.debug(f"Skipped image {idx} (white edges or download error)") return False # Save image image_path = f'./data/images/img_{idx}.png' image.save(image_path) # Save metadata for next stage metadata = { "idx": idx, "caption": caption, "url": url, "hash": hash_val, "image_path": image_path } metadata_path = f'./data/metadata/meta_{idx}.json' with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) logger.info(f"Downloaded and saved image {idx}") return True except Exception as e: logger.error(f"Error processing image {idx}: {e}") return False def download_images(df: pd.DataFrame, max_workers: int = 20): """Download all images with parallel processing""" logger.info(f"Starting image download with {max_workers} workers...") args_list = [(i, row['url'], row['hash'], row['caption']) for i, (_, row) in enumerate(df.iterrows())] successful = 0 skipped_white_edges = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(process_single_image, args) for args in args_list] for i, future in enumerate(concurrent.futures.as_completed(futures)): if future.result(): successful += 1 else: skipped_white_edges += 1 # Progress logging every 100 images if (i + 1) % 100 == 0: logger.info(f"Processed {i + 1}/{len(args_list)} images (successful: {successful}, skipped: {skipped_white_edges})") # Minimal rate limiting for high concurrency time.sleep(0.01) logger.info(f"Download complete: {successful}/{len(args_list)} images downloaded, {skipped_white_edges} skipped (white edges)") # Save summary summary = { "total_images": len(args_list), "successful_downloads": successful, "skipped_white_edges": skipped_white_edges, "download_rate": f"{successful/len(args_list)*100:.1f}%", "stage": "download_complete" } with open('./data/stage1_summary.json', 'w') as f: json.dump(summary, f, indent=2) def main(): """Main execution for Stage 1""" logger.info("Starting Stage 1: Data Loading and Image Downloading...") # Setup setup_environment() # Load data parquet_path = '/home/fal/partiprompt_clip/curated_part_00000.parquet' df = load_and_sample_data(parquet_path, n_samples=5000) # Save the dataframe for other stages df.to_pickle('./data/sampled_data.pkl') # Download images with optimized settings download_images(df, max_workers=30) logger.info("Stage 1 completed successfully!") if __name__ == "__main__": main()