eiji
change parameter controls
e0a8f84
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
from tkg_dm import TKGDMPipeline
def create_canvas_image(width=512, height=512):
"""Create a blank canvas for drawing bounding boxes"""
img = Image.new('RGB', (width, height), (240, 240, 240)) # Light gray background
draw = ImageDraw.Draw(img)
# Add grid lines for better visualization
grid_size = 64
for x in range(0, width, grid_size):
draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)
for y in range(0, height, grid_size):
draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)
# Add instructions
draw.text((10, 10), "Reserved regions preview", fill=(100, 100, 100))
draw.text((10, 25), "Yellow boxes show where content will be suppressed", fill=(100, 100, 100))
return img
def draw_boxes_on_canvas(boxes, width=512, height=512):
"""Draw bounding boxes on canvas"""
img = create_canvas_image(width, height)
draw = ImageDraw.Draw(img)
for i, (x1, y1, x2, y2) in enumerate(boxes):
# Convert normalized coordinates to pixel coordinates
px1, py1 = int(x1 * width), int(y1 * height)
px2, py2 = int(x2 * width), int(y2 * height)
# Draw bounding box
draw.rectangle([px1, py1, px2, py2], outline='red', width=3)
draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2)
# Add semi-transparent fill
overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 255, 0, 80))
img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
draw = ImageDraw.Draw(img)
# Add box label
label = f"Reserved Region {i+1}"
draw.text((px1+5, py1+5), label, fill='white')
draw.text((px1+4, py1+4), label, fill='black') # Shadow effect
return img
def sync_text_to_canvas(bbox_str):
"""Sync text input to canvas visualization"""
boxes = parse_bounding_boxes(bbox_str)
if boxes:
return draw_boxes_on_canvas(boxes)
else:
return create_canvas_image()
def add_bounding_box(bbox_str, x1, y1, x2, y2):
"""Add a new bounding box to the string"""
# Ensure coordinates are in correct order and valid range
x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
# Check minimum size
if x2 - x1 < 0.02 or y2 - y1 < 0.02:
return bbox_str, sync_text_to_canvas(bbox_str)
new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}"
if bbox_str.strip():
updated_str = bbox_str + ";" + new_box
else:
updated_str = new_box
return updated_str, sync_text_to_canvas(updated_str)
def remove_last_box(bbox_str):
"""Remove the last bounding box"""
if not bbox_str.strip():
return "", create_canvas_image()
boxes = bbox_str.split(';')
if boxes:
boxes.pop()
updated_str = ';'.join(boxes)
return updated_str, sync_text_to_canvas(updated_str)
def clear_all_boxes():
"""Clear all bounding boxes"""
return "", create_canvas_image()
def load_preset_boxes(preset_name):
"""Load preset bounding box configurations"""
presets = {
"center_box": "0.3,0.3,0.7,0.7",
"top_strip": "0.0,0.0,1.0,0.3",
"bottom_strip": "0.0,0.7,1.0,1.0",
"left_right": "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8",
"corners": "0.0,0.0,0.4,0.4;0.6,0.0,1.0,0.4;0.0,0.6,0.4,1.0;0.6,0.6,1.0,1.0",
"frame": "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
}
return presets.get(preset_name, "")
def load_preset_handler(preset_name):
"""Load preset boxes and update preview"""
if preset_name and preset_name != "center_box":
preset_str = load_preset_boxes(preset_name)
return preset_str, sync_text_to_canvas(preset_str)
elif preset_name == "center_box":
preset_str = "0.3,0.3,0.7,0.7"
return preset_str, sync_text_to_canvas(preset_str)
return "", create_canvas_image()
def parse_bounding_boxes(bbox_str):
"""Parse bounding boxes from string format"""
if not bbox_str or not bbox_str.strip():
return None
try:
boxes = []
for box_str in bbox_str.split(';'):
if box_str.strip():
coords = [float(x.strip()) for x in box_str.split(',')]
if len(coords) == 4:
x1, y1, x2, y2 = coords
x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
boxes.append((x1, y1, x2, y2))
return boxes if boxes else None
except:
return None
def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str):
"""Generate image using TKG-DM or fallback demo"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
bounding_boxes = parse_bounding_boxes(bounding_boxes_str)
model_id = custom_model_id.strip() if custom_model_id.strip() else None
pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device)
if pipeline.pipe is not None:
channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift]
final_shift_percent = shift_percent * intensity
blur_sigma_param = None if blur_sigma == 0 else blur_sigma
if not bounding_boxes:
bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
image = pipeline(
prompt=prompt,
channel_shifts=channel_shifts,
bounding_boxes=bounding_boxes,
target_shift_percent=final_shift_percent,
blur_sigma=blur_sigma_param,
num_inference_steps=steps,
guidance_scale=7.5
)
return image
else:
raise Exception("Pipeline not available")
except Exception as e:
print(f"Using demo mode due to: {e}")
return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes)
def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None):
"""Create demo visualization"""
approx_color = (
max(0, min(255, 128 + int(ch0_shift * 127))),
max(0, min(255, 128 + int(ch1_shift * 127))),
max(0, min(255, 128 + int(ch2_shift * 127)))
)
img = Image.new('RGB', (512, 512), approx_color)
draw = ImageDraw.Draw(img)
if not bounding_boxes:
bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
for i, (x1, y1, x2, y2) in enumerate(bounding_boxes):
px1, py1 = int(x1 * 512), int(y1 * 512)
px2, py2 = int(x2 * 512), int(y2 * 512)
draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3)
draw.text((px1+5, py1+5), f"Reserved {i+1}", fill='white')
draw.text((10, 10), f"TKG-DM Demo", fill='white')
draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white')
draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white')
return img
# Create Gradio 5.x compatible interface with enhanced explanations
with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo:
gr.Markdown("""
# 🎨 SAWNA: Space-Aware Text-to-Image Generation
## What is SAWNA?
**SAWNA** (Space-Aware Weighted Noise Addition) is an advanced AI image generation technique that gives you **precise control** over where content appears in your images. Unlike traditional text-to-image models that fill the entire canvas, SAWNA can **guarantee empty spaces** for your design elements.""")
with gr.Row():
# Left column - Controls
with gr.Column(scale=2):
# Step 1: Basic Settings
with gr.Group():
gr.Markdown("## πŸ“ Step 1: Basic Image Settings")
prompt = gr.Textbox(
value="A majestic lion in a natural landscape",
label="Text Prompt",
placeholder="Describe the image you want to generate...",
lines=2
)
with gr.Row():
model_type = gr.Dropdown(
choices=["sd1.5", "sdxl", "sd2.1"],
value="sd1.5",
label="Model Architecture",
info="SDXL = highest quality, SD1.5 = fastest"
)
steps = gr.Slider(
minimum=10, maximum=100, value=25,
label="Generation Steps",
info="More steps = higher quality, slower generation"
)
custom_model_id = gr.Textbox(
value="",
label="Custom Model ID (Optional)",
placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0",
info="Use any Hugging Face Stable Diffusion model"
)
# Step 2: Reserved Regions
with gr.Group():
gr.Markdown("## πŸ”² Step 2: Define Reserved Regions")
with gr.Row():
preset = gr.Dropdown(
choices=[
("Center Box", "center_box"),
("Top Banner", "top_strip"),
("Bottom Banner", "bottom_strip"),
("Side Panels", "left_right"),
("Corner Logos", "corners"),
("Full Frame", "frame")
],
value="center_box",
label="πŸš€ Quick Presets",
info="Pre-designed layouts for common use cases"
)
bounding_boxes_str = gr.Textbox(
value="0.3,0.3,0.7,0.7",
label="πŸ“‹ Region Coordinates",
placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (values 0-1)",
info="Format: x1,y1,x2,y2 where (0,0)=top-left, (1,1)=bottom-right"
)
# Manual box controls
gr.Markdown("**Manual Box Builder:**")
with gr.Row():
x1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Left (X1)", step=0.01)
y1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Top (Y1)", step=0.01)
x2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Right (X2)", step=0.01)
y2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Bottom (Y2)", step=0.01)
with gr.Row():
add_box_btn = gr.Button("βž• Add Box", variant="secondary", size="sm")
remove_box_btn = gr.Button("❌ Remove Last", variant="secondary", size="sm")
clear_boxes_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", size="sm")
# Step 3: Advanced Controls
with gr.Group():
gr.Markdown("## Step 3: Advanced Channel & Technical Controls")
with gr.Row():
ch0_shift = gr.Slider(
minimum=-1.0, maximum=1.0, value=0.0,
label="Channel 1 (Luminance/Color)",
info="Controls brightness and overall color balance"
)
ch1_shift = gr.Slider(
minimum=-1.0, maximum=1.0, value=1.0,
label="Channel 2 (Blue+, Red-)",
info="Positive = Sky Blue bias, Negative = Red bias"
)
with gr.Row():
ch2_shift = gr.Slider(
minimum=-1.0, maximum=1.0, value=1.0,
label="Channel 3 (Yellow+, Blue-)",
info="Positive = Yellow bias, Negative = Blue bias"
)
ch3_shift = gr.Slider(
minimum=-1.0, maximum=1.0, value=0.0,
label="Channel 4 (Luminance/Color)",
info="Secondary luminance and color control"
)
gr.Markdown("**Technical Parameters:**")
with gr.Row():
intensity = gr.Slider(
minimum=0.5, maximum=3.0, value=1.0,
label="🎯 Effect Intensity",
info="Multiplier for overall SAWNA effect strength"
)
shift_percent = gr.Slider(
minimum=0.01, maximum=0.15, value=0.07,
label="πŸ“Š Base Shift Percent",
info="Fundamental noise modification percentage (7% recommended)"
)
blur_sigma = gr.Slider(
minimum=0.0, maximum=5.0, value=0.0,
label="🌫️ Transition Blur Sigma",
info="Gaussian blur for region boundaries (0 = auto-calculate)"
)
generate_btn = gr.Button("🎨 Generate Space-Aware Image", variant="primary", size="lg")
# Right column - Preview and Results
with gr.Column(scale=1):
# Reserved regions preview
with gr.Group():
gr.Markdown("## πŸ‘οΈ Reserved Regions Preview")
bbox_preview = gr.Image(
value=create_canvas_image(),
label="Yellow areas will be kept empty",
interactive=False,
type="pil",
height=300
)
gr.Markdown("*Yellow boxes show where content generation will be suppressed*")
# Generated image output
with gr.Group():
gr.Markdown("## ✨ Generated Result")
output_image = gr.Image(
label="Your space-aware image",
type="pil",
height=400
)
# Examples section
with gr.Accordion("πŸ“š Professional Use Case Examples", open=False):
gr.Markdown("""
### Real-World Applications
Click any example below to see how SAWNA handles different professional design scenarios:
- **🦁 Product Photography**: Center space for product details
- **πŸ™οΈ Website Headers**: Top banner space for navigation
- **πŸš— Marketing Banners**: Bottom space for call-to-action
- **πŸš€ Mobile UI**: Side panels for app interface elements
- **⌚ E-commerce**: Frame layout for product information
""")
examples = gr.Examples(
examples=[
[
"A majestic lion in African savanna, professional product photography",
0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "",
"0.3,0.3,0.7,0.7"
],
[
"Modern cityscape with skyscrapers at sunset, website header background",
-0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "",
"0.0,0.0,1.0,0.3"
],
[
"Vintage luxury car on mountain road, marketing banner",
0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "",
"0.0,0.7,1.0,1.0"
],
[
"Space astronaut floating in colorful nebula, mobile app background",
0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "",
"0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8"
],
[
"Premium luxury watch product photography, e-commerce layout",
0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "",
"0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
]
],
inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str]
)
# Event handlers
generate_btn.click(
fn=generate_tkg_dm_image,
inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
outputs=output_image
)
# Preset dropdown handler
preset.change(
fn=load_preset_handler,
inputs=[preset],
outputs=[bounding_boxes_str, bbox_preview]
)
# Box building handlers
add_box_btn.click(
fn=add_bounding_box,
inputs=[bounding_boxes_str, x1, y1, x2, y2],
outputs=[bounding_boxes_str, bbox_preview]
)
remove_box_btn.click(
fn=remove_last_box,
inputs=[bounding_boxes_str],
outputs=[bounding_boxes_str, bbox_preview]
)
clear_boxes_btn.click(
fn=clear_all_boxes,
outputs=[bounding_boxes_str, bbox_preview]
)
# Text box sync handler
bounding_boxes_str.change(
fn=sync_text_to_canvas,
inputs=[bounding_boxes_str],
outputs=[bbox_preview]
)
if __name__ == "__main__":
demo.launch(share=True, server_name="0.0.0.0")