hub-6oa42x9a / app.py
Gertie01's picture
Deploy Gradio app with multiple files
ef5d3f9 verified
import gradio as gr
import torch
import spaces
import os
from diffusers import DiffusionPipeline
# --- Model Configuration and Loading ---
MODEL_ID = "Manojb/stable-diffusion-2-1-base"
DTYPE = torch.bfloat16
try:
# Load pipeline
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
use_safetensors=True
)
pipe.to('cuda')
# --- Mandatory ZeroGPU AoT Compilation for Optimization ---
@spaces.GPU(duration=1500) # Extended duration for startup compilation
def compile_unet():
print("Starting AoT compilation for UNet...")
# Dummy inputs for 512x512 generation (B=1, latents=64x64 for UNet)
B, C, H, W = 1, 4, 64, 64
sample = torch.randn(B, C, H, W, dtype=DTYPE, device='cuda')
timestep = torch.tensor([999], dtype=torch.long, device='cuda')
# Encoder Hidden States (text embeddings): (B, 77, 1024) for SD2.1
EHS_DIM = 77
EHS_HIDDEN = 1024
encoder_hidden_states = torch.randn(B, EHS_DIM, EHS_HIDDEN, dtype=DTYPE, device='cuda')
inputs = (sample, timestep, encoder_hidden_states)
with spaces.aoti_capture(pipe.unet) as call:
call(*inputs)
exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs)
compiled_model = spaces.aoti_compile(exported)
print("AoT compilation successful.")
return compiled_model
# Execute compilation during startup
compiled_unet = compile_unet()
spaces.aoti_apply(compiled_unet, pipe.unet)
except Exception as e:
print(f"⚠️ Warning: Model initialization or AoT compilation failed ({e}). Running without optimization or skipping initialization if severe.")
# Fallback to loading the model without AoT if compilation fails
if 'pipe' not in locals():
pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE, use_safetensors=True)
pipe.to('cuda')
print("Model loaded successfully without AoT.")
@spaces.GPU(duration=60) # Standard GPU allocation for inference
def generate(prompt: str, num_images: int):
"""Generates images using the Stable Diffusion pipeline."""
if not prompt:
raise gr.Error("Prompt cannot be empty.")
# Prepare batch input
prompt_list = [prompt] * num_images
# Generate images
output = pipe(
prompt_list,
num_inference_steps=25,
guidance_scale=9.0,
)
return output.images
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="SD 2.1 Base Generator") as demo:
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h1>Stable Diffusion 2.1 Base (512x512)</h1>
<p>Model: Manojb/stable-diffusion-2-1-base | Optimized with ZeroGPU AoT</p>
<p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="A detailed digital painting of a majestic dragon flying over a medieval castle, fantasy art",
lines=3
)
num_images = gr.Slider(
minimum=1,
maximum=4,
step=1,
value=2,
label="Number of Images to Generate (Max 4)",
info="Generates multiple images in a single batch call."
)
generate_btn = gr.Button("Generate Images", variant="primary")
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Generated Images (512x512)",
height=512,
columns=2,
rows=2,
object_fit="contain"
)
generate_btn.click(
fn=generate,
inputs=[prompt, num_images],
outputs=output_gallery
)
gr.Examples(
examples=[
["A photorealistic portrait of a golden retriever wearing sunglasses on a beach, cinematic lighting", 2],
["Steampunk owl on a bookshelf, detailed brass gears, oil painting", 4],
["High contrast black and white photograph of an old lighthouse during a storm", 1]
],
inputs=[prompt, num_images],
outputs=output_gallery,
fn=generate,
cache_examples=True,
cache_mode="eager"
)
demo.queue()
if __name__ == "__main__":
demo.launch()