Spaces:
Running
Running
import os | |
import random | |
import time | |
from typing import Optional | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler | |
# ----------------------------- | |
# Device & Precision | |
# ----------------------------- | |
USE_CUDA = torch.cuda.is_available() | |
DTYPE = torch.float16 if USE_CUDA else torch.float32 | |
DEVICE = "cuda" if USE_CUDA else "cpu" | |
MODEL_ID = os.environ.get("MODEL_ID", "runwayml/stable-diffusion-v1-5") | |
pipe: Optional[StableDiffusionPipeline] = None | |
def load_pipeline(): | |
"""Load and configure the Stable Diffusion pipeline once at startup.""" | |
global pipe | |
t0 = time.time() | |
pipe = StableDiffusionPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=DTYPE, | |
safety_checker=None, # Keep None for faster demos | |
) | |
# Use a fast, good-quality scheduler | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to(DEVICE) | |
# Optional memory optimization on GPU | |
if USE_CUDA: | |
try: | |
pipe.enable_attention_slicing() | |
pipe.enable_xformers_memory_efficient_attention() | |
except Exception: | |
pass | |
t1 = time.time() | |
print(f"Pipeline loaded in {t1 - t0:.2f}s on {DEVICE} (dtype={DTYPE}).") | |
# Load on import (Space boot) | |
load_pipeline() | |
def generate_image( | |
prompt: str, | |
negative_prompt: str, | |
steps: int, | |
guidance: float, | |
width: int, | |
height: int, | |
seed: int, | |
): | |
if not prompt or len(prompt.strip()) == 0: | |
raise gr.Error("Please enter a prompt.") | |
width = max(256, min(1024, width)) | |
height = max(256, min(1024, height)) | |
if seed == -1: | |
seed = random.randint(0, 2**31 - 1) | |
generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
with torch.autocast(DEVICE, enabled=USE_CUDA): | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt or None, | |
num_inference_steps=int(steps), | |
guidance_scale=float(guidance), | |
width=int(width), | |
height=int(height), | |
generator=generator, | |
).images[0] | |
return image, seed | |
# ----------------------------- | |
# Gradio UI | |
# ----------------------------- | |
with gr.Blocks(title="Stable Diffusion Image Generator", css="footer {visibility: hidden}") as demo: | |
gr.Markdown( | |
""" | |
# 🧠 Stable Diffusion Image Generator | |
Type a prompt and generate an image using **Stable Diffusion v1.5**. | |
**Tip:** For consistent results, set a fixed seed. Use `-1` for random seed. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="a cinematic portrait of an astronaut relaxing in a tropical cafe, 35mm photo, bokeh, soft light", | |
lines=3, | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt (optional)", | |
placeholder="blurry, low quality, extra fingers, text, watermark", | |
lines=2, | |
) | |
with gr.Row(): | |
steps = gr.Slider(5, 50, value=25, step=1, label="Steps") | |
guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.5, label="Guidance Scale") | |
with gr.Row(): | |
width = gr.Slider(256, 1024, value=512, step=64, label="Width") | |
height = gr.Slider(256, 1024, value=512, step=64, label="Height") | |
seed = gr.Number(value=-1, precision=0, label="Seed (-1 for random)") | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Column(scale=4): | |
out_image = gr.Image(label="Result", type="pil") | |
out_seed = gr.Number(label="Used Seed", interactive=False) | |
examples = gr.Examples( | |
examples=[ | |
[ | |
"ultra-detailed watercolor of a koi fish swirling through clouds, ethereal, pastel palette", | |
"lowres, noisy, text", | |
28, | |
7.5, | |
512, | |
512, | |
1234, | |
], | |
[ | |
"cozy cyberpunk alley coffee shop at dusk, volumetric lighting, rain reflections, 4k", | |
"low quality, oversaturated", | |
25, | |
6.5, | |
640, | |
384, | |
-1, | |
], | |
[ | |
"studio photo of a cute corgi wearing sunglasses, soft light, shallow depth of field", | |
"text, watermark, blurry", | |
22, | |
7.0, | |
512, | |
512, | |
2024, | |
], | |
], | |
inputs=[prompt, negative_prompt, steps, guidance, width, height, seed], | |
) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[prompt, negative_prompt, steps, guidance, width, height, seed], | |
outputs=[out_image, out_seed], | |
api_name="generate", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |