import gradio as gr import torch import numpy as np from PIL import Image, ImageDraw from tkg_dm import TKGDMPipeline def create_canvas_image(width=512, height=512): """Create a blank canvas for drawing bounding boxes""" img = Image.new('RGB', (width, height), (240, 240, 240)) # Light gray background draw = ImageDraw.Draw(img) # Add grid lines for better visualization grid_size = 64 for x in range(0, width, grid_size): draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1) for y in range(0, height, grid_size): draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1) # Add instructions draw.text((10, 10), "Reserved regions preview", fill=(100, 100, 100)) draw.text((10, 25), "Yellow boxes show where content will be suppressed", fill=(100, 100, 100)) return img def draw_boxes_on_canvas(boxes, width=512, height=512): """Draw bounding boxes on canvas""" img = create_canvas_image(width, height) draw = ImageDraw.Draw(img) for i, (x1, y1, x2, y2) in enumerate(boxes): # Convert normalized coordinates to pixel coordinates px1, py1 = int(x1 * width), int(y1 * height) px2, py2 = int(x2 * width), int(y2 * height) # Draw bounding box draw.rectangle([px1, py1, px2, py2], outline='red', width=3) draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2) # Add semi-transparent fill overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0)) overlay_draw = ImageDraw.Draw(overlay) overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 255, 0, 80)) img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB') draw = ImageDraw.Draw(img) # Add box label label = f"Reserved Region {i+1}" draw.text((px1+5, py1+5), label, fill='white') draw.text((px1+4, py1+4), label, fill='black') # Shadow effect return img def sync_text_to_canvas(bbox_str): """Sync text input to canvas visualization""" boxes = parse_bounding_boxes(bbox_str) if boxes: return draw_boxes_on_canvas(boxes) else: return create_canvas_image() def add_bounding_box(bbox_str, x1, y1, x2, y2): """Add a new bounding box to the string""" # Ensure coordinates are in correct order and valid range x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2)) y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2)) # Check minimum size if x2 - x1 < 0.02 or y2 - y1 < 0.02: return bbox_str, sync_text_to_canvas(bbox_str) new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}" if bbox_str.strip(): updated_str = bbox_str + ";" + new_box else: updated_str = new_box return updated_str, sync_text_to_canvas(updated_str) def remove_last_box(bbox_str): """Remove the last bounding box""" if not bbox_str.strip(): return "", create_canvas_image() boxes = bbox_str.split(';') if boxes: boxes.pop() updated_str = ';'.join(boxes) return updated_str, sync_text_to_canvas(updated_str) def clear_all_boxes(): """Clear all bounding boxes""" return "", create_canvas_image() def load_preset_boxes(preset_name): """Load preset bounding box configurations""" presets = { "center_box": "0.3,0.3,0.7,0.7", "top_strip": "0.0,0.0,1.0,0.3", "bottom_strip": "0.0,0.7,1.0,1.0", "left_right": "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8", "corners": "0.0,0.0,0.4,0.4;0.6,0.0,1.0,0.4;0.0,0.6,0.4,1.0;0.6,0.6,1.0,1.0", "frame": "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8" } return presets.get(preset_name, "") def load_preset_handler(preset_name): """Load preset boxes and update preview""" if preset_name and preset_name != "center_box": preset_str = load_preset_boxes(preset_name) return preset_str, sync_text_to_canvas(preset_str) elif preset_name == "center_box": preset_str = "0.3,0.3,0.7,0.7" return preset_str, sync_text_to_canvas(preset_str) return "", create_canvas_image() def parse_bounding_boxes(bbox_str): """Parse bounding boxes from string format""" if not bbox_str or not bbox_str.strip(): return None try: boxes = [] for box_str in bbox_str.split(';'): if box_str.strip(): coords = [float(x.strip()) for x in box_str.split(',')] if len(coords) == 4: x1, y1, x2, y2 = coords x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2)) y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2)) boxes.append((x1, y1, x2, y2)) return boxes if boxes else None except: return None def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str): """Generate image using TKG-DM or fallback demo""" try: device = "cuda" if torch.cuda.is_available() else "cpu" bounding_boxes = parse_bounding_boxes(bounding_boxes_str) model_id = custom_model_id.strip() if custom_model_id.strip() else None pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device) if pipeline.pipe is not None: channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift] final_shift_percent = shift_percent * intensity blur_sigma_param = None if blur_sigma == 0 else blur_sigma if not bounding_boxes: bounding_boxes = [(0.3, 0.3, 0.7, 0.7)] image = pipeline( prompt=prompt, channel_shifts=channel_shifts, bounding_boxes=bounding_boxes, target_shift_percent=final_shift_percent, blur_sigma=blur_sigma_param, num_inference_steps=steps, guidance_scale=7.5 ) return image else: raise Exception("Pipeline not available") except Exception as e: print(f"Using demo mode due to: {e}") return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes) def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None): """Create demo visualization""" approx_color = ( max(0, min(255, 128 + int(ch0_shift * 127))), max(0, min(255, 128 + int(ch1_shift * 127))), max(0, min(255, 128 + int(ch2_shift * 127))) ) img = Image.new('RGB', (512, 512), approx_color) draw = ImageDraw.Draw(img) if not bounding_boxes: bounding_boxes = [(0.3, 0.3, 0.7, 0.7)] for i, (x1, y1, x2, y2) in enumerate(bounding_boxes): px1, py1 = int(x1 * 512), int(y1 * 512) px2, py2 = int(x2 * 512), int(y2 * 512) draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3) draw.text((px1+5, py1+5), f"Reserved {i+1}", fill='white') draw.text((10, 10), f"TKG-DM Demo", fill='white') draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white') draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white') return img # Create Gradio 5.x compatible interface with enhanced explanations with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo: gr.Markdown(""" # 🎨 SAWNA: Space-Aware Text-to-Image Generation ## What is SAWNA? **SAWNA** (Space-Aware Weighted Noise Addition) is an advanced AI image generation technique that gives you **precise control** over where content appears in your images. Unlike traditional text-to-image models that fill the entire canvas, SAWNA can **guarantee empty spaces** for your design elements.""") with gr.Row(): # Left column - Controls with gr.Column(scale=2): # Step 1: Basic Settings with gr.Group(): gr.Markdown("## 📝 Step 1: Basic Image Settings") prompt = gr.Textbox( value="A majestic lion in a natural landscape", label="Text Prompt", placeholder="Describe the image you want to generate...", lines=2 ) with gr.Row(): model_type = gr.Dropdown( choices=["sd1.5", "sdxl", "sd2.1"], value="sd1.5", label="Model Architecture", info="SDXL = highest quality, SD1.5 = fastest" ) steps = gr.Slider( minimum=10, maximum=100, value=25, label="Generation Steps", info="More steps = higher quality, slower generation" ) custom_model_id = gr.Textbox( value="", label="Custom Model ID (Optional)", placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0", info="Use any Hugging Face Stable Diffusion model" ) # Step 2: Reserved Regions with gr.Group(): gr.Markdown("## 🔲 Step 2: Define Reserved Regions") with gr.Row(): preset = gr.Dropdown( choices=[ ("Center Box", "center_box"), ("Top Banner", "top_strip"), ("Bottom Banner", "bottom_strip"), ("Side Panels", "left_right"), ("Corner Logos", "corners"), ("Full Frame", "frame") ], value="center_box", label="🚀 Quick Presets", info="Pre-designed layouts for common use cases" ) bounding_boxes_str = gr.Textbox( value="0.3,0.3,0.7,0.7", label="📋 Region Coordinates", placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (values 0-1)", info="Format: x1,y1,x2,y2 where (0,0)=top-left, (1,1)=bottom-right" ) # Manual box controls gr.Markdown("**Manual Box Builder:**") with gr.Row(): x1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Left (X1)", step=0.01) y1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Top (Y1)", step=0.01) x2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Right (X2)", step=0.01) y2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Bottom (Y2)", step=0.01) with gr.Row(): add_box_btn = gr.Button("➕ Add Box", variant="secondary", size="sm") remove_box_btn = gr.Button("❌ Remove Last", variant="secondary", size="sm") clear_boxes_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm") # Step 3: Advanced Controls with gr.Group(): gr.Markdown("## Step 3: Advanced Channel & Technical Controls") with gr.Row(): ch0_shift = gr.Slider( minimum=-1.0, maximum=1.0, value=0.0, label="Channel 1 (Luminance/Color)", info="Controls brightness and overall color balance" ) ch1_shift = gr.Slider( minimum=-1.0, maximum=1.0, value=1.0, label="Channel 2 (Blue+, Red-)", info="Positive = Sky Blue bias, Negative = Red bias" ) with gr.Row(): ch2_shift = gr.Slider( minimum=-1.0, maximum=1.0, value=1.0, label="Channel 3 (Yellow+, Blue-)", info="Positive = Yellow bias, Negative = Blue bias" ) ch3_shift = gr.Slider( minimum=-1.0, maximum=1.0, value=0.0, label="Channel 4 (Luminance/Color)", info="Secondary luminance and color control" ) gr.Markdown("**Technical Parameters:**") with gr.Row(): intensity = gr.Slider( minimum=0.5, maximum=3.0, value=1.0, label="🎯 Effect Intensity", info="Multiplier for overall SAWNA effect strength" ) shift_percent = gr.Slider( minimum=0.01, maximum=0.15, value=0.07, label="📊 Base Shift Percent", info="Fundamental noise modification percentage (7% recommended)" ) blur_sigma = gr.Slider( minimum=0.0, maximum=5.0, value=0.0, label="🌫️ Transition Blur Sigma", info="Gaussian blur for region boundaries (0 = auto-calculate)" ) generate_btn = gr.Button("🎨 Generate Space-Aware Image", variant="primary", size="lg") # Right column - Preview and Results with gr.Column(scale=1): # Reserved regions preview with gr.Group(): gr.Markdown("## 👁️ Reserved Regions Preview") bbox_preview = gr.Image( value=create_canvas_image(), label="Yellow areas will be kept empty", interactive=False, type="pil", height=300 ) gr.Markdown("*Yellow boxes show where content generation will be suppressed*") # Generated image output with gr.Group(): gr.Markdown("## ✨ Generated Result") output_image = gr.Image( label="Your space-aware image", type="pil", height=400 ) # Examples section with gr.Accordion("📚 Professional Use Case Examples", open=False): gr.Markdown(""" ### Real-World Applications Click any example below to see how SAWNA handles different professional design scenarios: - **🦁 Product Photography**: Center space for product details - **🏙️ Website Headers**: Top banner space for navigation - **🚗 Marketing Banners**: Bottom space for call-to-action - **🚀 Mobile UI**: Side panels for app interface elements - **⌚ E-commerce**: Frame layout for product information """) examples = gr.Examples( examples=[ [ "A majestic lion in African savanna, professional product photography", 0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "", "0.3,0.3,0.7,0.7" ], [ "Modern cityscape with skyscrapers at sunset, website header background", -0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "", "0.0,0.0,1.0,0.3" ], [ "Vintage luxury car on mountain road, marketing banner", 0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "", "0.0,0.7,1.0,1.0" ], [ "Space astronaut floating in colorful nebula, mobile app background", 0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "", "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8" ], [ "Premium luxury watch product photography, e-commerce layout", 0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "", "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8" ] ], inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str] ) # Event handlers generate_btn.click( fn=generate_tkg_dm_image, inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str], outputs=output_image ) # Preset dropdown handler preset.change( fn=load_preset_handler, inputs=[preset], outputs=[bounding_boxes_str, bbox_preview] ) # Box building handlers add_box_btn.click( fn=add_bounding_box, inputs=[bounding_boxes_str, x1, y1, x2, y2], outputs=[bounding_boxes_str, bbox_preview] ) remove_box_btn.click( fn=remove_last_box, inputs=[bounding_boxes_str], outputs=[bounding_boxes_str, bbox_preview] ) clear_boxes_btn.click( fn=clear_all_boxes, outputs=[bounding_boxes_str, bbox_preview] ) # Text box sync handler bounding_boxes_str.change( fn=sync_text_to_canvas, inputs=[bounding_boxes_str], outputs=[bbox_preview] ) if __name__ == "__main__": demo.launch(share=True, server_name="0.0.0.0")