Spaces:
No application file
No application file
#!/usr/bin/env python3 | |
""" | |
Stage 1: Data Loading and Image Downloading | |
Downloads and preprocesses top 2000 images from parquet file | |
""" | |
import os | |
import json | |
import requests | |
import pandas as pd | |
from PIL import Image | |
from io import BytesIO | |
import concurrent.futures | |
from pathlib import Path | |
import time | |
import logging | |
import numpy as np | |
from typing import Tuple | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def setup_environment(): | |
"""Setup data directory""" | |
os.makedirs('./data', exist_ok=True) | |
os.makedirs('./data/images', exist_ok=True) | |
os.makedirs('./data/metadata', exist_ok=True) | |
return True | |
def load_and_sample_data(parquet_path: str, n_samples: int = 2000) -> pd.DataFrame: | |
"""Load parquet file and sample top N rows""" | |
logger.info(f"Loading data from {parquet_path}") | |
df = pd.read_parquet(parquet_path) | |
logger.info(f"Loaded {len(df)} rows, sampling top {n_samples}") | |
return df.head(n_samples) | |
def has_white_edges(img: Image.Image, threshold: int = 240) -> bool: | |
"""Check if image has 3 or more white edges (mean RGB > threshold)""" | |
try: | |
img_array = np.array(img) | |
height, width = img_array.shape[:2] | |
# Define edge thickness (check 5 pixels from each edge) | |
edge_thickness = 5 | |
# Get edges | |
top_edge = img_array[:edge_thickness, :].mean(axis=(0, 1)) | |
bottom_edge = img_array[-edge_thickness:, :].mean(axis=(0, 1)) | |
left_edge = img_array[:, :edge_thickness].mean(axis=(0, 1)) | |
right_edge = img_array[:, -edge_thickness:].mean(axis=(0, 1)) | |
# Check if edge is white (all RGB channels > threshold) | |
edges = [top_edge, bottom_edge, left_edge, right_edge] | |
white_edges = sum(1 for edge in edges if np.all(edge > threshold)) | |
return white_edges >= 3 | |
except Exception as e: | |
logger.debug(f"Error checking white edges: {e}") | |
return False | |
def download_and_process_image(url: str, target_size: int = 256) -> Image.Image: | |
"""Download image and resize with center crop, skip if has white edges""" | |
try: | |
response = requests.get(url, timeout=10, headers={'User-Agent': 'Mozilla/5.0'}) | |
response.raise_for_status() | |
img = Image.open(BytesIO(response.content)).convert('RGB') | |
# Check for white edges before processing | |
if has_white_edges(img): | |
logger.debug(f"Skipping image with white edges: {url}") | |
return None | |
# Resize and center crop to target_size x target_size | |
width, height = img.size | |
min_side = min(width, height) | |
scale = target_size / min_side | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
# Center crop | |
left = (new_width - target_size) // 2 | |
top = (new_height - target_size) // 2 | |
right = left + target_size | |
bottom = top + target_size | |
img = img.crop((left, top, right, bottom)) | |
# Double-check after processing | |
if has_white_edges(img): | |
logger.debug(f"Skipping processed image with white edges: {url}") | |
return None | |
return img | |
except Exception as e: | |
logger.error(f"Error downloading {url}: {e}") | |
return None | |
def process_single_image(args: Tuple[int, str, str, str]) -> bool: | |
"""Download and save a single image""" | |
idx, url, hash_val, caption = args | |
try: | |
# Download and process image | |
image = download_and_process_image(url) | |
if image is None: | |
logger.debug(f"Skipped image {idx} (white edges or download error)") | |
return False | |
# Save image | |
image_path = f'./data/images/img_{idx}.png' | |
image.save(image_path) | |
# Save metadata for next stage | |
metadata = { | |
"idx": idx, | |
"caption": caption, | |
"url": url, | |
"hash": hash_val, | |
"image_path": image_path | |
} | |
metadata_path = f'./data/metadata/meta_{idx}.json' | |
with open(metadata_path, 'w') as f: | |
json.dump(metadata, f, indent=2) | |
logger.info(f"Downloaded and saved image {idx}") | |
return True | |
except Exception as e: | |
logger.error(f"Error processing image {idx}: {e}") | |
return False | |
def download_images(df: pd.DataFrame, max_workers: int = 20): | |
"""Download all images with parallel processing""" | |
logger.info(f"Starting image download with {max_workers} workers...") | |
args_list = [(i, row['url'], row['hash'], row['caption']) | |
for i, (_, row) in enumerate(df.iterrows())] | |
successful = 0 | |
skipped_white_edges = 0 | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [executor.submit(process_single_image, args) for args in args_list] | |
for i, future in enumerate(concurrent.futures.as_completed(futures)): | |
if future.result(): | |
successful += 1 | |
else: | |
skipped_white_edges += 1 | |
# Progress logging every 100 images | |
if (i + 1) % 100 == 0: | |
logger.info(f"Processed {i + 1}/{len(args_list)} images (successful: {successful}, skipped: {skipped_white_edges})") | |
# Minimal rate limiting for high concurrency | |
time.sleep(0.01) | |
logger.info(f"Download complete: {successful}/{len(args_list)} images downloaded, {skipped_white_edges} skipped (white edges)") | |
# Save summary | |
summary = { | |
"total_images": len(args_list), | |
"successful_downloads": successful, | |
"skipped_white_edges": skipped_white_edges, | |
"download_rate": f"{successful/len(args_list)*100:.1f}%", | |
"stage": "download_complete" | |
} | |
with open('./data/stage1_summary.json', 'w') as f: | |
json.dump(summary, f, indent=2) | |
def main(): | |
"""Main execution for Stage 1""" | |
logger.info("Starting Stage 1: Data Loading and Image Downloading...") | |
# Setup | |
setup_environment() | |
# Load data | |
parquet_path = '/home/fal/partiprompt_clip/curated_part_00000.parquet' | |
df = load_and_sample_data(parquet_path, n_samples=5000) | |
# Save the dataframe for other stages | |
df.to_pickle('./data/sampled_data.pkl') | |
# Download images with optimized settings | |
download_images(df, max_workers=30) | |
logger.info("Stage 1 completed successfully!") | |
if __name__ == "__main__": | |
main() | |