xarical's picture
Add type hints to infer for MCP compatibility
9750634
raw
history blame
6.29 kB
import gradio as gr
import numpy as np
import random
import spaces
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "tensorart/stable-diffusion-3.5-large-TurboX"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo_id, subfolder="scheduler", shift=5)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU(duration=65)
def infer(
prompt: str,
negative_prompt: str = "",
seed: int = 42,
randomize_seed: bool = False,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 1.5,
num_inference_steps: int = 8,
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> tuple["PIL.Image.Image", int]:
"""Generate an image using Stable Diffusion 3.5 Large TurboX.
Args:
prompt: The prompt to generate an image from
negative_prompt: The negative prompt
seed: The generation seed
randomize_seed: Whether to randomize the seed
width: The width of the generated image in pixels
height: The height of the generated image in pixels
guidance_scale: How closely the image should align with the prompt
num_inference_steps: Number of generation steps
Returns:
The generated image and the seed used for generation as a tuple of (PIL.Image.Image, seed)
"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
examples = [
"A capybara wearing a suit holding a sign that reads Hello World",
"A serene mountain lake at sunset with cherry blossoms floating on the water",
"A magical crystal dragon with iridescent scales in a glowing forest",
"A Victorian steampunk teapot with intricate brass gears and rose gold accents",
"A futuristic neon cityscape with flying cars and holographic billboards",
"A red panda painter creating a masterpiece with tiny paws in an art studio",
]
css = """
body {
background: linear-gradient(135deg, #f9e2e6 0%, #e8f3fc 50%, #e2f9f2 100%);
background-attachment: fixed;
min-height: 100vh;
}
#col-container {
margin: 0 auto;
max-width: 640px;
background-color: rgba(255, 255, 255, 0.85);
border-radius: 16px;
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1);
padding: 24px;
backdrop-filter: blur(10px);
}
.gradio-container {
background: transparent !important;
}
.gr-button-primary {
background: linear-gradient(90deg, #6b9dfc, #8c6bfc) !important;
border: none !important;
transition: all 0.3s ease;
}
.gr-button-primary:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(108, 99, 255, 0.3);
}
.gr-form {
border-radius: 12px;
background-color: rgba(255, 255, 255, 0.7);
}
.gr-accordion {
border-radius: 12px;
overflow: hidden;
}
h1 {
background: linear-gradient(90deg, #6b9dfc, #8c6bfc);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 800;
}
"""
with gr.Blocks(theme="apriel", css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # TensorArt Stable Diffusion 3.5 Large TurboX")
gr.Markdown("[8-step distilled turbo model](https://huggingface.co/tensorart/stable-diffusion-3.5-large-TurboX)")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=1.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=8,
)
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch(mcp_server=True)