Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import spaces | |
| import os | |
| from diffusers import DiffusionPipeline | |
| # --- Model Configuration and Loading --- | |
| MODEL_ID = "Manojb/stable-diffusion-2-1-base" | |
| DTYPE = torch.bfloat16 | |
| try: | |
| # Load pipeline | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| use_safetensors=True | |
| ) | |
| pipe.to('cuda') | |
| # --- Mandatory ZeroGPU AoT Compilation for Optimization --- | |
| # Extended duration for startup compilation | |
| def compile_unet(): | |
| print("Starting AoT compilation for UNet...") | |
| # Dummy inputs for 512x512 generation (B=1, latents=64x64 for UNet) | |
| B, C, H, W = 1, 4, 64, 64 | |
| sample = torch.randn(B, C, H, W, dtype=DTYPE, device='cuda') | |
| timestep = torch.tensor([999], dtype=torch.long, device='cuda') | |
| # Encoder Hidden States (text embeddings): (B, 77, 1024) for SD2.1 | |
| EHS_DIM = 77 | |
| EHS_HIDDEN = 1024 | |
| encoder_hidden_states = torch.randn(B, EHS_DIM, EHS_HIDDEN, dtype=DTYPE, device='cuda') | |
| inputs = (sample, timestep, encoder_hidden_states) | |
| with spaces.aoti_capture(pipe.unet) as call: | |
| call(*inputs) | |
| exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs) | |
| compiled_model = spaces.aoti_compile(exported) | |
| print("AoT compilation successful.") | |
| return compiled_model | |
| # Execute compilation during startup | |
| compiled_unet = compile_unet() | |
| spaces.aoti_apply(compiled_unet, pipe.unet) | |
| except Exception as e: | |
| print(f"⚠️ Warning: Model initialization or AoT compilation failed ({e}). Running without optimization or skipping initialization if severe.") | |
| # Fallback to loading the model without AoT if compilation fails | |
| if 'pipe' not in locals(): | |
| pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE, use_safetensors=True) | |
| pipe.to('cuda') | |
| print("Model loaded successfully without AoT.") | |
| # Standard GPU allocation for inference | |
| def generate(prompt: str, num_images: int): | |
| """Generates images using the Stable Diffusion pipeline.""" | |
| if not prompt: | |
| raise gr.Error("Prompt cannot be empty.") | |
| # Prepare batch input | |
| prompt_list = [prompt] * num_images | |
| # Generate images | |
| output = pipe( | |
| prompt_list, | |
| num_inference_steps=25, | |
| guidance_scale=9.0, | |
| ) | |
| return output.images | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="SD 2.1 Base Generator") as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>Stable Diffusion 2.1 Base (512x512)</h1> | |
| <p>Model: Manojb/stable-diffusion-2-1-base | Optimized with ZeroGPU AoT</p> | |
| <p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="A detailed digital painting of a majestic dragon flying over a medieval castle, fantasy art", | |
| lines=3 | |
| ) | |
| num_images = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=2, | |
| label="Number of Images to Generate (Max 4)", | |
| info="Generates multiple images in a single batch call." | |
| ) | |
| generate_btn = gr.Button("Generate Images", variant="primary") | |
| with gr.Column(scale=2): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images (512x512)", | |
| height=512, | |
| columns=2, | |
| rows=2, | |
| object_fit="contain" | |
| ) | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=[prompt, num_images], | |
| outputs=output_gallery | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["A photorealistic portrait of a golden retriever wearing sunglasses on a beach, cinematic lighting", 2], | |
| ["Steampunk owl on a bookshelf, detailed brass gears, oil painting", 4], | |
| ["High contrast black and white photograph of an old lighthouse during a storm", 1] | |
| ], | |
| inputs=[prompt, num_images], | |
| outputs=output_gallery, | |
| fn=generate, | |
| cache_examples=True, | |
| cache_mode="eager" | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |