Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion | |
import openai | |
import os | |
import uuid | |
from src.flux.generate import generate, seed_everything | |
def process_image(api_key, seed, image, prompt): | |
openai.api_key = api_key | |
# Generate a unique image ID to avoid file name conflict | |
image_id = str(uuid.uuid4()) | |
seed_everything(seed) | |
input_path = f"input_{image_id}.png" | |
image.save(input_path) | |
intermediate_images = [] | |
current_image_path = input_path | |
try: | |
uri = encode_image_to_datauri(input_path) | |
categories, instructions = cot_with_gpt(uri, prompt) | |
if not categories or not instructions: | |
raise gr.Error("No editing steps returned by GPT. Try a more specific instruction.") | |
intermediate_images.append(image) | |
yield intermediate_images, image | |
for i, (category, instruction) in enumerate(zip(categories, instructions)): | |
print(f"[Step {i}] Category: {category} | Instruction: {instruction}") | |
step_prefix = f"{image_id}_{i}" | |
if category in ('Add', 'Remove', 'Replace'): | |
if category == 'Add': | |
edited_image = infer_with_DiT('RoI Editing', current_image_path, instruction, category) | |
else: | |
mask_image = roi_localization(current_image_path, instruction, category) | |
edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category) | |
elif category == 'Action Change': | |
mask_image = roi_localization(current_image_path, instruction, category) | |
inpainted = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') | |
changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', current_image_path, instruction, category) | |
fusion_image = fusion(inpainted, changed_instance, x0, y1, scale) | |
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None) | |
elif category in ('Move', 'Resize'): | |
mask_image, changed_instance, x0, y1, scale = roi_localization(current_image_path, instruction, category) | |
inpainted = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') | |
fusion_image = fusion(inpainted, changed_instance, x0, y1, scale) | |
edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None) | |
elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'): | |
edited_image = infer_with_DiT('RoI Editing', current_image_path, instruction, category) | |
elif category in ('Tone Transfer', 'Style Change'): | |
edited_image = infer_with_DiT('Global Transformation', current_image_path, instruction, category) | |
else: | |
raise gr.Error(f"Invalid category returned: '{category}'") | |
current_image_path = f"{step_prefix}.png" | |
edited_image.save(current_image_path) | |
intermediate_images.append(edited_image.copy()) | |
yield intermediate_images, edited_image | |
except Exception as e: | |
raise gr.Error(f"Processing failed: {str(e)}") | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## πΌοΈ IEAP: Image Editing As Programs") | |
with gr.Row(): | |
api_key_input = gr.Textbox(label="π OpenAI API Key", type="password", placeholder="sk-...") | |
with gr.Row(): | |
seed_slider = gr.Slider( | |
label="π² Random Seed", | |
minimum=0, | |
maximum=1000000, | |
value=3407, | |
step=1, | |
info="Drag to set the random seed for reproducibility" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
prompt_input = gr.Textbox(label="Instruction", placeholder="e.g., Remove the dog.") | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
result_gallery = gr.Gallery(label="Intermediate Steps", columns=2, height="auto") | |
final_output = gr.Image(label="β Final Result") | |
submit_button.click( | |
fn=process_image, | |
inputs=[api_key_input, seed_slider, image_input, prompt_input], | |
outputs=[result_gallery, final_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |