import base64
import io
import spaces
import gradio as gr
from PIL import Image
import requests
import numpy as np
import PIL
from concept_attention import ConceptAttentionFluxPipeline
# concept_attention_default_args = {
# "model_name": "flux-schnell",
# "device": "cuda",
# "layer_indices": list(range(10, 19)),
# "timesteps": list(range(2, 4)),
# "num_samples": 4,
# "num_inference_steps": 4
# }
IMG_SIZE = 250
def download_image(url):
return Image.open(io.BytesIO(requests.get(url).content))
EXAMPLES = [
[
"A dog by a tree", # prompt
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dog_by_tree.png?raw=true"),
"tree, dog, grass, background", # words
42, # seed
],
[
"A dragon", # prompt
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dragon_image.png?raw=true"),
"dragon, sky, rock, cloud", # words
42, # seed
],
[
"A hot air balloon", # prompt
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/hot_air_balloon.png?raw=true"),
"balloon, sky, water, tree", # words
42, # seed
]
]
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
@spaces.GPU(duration=60)
def process_inputs(prompt, input_image, word_list, seed, num_samples, layer_start_index, timestep_start_index):
print("Processing inputs")
prompt = prompt.strip()
if not word_list.strip():
return None, "Please enter comma-separated words"
concepts = [w.strip() for w in word_list.split(",")]
if input_image is not None:
if isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image)
input_image = input_image.convert("RGB")
input_image = input_image.resize((1024, 1024))
elif isinstance(input_image, PIL.Image.Image):
input_image = input_image.convert("RGB")
input_image = input_image.resize((1024, 1024))
pipeline_output = pipeline.encode_image(
image=input_image,
concepts=concepts,
prompt=prompt,
width=1024,
height=1024,
seed=seed,
num_samples=num_samples,
layer_indices=list(range(layer_start_index, 19)),
)
else:
pipeline_output = pipeline.generate_image(
prompt=prompt,
concepts=concepts,
width=1024,
height=1024,
seed=seed,
timesteps=list(range(timestep_start_index, 4)),
num_inference_steps=4,
layer_indices=list(range(layer_start_index, 19)),
)
output_image = pipeline_output.image
concept_heatmaps = pipeline_output.concept_heatmaps
html_elements = []
for concept, heatmap in zip(concepts, concept_heatmaps):
img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
html = f"""
{concept}
"""
html_elements.append(html)
combined_html = "" + "".join(html_elements) + "
"
return output_image, combined_html, None # None fills input_image with None
with gr.Blocks(
css="""
.container { max-width: 1200px; margin: 0 auto; padding: 20px; }
.title { text-align: center; margin-bottom: 10px; }
.authors { text-align: center; margin-bottom: 10px; }
.affiliations { text-align: center; color: #666; margin-bottom: 10px; }
.content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
.section { }
.input-image { width: 100%; height: 200px; }
.abstract { text-align: center; margin-bottom: 40px; }
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
gr.Markdown(
"""
We introduce ConceptAttention, an approach to interpreting the intermediate representations of diffusion transformers.
The user just gives a list of textual concepts and ConceptAttention will produce a set of saliency maps depicting
the location and intensity of these concepts in generated images. Check out our paper: [here](https://arxiv.org/abs/2502.04320).
""",
elem_classes="abstract"
)
with gr.Row(elem_classes="content"):
with gr.Column(elem_classes="section"):
gr.Markdown("### Input")
prompt = gr.Textbox(label="Enter your prompt")
words = gr.Textbox(label="Enter a list of concepts (comma-separated)")
# gr.HTML("
Or
")
image_input = gr.Image(type="numpy", label="Upload image (optional)", elem_classes="input-image")
# Set up advanced options
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4)
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
with gr.Column(elem_classes="section"):
gr.Markdown("### Output")
output_image = gr.Image(type="numpy", label="Output image")
with gr.Row():
submit_btn = gr.Button("Process")
with gr.Row(elem_classes="section"):
saliency_display = gr.HTML(label="Saliency Maps")
submit_btn.click(
fn=process_inputs,
inputs=[prompt, image_input, words, seed, num_samples, layer_start_index, timestep_start_index], outputs=[output_image, saliency_display, image_input]
)
# .then(
# fn=lambda component: gr.update(value=None),
# inputs=[image_input],
# outputs=[]
# )
gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
if __name__ == "__main__":
demo.launch(max_threads=1)
# share=True,
# server_name="0.0.0.0",
# inbrowser=True,
# # share=False,
# server_port=6754,
# quiet=True,
# max_threads=1
# )