import spaces import os import torch from diffusers import StableDiffusionXLPipeline import gradio as gr from huggingface_hub import hf_hub_download, snapshot_download from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer from utils import align_face # ---------------------- # Configuration (update paths as needed) # ---------------------- base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"]) image_encoder_path = os.path.join(image_encoder_path, "image_encoder") personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/model.safetensors") device = "cuda" # Special token settings placeholder_token = "" initializer_token = "person" # ---------------------- # Load models # ---------------------- pipe = StableDiffusionXLPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, ) add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token) ip_model = NestedAdapterInference( pipe, image_encoder_path, personalization_ckpt, 1024, vq_normalize_factor=2.0, device=device ) # Generation defaults negative_prompt = "bad anatomy, monochrome, lowres, worst quality, low quality" num_inference_steps = 30 guidance_scale = 5.0 # ---------------------- # Inference function with alignment # ---------------------- @spaces.GPU def generate_images(img1, img2, img3, prompt, w, num_samples, seed): # Collect non-empty reference images refs = [img for img in (img1, img2, img3) if img is not None] if not refs: return [] # Align directly on PIL aligned_refs = [align_face(img) for img in refs] # Resize to model resolution pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs] placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) # Generate personalized samples results = ip_model.generate( pil_image=pil_images, prompt=prompt, negative_prompt=negative_prompt, num_samples=num_samples, num_inference_steps=num_inference_steps, placeholder_token_ids=placeholder_token_ids, seed=seed if seed > 0 else None, guidance_scale=guidance_scale, multiple_images=True, special_token_weight=w ) return results # ---------------------- # Gradio UI # ---------------------- with gr.Blocks() as demo: gr.Markdown("## Nested Attention: Semantic-aware Attention Values for Concept Personalization") gr.Markdown( "Upload up to 3 reference images. " "Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\) in your prompt, " "set token weight, and choose how many outputs you want." ) with gr.Row(): with gr.Column(scale=1): # Reference images with gr.Row(): img1 = gr.Image(type="pil", label="Reference Image 1") img2 = gr.Image(type="pil", label="Reference Image 2 (optional)") img3 = gr.Image(type="pil", label="Reference Image 3 (optional)") prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a ") w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)") num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate") seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)") generate_button = gr.Button("Generate Images") # Add examples gr.Examples( examples=[ ["example_images/01.jpg", None, None, "a watercolor painting of a , closeup", 1.0, 4, 1], ["example_images/02.jpg", None, None, "an abstract pencil drawing of a ", 1.5, 4, 30], ["example_images/01.jpg", None, None, "a high quality photo of a as a firefighter", 3.0, 4, 10], ["example_images/02.jpg", None, None, "a high quality photo of a smiling in the snow", 2.0, 4, 40], ["example_images/01.jpg", None, None, "a pop figure of a , she stands on a white background", 2.0, 4, 20], ], inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input], label="Example Prompts" ) with gr.Column(scale=1): output_gallery = gr.Gallery(label="Generated Images", columns=3) generate_button.click( fn=generate_images, inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input], outputs=output_gallery ) demo.launch()