Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import spaces | |
from PIL import Image | |
import torch | |
from torch.amp import autocast | |
from transformers import AutoTokenizer, AutoModel | |
from models.gen_pipeline import NextStepPipeline | |
HF_HUB = "stepfun-ai/NextStep-1-Large" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True) | |
model = AutoModel.from_pretrained( | |
HF_HUB, | |
local_files_only=False, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
).to(device) | |
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16) | |
MAX_SEED = np.iinfo(np.int16).max | |
DEFAULT_POSITIVE_PROMPT = None | |
DEFAULT_NEGATIVE_PROMPT = None | |
DEFAULT_CFG = 7.5 | |
def _ensure_pil(x): | |
"""Ensure returned image is a PIL.Image.Image.""" | |
if isinstance(x, Image.Image): | |
return x | |
import numpy as np | |
if hasattr(x, "detach"): | |
x = x.detach().float().clamp(0, 1).cpu().numpy() | |
if isinstance(x, np.ndarray): | |
if x.dtype != np.uint8: | |
x = (x * 255.0).clip(0, 255).astype(np.uint8) | |
if x.ndim == 3 and x.shape[0] in (1, 3, 4): # CHW -> HWC | |
x = np.moveaxis(x, 0, -1) | |
return Image.fromarray(x) | |
raise TypeError("Unsupported image type returned by pipeline.") | |
def infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress): | |
"""Core inference logic without GPU decorators.""" | |
if prompt in [None, ""]: | |
gr.Warning("⚠️ Please enter a prompt!") | |
return None | |
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16): | |
imgs = pipeline.generate_image( | |
prompt, | |
hw=(int(height), int(width)), | |
num_images_per_caption=1, | |
positive_prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
cfg=float(cfg), | |
cfg_img=1.0, | |
cfg_schedule="constant", | |
use_norm=False, | |
num_sampling_steps=int(num_inference_steps), | |
timesteps_shift=1.0, | |
seed=int(seed), | |
progress=True, | |
) | |
return _ensure_pil(imgs[0]) | |
# Tier 1: Very small images with few steps | |
def infer_tiny(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG, | |
positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
progress=gr.Progress(track_tqdm=True)): | |
return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
# Tier 2: Small to medium images with standard steps | |
def infer_fast(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG, | |
positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
progress=gr.Progress(track_tqdm=True)): | |
return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
# Tier 3: Standard generation for most common cases | |
def infer_std(prompt=None, seed=0, width=512, height=512, num_inference_steps=28, cfg=DEFAULT_CFG, | |
positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
progress=gr.Progress(track_tqdm=True)): | |
return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
# Tier 4: Larger images or more steps | |
def infer_long(prompt=None, seed=0, width=512, height=512, num_inference_steps=36, cfg=DEFAULT_CFG, | |
positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
progress=gr.Progress(track_tqdm=True)): | |
return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
# Tier 5: Maximum quality with many steps | |
def infer_max(prompt=None, seed=0, width=512, height=512, num_inference_steps=45, cfg=DEFAULT_CFG, | |
positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
progress=gr.Progress(track_tqdm=True)): | |
return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress) | |
# Improved JS dispatcher with better calculation logic | |
js_dispatch = """ | |
function(width, height, steps){ | |
const w = Number(width); | |
const h = Number(height); | |
const s = Number(steps); | |
// Calculate total pixels and complexity score | |
const pixels = w * h; | |
const megapixels = pixels / 1000000; | |
// Complexity score combines image size and steps | |
// Base: ~0.5 seconds per megapixel per step | |
const complexity = megapixels * s; | |
let target = 'btn-std'; // Default | |
// Select appropriate tier based on complexity | |
if (pixels <= 256*256 && s <= 20) { | |
// Very small images with few steps | |
target = 'btn-tiny'; | |
} else if (complexity < 5) { | |
// Small images or few steps (e.g., 384x384 @ 24 steps = 3.5) | |
target = 'btn-fast'; | |
} else if (complexity < 8) { | |
// Standard generation (e.g., 512x512 @ 28 steps = 7.3) | |
target = 'btn-std'; | |
} else if (complexity < 12) { | |
// Larger or more steps (e.g., 512x512 @ 40 steps = 10.5) | |
target = 'btn-long'; | |
} else { | |
// Maximum complexity | |
target = 'btn-max'; | |
} | |
// Special cases: override based on extreme values | |
if (s >= 45) { | |
target = 'btn-max'; // Many steps always need more time | |
} else if (pixels >= 512*512 && s >= 35) { | |
target = 'btn-long'; // Large images with many steps | |
} | |
console.log(`Resolution: ${w}x${h}, Steps: ${s}, Complexity: ${complexity.toFixed(2)}, Selected: ${target}`); | |
const b = document.getElementById(target); | |
if (b) b.click(); | |
} | |
""" | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
/* Hide the dispatcher buttons */ | |
#btn-tiny, #btn-fast, #btn-std, #btn-long, #btn-max { | |
display: none !important; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# NextStep-1-Large — Image generation") | |
with gr.Row(): | |
prompt = gr.Text(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt", | |
container=False) | |
run_button = gr.Button("Run", scale=0, variant="primary") | |
cancel_button = gr.Button("Cancel", scale=0, variant="secondary") | |
with gr.Row(): | |
with gr.Accordion("Advanced Settings", open=True): | |
positive_prompt = gr.Text(label="Positive Prompt", show_label=True, | |
placeholder="Optional: add positives") | |
negative_prompt = gr.Text(label="Negative Prompt", show_label=True, | |
placeholder="Optional: add negatives") | |
with gr.Row(): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=3407) | |
num_inference_steps = gr.Slider(label="Sampling steps", minimum=10, maximum=50, step=1, value=28) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=512, step=64, value=512) | |
height = gr.Slider(label="Height", minimum=256, maximum=512, step=64, value=512) | |
cfg = gr.Slider(label="CFG (guidance scale)", minimum=0.0, maximum=20.0, step=0.5, value=DEFAULT_CFG, | |
info="Higher = closer to text, lower = more creative") | |
with gr.Row(): | |
result_1 = gr.Image(label="Result", format="png", interactive=False) | |
# Hidden dispatcher buttons | |
with gr.Row(visible=False): | |
btn_tiny = gr.Button(visible=False, elem_id="btn-tiny") | |
btn_fast = gr.Button(visible=False, elem_id="btn-fast") | |
btn_std = gr.Button(visible=False, elem_id="btn-std") | |
btn_long = gr.Button(visible=False, elem_id="btn-long") | |
btn_max = gr.Button(visible=False, elem_id="btn-max") | |
examples = [ | |
[ | |
"Studio portrait of an elderly sailor with a weathered face, dramatic Rembrandt lighting, shallow depth of field", | |
101, 512, 512, 32, 7.5, | |
"photorealistic, sharp eyes, detailed skin texture, soft rim light, 85mm lens", | |
"over-smoothed skin, plastic look, extra limbs, watermark"], | |
["Isometric cozy coffee shop interior with hanging plants and warm Edison bulbs", | |
202, 512, 384, 30, 8.5, | |
"isometric view, clean lines, stylized, warm ambience, detailed furniture", | |
"text, logo, watermark, perspective distortion"], | |
["Ultra-wide desert canyon at golden hour with long shadows and dust in the air", | |
303, 512, 320, 28, 7.0, | |
"cinematic, volumetric light, natural colors, high dynamic range", | |
"over-saturated, haze artifacts, blown highlights"], | |
["Oil painting of a stormy sea with a lighthouse, thick impasto brushwork", | |
707, 384, 512, 34, 7.0, | |
"textured canvas, visible brush strokes, dramatic sky, moody lighting", | |
"smooth digital look, airbrush, neon colors"], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt], | |
label="Click & Fill Examples (Exact Size)", | |
) | |
# Wire up the dispatcher buttons to their respective functions | |
ev_tiny = btn_tiny.click(infer_tiny, | |
inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
negative_prompt], | |
outputs=[result_1]) | |
ev_fast = btn_fast.click(infer_fast, | |
inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
negative_prompt], | |
outputs=[result_1]) | |
ev_std = btn_std.click(infer_std, | |
inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
negative_prompt], | |
outputs=[result_1]) | |
ev_long = btn_long.click(infer_long, | |
inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
negative_prompt], | |
outputs=[result_1]) | |
ev_max = btn_max.click(infer_max, | |
inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, | |
negative_prompt], | |
outputs=[result_1]) | |
# Trigger JS dispatcher on run button or prompt submit | |
run_button.click(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch) | |
prompt.submit(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch) | |
# Cancel button cancels all possible events | |
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[ev_tiny, ev_fast, ev_std, ev_long, ev_max]) | |
if __name__ == "__main__": | |
demo.launch() |