import os import numpy as np from datasets import load_dataset from PIL import Image, ImageOps, ImageFilter from tqdm import tqdm import random import requests import io import time def download_image(url, timeout=10, retries=2): """Download image from URL with retry mechanism""" for attempt in range(retries): try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, timeout=timeout, headers=headers) if response.status_code == 200: image = Image.open(io.BytesIO(response.content)) return image else: return None except Exception as e: if attempt == retries - 1: # Last attempt print(f"Failed to download {url}: {e}") return None time.sleep(0.5) # Brief pause before retry return None def preprocess_image(image, target_size=512, quality_threshold=0.7): """Preprocess image with various enhancements""" if image is None: return None try: # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Filter out low quality images width, height = image.size if min(width, height) < target_size * quality_threshold: return None # Center crop to square if not already if width != height: size = min(width, height) left = (width - size) // 2 top = (height - size) // 2 image = image.crop((left, top, left + size, top + size)) # Resize to target size image = image.resize((target_size, target_size), Image.Resampling.LANCZOS) # Enhance image quality # Slightly sharpen image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3)) # Auto-adjust levels image = ImageOps.autocontrast(image, cutoff=1) return image except Exception as e: print(f"Error preprocessing image: {e}") return None def clean_prompt(prompt): """Clean and normalize prompts""" if not prompt: return None # Remove excessive whitespace prompt = ' '.join(prompt.split()) # Remove common artifacts prompt = prompt.replace(' ', ' ') prompt = prompt.strip(' .,;:') # Filter out very short or very long prompts words = prompt.split() if len(words) < 3 or len(words) > 50: return None return prompt def prepare_dreambooth_data(): # Load dataset print("Loading LAION dataset...") dataset = load_dataset("laion/laion2B-en-aesthetic", split="train", streaming=True) # Create directory structure data_dir = "./laion_dataset" os.makedirs(data_dir, exist_ok=True) valid_samples = 0 processed_count = 0 max_samples = 1000 # Limit total samples to process print(f"Starting to process up to {max_samples} samples...") # Process images with preprocessing for idx, sample in enumerate(tqdm(dataset, desc="Processing LAION samples")): if processed_count >= max_samples: break processed_count += 1 try: # Get URL and text from LAION format image_url = sample.get('URL', '') text_prompt = sample.get('TEXT', '') if not image_url or not text_prompt: continue # Clean prompt first prompt = clean_prompt(text_prompt) if prompt is None: continue # Download image from URL print(f"Downloading image {valid_samples + 1}: {image_url[:50]}...") image = download_image(image_url) if image is None: continue # Preprocess downloaded image processed_image = preprocess_image(image) if processed_image is None: continue # Save processed image image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg") processed_image.save(image_path, "JPEG", quality=95, optimize=True) # Save cleaned caption caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt") with open(caption_path, 'w', encoding='utf-8') as f: f.write(prompt) valid_samples += 1 # Optional: Add metadata file metadata_path = os.path.join(data_dir, f"image_{valid_samples-1:04d}_meta.txt") with open(metadata_path, 'w', encoding='utf-8') as f: f.write(f"URL: {image_url}\n") f.write(f"Aesthetic: {sample.get('aesthetic', 'N/A')}\n") f.write(f"Width: {sample.get('WIDTH', 'N/A')}\n") f.write(f"Height: {sample.get('HEIGHT', 'N/A')}\n") # Stop if we have enough samples if valid_samples >= 100: # Adjust this number as needed break except Exception as e: print(f"Error processing sample {idx}: {e}") continue print(f"Processed {processed_count} samples, saved {valid_samples} valid images to {data_dir}") return data_dir def create_demo_dataset(): """Create demo dataset as last resort""" print("Creating demo dataset...") data_dir = "./demo_dataset" os.makedirs(data_dir, exist_ok=True) demo_prompts = [ "a beautiful landscape with mountains", "portrait of a person with detailed features", "abstract colorful digital artwork", "modern architecture building design", "natural forest scene with trees", "urban cityscape at sunset", "artistic oil painting style", "vintage photography aesthetic", "minimalist geometric composition", "vibrant surreal art piece" ] for idx, prompt in enumerate(demo_prompts): # Create gradient background color1 = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200)) color2 = (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255)) image = Image.new('RGB', (512, 512), color1) # Save files image_path = os.path.join(data_dir, f"image_{idx:04d}.jpg") image.save(image_path, "JPEG", quality=95) caption_path = os.path.join(data_dir, f"image_{idx:04d}.txt") with open(caption_path, 'w', encoding='utf-8') as f: f.write(prompt) print(f"Created {len(demo_prompts)} demo samples") return data_dir # Main execution with fallback def main(): data_dir = prepare_dreambooth_data() # Generate training command training_command = f""" accelerate launch \\ --deepspeed_config_file ds_config.json \\ diffusers/examples/dreambooth/train_dreambooth.py \\ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\ --instance_data_dir="{data_dir}" \\ --instance_prompt="a high quality image" \\ --output_dir="./laion-model" \\ --resolution=512 \\ --train_batch_size=1 \\ --gradient_accumulation_steps=1 \\ --gradient_checkpointing \\ --learning_rate=5e-6 \\ --lr_scheduler="constant" \\ --lr_warmup_steps=0 \\ --max_train_steps=400 \\ --mixed_precision="fp16" \\ --checkpointing_steps=100 \\ --checkpoints_total_limit=1 \\ --report_to="tensorboard" \\ --logging_dir="./laion-model/logs" """ print(f"\nāœ… Dataset prepared in: {data_dir}") print("šŸš€ Run this command to train:") print(training_command) if __name__ == "__main__": main()