|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageDraw |
|
from tkg_dm import TKGDMPipeline |
|
|
|
|
|
def create_canvas_image(width=512, height=512): |
|
"""Create a blank canvas for drawing bounding boxes""" |
|
img = Image.new('RGB', (width, height), (240, 240, 240)) |
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
grid_size = 64 |
|
for x in range(0, width, grid_size): |
|
draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1) |
|
for y in range(0, height, grid_size): |
|
draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1) |
|
|
|
|
|
draw.text((10, 10), "Reserved regions preview", fill=(100, 100, 100)) |
|
draw.text((10, 25), "Yellow boxes show where content will be suppressed", fill=(100, 100, 100)) |
|
|
|
return img |
|
|
|
|
|
def draw_boxes_on_canvas(boxes, width=512, height=512): |
|
"""Draw bounding boxes on canvas""" |
|
img = create_canvas_image(width, height) |
|
draw = ImageDraw.Draw(img) |
|
|
|
for i, (x1, y1, x2, y2) in enumerate(boxes): |
|
|
|
px1, py1 = int(x1 * width), int(y1 * height) |
|
px2, py2 = int(x2 * width), int(y2 * height) |
|
|
|
|
|
draw.rectangle([px1, py1, px2, py2], outline='red', width=3) |
|
draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2) |
|
|
|
|
|
overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0)) |
|
overlay_draw = ImageDraw.Draw(overlay) |
|
overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 255, 0, 80)) |
|
img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB') |
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
label = f"Reserved Region {i+1}" |
|
draw.text((px1+5, py1+5), label, fill='white') |
|
draw.text((px1+4, py1+4), label, fill='black') |
|
|
|
return img |
|
|
|
|
|
def sync_text_to_canvas(bbox_str): |
|
"""Sync text input to canvas visualization""" |
|
boxes = parse_bounding_boxes(bbox_str) |
|
if boxes: |
|
return draw_boxes_on_canvas(boxes) |
|
else: |
|
return create_canvas_image() |
|
|
|
|
|
def add_bounding_box(bbox_str, x1, y1, x2, y2): |
|
"""Add a new bounding box to the string""" |
|
|
|
x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2)) |
|
y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2)) |
|
|
|
|
|
if x2 - x1 < 0.02 or y2 - y1 < 0.02: |
|
return bbox_str, sync_text_to_canvas(bbox_str) |
|
|
|
new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}" |
|
|
|
if bbox_str.strip(): |
|
updated_str = bbox_str + ";" + new_box |
|
else: |
|
updated_str = new_box |
|
|
|
return updated_str, sync_text_to_canvas(updated_str) |
|
|
|
|
|
def remove_last_box(bbox_str): |
|
"""Remove the last bounding box""" |
|
if not bbox_str.strip(): |
|
return "", create_canvas_image() |
|
|
|
boxes = bbox_str.split(';') |
|
if boxes: |
|
boxes.pop() |
|
|
|
updated_str = ';'.join(boxes) |
|
return updated_str, sync_text_to_canvas(updated_str) |
|
|
|
|
|
def clear_all_boxes(): |
|
"""Clear all bounding boxes""" |
|
return "", create_canvas_image() |
|
|
|
|
|
def load_preset_boxes(preset_name): |
|
"""Load preset bounding box configurations""" |
|
presets = { |
|
"center_box": "0.3,0.3,0.7,0.7", |
|
"top_strip": "0.0,0.0,1.0,0.3", |
|
"bottom_strip": "0.0,0.7,1.0,1.0", |
|
"left_right": "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8", |
|
"corners": "0.0,0.0,0.4,0.4;0.6,0.0,1.0,0.4;0.0,0.6,0.4,1.0;0.6,0.6,1.0,1.0", |
|
"frame": "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8" |
|
} |
|
return presets.get(preset_name, "") |
|
|
|
|
|
def load_preset_handler(preset_name): |
|
"""Load preset boxes and update preview""" |
|
if preset_name and preset_name != "center_box": |
|
preset_str = load_preset_boxes(preset_name) |
|
return preset_str, sync_text_to_canvas(preset_str) |
|
elif preset_name == "center_box": |
|
preset_str = "0.3,0.3,0.7,0.7" |
|
return preset_str, sync_text_to_canvas(preset_str) |
|
return "", create_canvas_image() |
|
|
|
|
|
def parse_bounding_boxes(bbox_str): |
|
"""Parse bounding boxes from string format""" |
|
if not bbox_str or not bbox_str.strip(): |
|
return None |
|
|
|
try: |
|
boxes = [] |
|
for box_str in bbox_str.split(';'): |
|
if box_str.strip(): |
|
coords = [float(x.strip()) for x in box_str.split(',')] |
|
if len(coords) == 4: |
|
x1, y1, x2, y2 = coords |
|
x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2)) |
|
y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2)) |
|
boxes.append((x1, y1, x2, y2)) |
|
return boxes if boxes else None |
|
except: |
|
return None |
|
|
|
|
|
def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str): |
|
"""Generate image using TKG-DM or fallback demo""" |
|
|
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
bounding_boxes = parse_bounding_boxes(bounding_boxes_str) |
|
|
|
model_id = custom_model_id.strip() if custom_model_id.strip() else None |
|
pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device) |
|
|
|
if pipeline.pipe is not None: |
|
channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift] |
|
final_shift_percent = shift_percent * intensity |
|
blur_sigma_param = None if blur_sigma == 0 else blur_sigma |
|
|
|
if not bounding_boxes: |
|
bounding_boxes = [(0.3, 0.3, 0.7, 0.7)] |
|
|
|
image = pipeline( |
|
prompt=prompt, |
|
channel_shifts=channel_shifts, |
|
bounding_boxes=bounding_boxes, |
|
target_shift_percent=final_shift_percent, |
|
blur_sigma=blur_sigma_param, |
|
num_inference_steps=steps, |
|
guidance_scale=7.5 |
|
) |
|
return image |
|
else: |
|
raise Exception("Pipeline not available") |
|
|
|
except Exception as e: |
|
print(f"Using demo mode due to: {e}") |
|
return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes) |
|
|
|
|
|
def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None): |
|
"""Create demo visualization""" |
|
approx_color = ( |
|
max(0, min(255, 128 + int(ch0_shift * 127))), |
|
max(0, min(255, 128 + int(ch1_shift * 127))), |
|
max(0, min(255, 128 + int(ch2_shift * 127))) |
|
) |
|
img = Image.new('RGB', (512, 512), approx_color) |
|
draw = ImageDraw.Draw(img) |
|
|
|
if not bounding_boxes: |
|
bounding_boxes = [(0.3, 0.3, 0.7, 0.7)] |
|
|
|
for i, (x1, y1, x2, y2) in enumerate(bounding_boxes): |
|
px1, py1 = int(x1 * 512), int(y1 * 512) |
|
px2, py2 = int(x2 * 512), int(y2 * 512) |
|
draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3) |
|
draw.text((px1+5, py1+5), f"Reserved {i+1}", fill='white') |
|
|
|
draw.text((10, 10), f"TKG-DM Demo", fill='white') |
|
draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white') |
|
draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white') |
|
|
|
return img |
|
|
|
|
|
|
|
with gr.Blocks(title="π¨ SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo: |
|
gr.Markdown(""" |
|
# π¨ SAWNA: Space-Aware Text-to-Image Generation |
|
|
|
## What is SAWNA? |
|
**SAWNA** (Space-Aware Weighted Noise Addition) is an advanced AI image generation technique that gives you **precise control** over where content appears in your images. Unlike traditional text-to-image models that fill the entire canvas, SAWNA can **guarantee empty spaces** for your design elements.""") |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("## π Step 1: Basic Image Settings") |
|
prompt = gr.Textbox( |
|
value="A majestic lion in a natural landscape", |
|
label="Text Prompt", |
|
placeholder="Describe the image you want to generate...", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
model_type = gr.Dropdown( |
|
choices=["sd1.5", "sdxl", "sd2.1"], |
|
value="sd1.5", |
|
label="Model Architecture", |
|
info="SDXL = highest quality, SD1.5 = fastest" |
|
) |
|
steps = gr.Slider( |
|
minimum=10, maximum=100, value=25, |
|
label="Generation Steps", |
|
info="More steps = higher quality, slower generation" |
|
) |
|
|
|
custom_model_id = gr.Textbox( |
|
value="", |
|
label="Custom Model ID (Optional)", |
|
placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0", |
|
info="Use any Hugging Face Stable Diffusion model" |
|
) |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("## π² Step 2: Define Reserved Regions") |
|
|
|
with gr.Row(): |
|
preset = gr.Dropdown( |
|
choices=[ |
|
("Center Box", "center_box"), |
|
("Top Banner", "top_strip"), |
|
("Bottom Banner", "bottom_strip"), |
|
("Side Panels", "left_right"), |
|
("Corner Logos", "corners"), |
|
("Full Frame", "frame") |
|
], |
|
value="center_box", |
|
label="π Quick Presets", |
|
info="Pre-designed layouts for common use cases" |
|
) |
|
|
|
bounding_boxes_str = gr.Textbox( |
|
value="0.3,0.3,0.7,0.7", |
|
label="π Region Coordinates", |
|
placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (values 0-1)", |
|
info="Format: x1,y1,x2,y2 where (0,0)=top-left, (1,1)=bottom-right" |
|
) |
|
|
|
|
|
gr.Markdown("**Manual Box Builder:**") |
|
with gr.Row(): |
|
x1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Left (X1)", step=0.01) |
|
y1 = gr.Number(value=0.3, minimum=0, maximum=1, label="Top (Y1)", step=0.01) |
|
x2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Right (X2)", step=0.01) |
|
y2 = gr.Number(value=0.7, minimum=0, maximum=1, label="Bottom (Y2)", step=0.01) |
|
|
|
with gr.Row(): |
|
add_box_btn = gr.Button("β Add Box", variant="secondary", size="sm") |
|
remove_box_btn = gr.Button("β Remove Last", variant="secondary", size="sm") |
|
clear_boxes_btn = gr.Button("ποΈ Clear All", variant="secondary", size="sm") |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("## Step 3: Advanced Channel & Technical Controls") |
|
|
|
with gr.Row(): |
|
ch0_shift = gr.Slider( |
|
minimum=-1.0, maximum=1.0, value=0.0, |
|
label="Channel 1 (Luminance/Color)", |
|
info="Controls brightness and overall color balance" |
|
) |
|
ch1_shift = gr.Slider( |
|
minimum=-1.0, maximum=1.0, value=1.0, |
|
label="Channel 2 (Blue+, Red-)", |
|
info="Positive = Sky Blue bias, Negative = Red bias" |
|
) |
|
|
|
with gr.Row(): |
|
ch2_shift = gr.Slider( |
|
minimum=-1.0, maximum=1.0, value=1.0, |
|
label="Channel 3 (Yellow+, Blue-)", |
|
info="Positive = Yellow bias, Negative = Blue bias" |
|
) |
|
ch3_shift = gr.Slider( |
|
minimum=-1.0, maximum=1.0, value=0.0, |
|
label="Channel 4 (Luminance/Color)", |
|
info="Secondary luminance and color control" |
|
) |
|
|
|
gr.Markdown("**Technical Parameters:**") |
|
with gr.Row(): |
|
intensity = gr.Slider( |
|
minimum=0.5, maximum=3.0, value=1.0, |
|
label="π― Effect Intensity", |
|
info="Multiplier for overall SAWNA effect strength" |
|
) |
|
shift_percent = gr.Slider( |
|
minimum=0.01, maximum=0.15, value=0.07, |
|
label="π Base Shift Percent", |
|
info="Fundamental noise modification percentage (7% recommended)" |
|
) |
|
|
|
blur_sigma = gr.Slider( |
|
minimum=0.0, maximum=5.0, value=0.0, |
|
label="π«οΈ Transition Blur Sigma", |
|
info="Gaussian blur for region boundaries (0 = auto-calculate)" |
|
) |
|
|
|
generate_btn = gr.Button("π¨ Generate Space-Aware Image", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("## ποΈ Reserved Regions Preview") |
|
bbox_preview = gr.Image( |
|
value=create_canvas_image(), |
|
label="Yellow areas will be kept empty", |
|
interactive=False, |
|
type="pil", |
|
height=300 |
|
) |
|
gr.Markdown("*Yellow boxes show where content generation will be suppressed*") |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("## β¨ Generated Result") |
|
output_image = gr.Image( |
|
label="Your space-aware image", |
|
type="pil", |
|
height=400 |
|
) |
|
|
|
|
|
with gr.Accordion("π Professional Use Case Examples", open=False): |
|
gr.Markdown(""" |
|
### Real-World Applications |
|
|
|
Click any example below to see how SAWNA handles different professional design scenarios: |
|
|
|
- **π¦ Product Photography**: Center space for product details |
|
- **ποΈ Website Headers**: Top banner space for navigation |
|
- **π Marketing Banners**: Bottom space for call-to-action |
|
- **π Mobile UI**: Side panels for app interface elements |
|
- **β E-commerce**: Frame layout for product information |
|
""") |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
[ |
|
"A majestic lion in African savanna, professional product photography", |
|
0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "", |
|
"0.3,0.3,0.7,0.7" |
|
], |
|
[ |
|
"Modern cityscape with skyscrapers at sunset, website header background", |
|
-0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "", |
|
"0.0,0.0,1.0,0.3" |
|
], |
|
[ |
|
"Vintage luxury car on mountain road, marketing banner", |
|
0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "", |
|
"0.0,0.7,1.0,1.0" |
|
], |
|
[ |
|
"Space astronaut floating in colorful nebula, mobile app background", |
|
0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "", |
|
"0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8" |
|
], |
|
[ |
|
"Premium luxury watch product photography, e-commerce layout", |
|
0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "", |
|
"0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8" |
|
] |
|
], |
|
inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, |
|
intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str] |
|
) |
|
|
|
|
|
generate_btn.click( |
|
fn=generate_tkg_dm_image, |
|
inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, |
|
intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str], |
|
outputs=output_image |
|
) |
|
|
|
|
|
preset.change( |
|
fn=load_preset_handler, |
|
inputs=[preset], |
|
outputs=[bounding_boxes_str, bbox_preview] |
|
) |
|
|
|
|
|
add_box_btn.click( |
|
fn=add_bounding_box, |
|
inputs=[bounding_boxes_str, x1, y1, x2, y2], |
|
outputs=[bounding_boxes_str, bbox_preview] |
|
) |
|
|
|
remove_box_btn.click( |
|
fn=remove_last_box, |
|
inputs=[bounding_boxes_str], |
|
outputs=[bounding_boxes_str, bbox_preview] |
|
) |
|
|
|
clear_boxes_btn.click( |
|
fn=clear_all_boxes, |
|
outputs=[bounding_boxes_str, bbox_preview] |
|
) |
|
|
|
|
|
bounding_boxes_str.change( |
|
fn=sync_text_to_canvas, |
|
inputs=[bounding_boxes_str], |
|
outputs=[bbox_preview] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True, server_name="0.0.0.0") |
|
|