File size: 11,570 Bytes
f2a451c
 
 
 
 
b81a188
f2a451c
 
 
 
cdba24f
f2a451c
 
 
b81a188
 
 
 
 
 
 
 
 
f2a451c
 
 
f0916d6
35c1a87
 
f2a451c
940ab95
3669017
940ab95
 
3669017
 
 
 
 
 
35c1a87
3669017
 
 
940ab95
35c1a87
1feed0d
 
f2a451c
 
 
18907bb
 
 
 
 
 
 
35c1a87
18907bb
 
 
 
 
 
 
 
1feed0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a451c
1feed0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a451c
1feed0d
 
 
 
 
 
 
 
 
 
 
 
 
35c1a87
f2a451c
 
 
3669017
f2a451c
1feed0d
 
 
 
f2a451c
 
 
 
35c1a87
f2a451c
 
1feed0d
 
f2a451c
b46927f
f2a451c
 
 
1feed0d
 
 
 
f2a451c
1feed0d
 
b46927f
1feed0d
 
 
 
f2a451c
 
1feed0d
 
 
 
 
 
 
 
 
940ab95
 
 
35c1a87
 
 
1feed0d
 
 
 
 
 
 
 
 
 
 
 
 
940ab95
 
 
 
1feed0d
3669017
940ab95
f2a451c
1feed0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a451c
1feed0d
 
 
 
 
 
f2a451c
 
1feed0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import gradio as gr
import numpy as np
import spaces
from PIL import Image
import torch
from torch.amp import autocast

from transformers import AutoTokenizer, AutoModel
from models.gen_pipeline import NextStepPipeline

HF_HUB = "stepfun-ai/NextStep-1-Large"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)

model = AutoModel.from_pretrained(
    HF_HUB,
    local_files_only=False,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).to(device)

pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)

MAX_SEED = np.iinfo(np.int16).max
DEFAULT_POSITIVE_PROMPT = None
DEFAULT_NEGATIVE_PROMPT = None
DEFAULT_CFG = 7.5


def _ensure_pil(x):
    """Ensure returned image is a PIL.Image.Image."""
    if isinstance(x, Image.Image):
        return x
    import numpy as np
    if hasattr(x, "detach"):
        x = x.detach().float().clamp(0, 1).cpu().numpy()
    if isinstance(x, np.ndarray):
        if x.dtype != np.uint8:
            x = (x * 255.0).clip(0, 255).astype(np.uint8)
        if x.ndim == 3 and x.shape[0] in (1, 3, 4):  # CHW -> HWC
            x = np.moveaxis(x, 0, -1)
        return Image.fromarray(x)
    raise TypeError("Unsupported image type returned by pipeline.")


def infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress):
    """Core inference logic without GPU decorators."""
    if prompt in [None, ""]:
        gr.Warning("⚠️ Please enter a prompt!")
        return None
    with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
        imgs = pipeline.generate_image(
            prompt,
            hw=(int(height), int(width)),
            num_images_per_caption=1,
            positive_prompt=positive_prompt,
            negative_prompt=negative_prompt,
            cfg=float(cfg),
            cfg_img=1.0,
            cfg_schedule="constant",
            use_norm=False,
            num_sampling_steps=int(num_inference_steps),
            timesteps_shift=1.0,
            seed=int(seed),
            progress=True,
        )
    return _ensure_pil(imgs[0])


# Tier 1: Very small images with few steps
@spaces.GPU(duration=90)
def infer_tiny(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG,
               positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
               progress=gr.Progress(track_tqdm=True)):
    return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)


# Tier 2: Small to medium images with standard steps
@spaces.GPU(duration=150)
def infer_fast(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG,
               positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
               progress=gr.Progress(track_tqdm=True)):
    return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)


# Tier 3: Standard generation for most common cases
@spaces.GPU(duration=200)
def infer_std(prompt=None, seed=0, width=512, height=512, num_inference_steps=28, cfg=DEFAULT_CFG,
              positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
              progress=gr.Progress(track_tqdm=True)):
    return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)


# Tier 4: Larger images or more steps
@spaces.GPU(duration=300)
def infer_long(prompt=None, seed=0, width=512, height=512, num_inference_steps=36, cfg=DEFAULT_CFG,
               positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
               progress=gr.Progress(track_tqdm=True)):
    return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)


# Tier 5: Maximum quality with many steps
@spaces.GPU(duration=400)
def infer_max(prompt=None, seed=0, width=512, height=512, num_inference_steps=45, cfg=DEFAULT_CFG,
              positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
              progress=gr.Progress(track_tqdm=True)):
    return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)


# Improved JS dispatcher with better calculation logic
js_dispatch = """
function(width, height, steps){
  const w = Number(width);
  const h = Number(height);
  const s = Number(steps);

  // Calculate total pixels and complexity score
  const pixels = w * h;
  const megapixels = pixels / 1000000;

  // Complexity score combines image size and steps
  // Base: ~0.5 seconds per megapixel per step
  const complexity = megapixels * s;

  let target = 'btn-std';  // Default

  // Select appropriate tier based on complexity
  if (pixels <= 256*256 && s <= 20) {
    // Very small images with few steps
    target = 'btn-tiny';
  } else if (complexity < 5) {
    // Small images or few steps (e.g., 384x384 @ 24 steps = 3.5)
    target = 'btn-fast';
  } else if (complexity < 8) {
    // Standard generation (e.g., 512x512 @ 28 steps = 7.3)
    target = 'btn-std';
  } else if (complexity < 12) {
    // Larger or more steps (e.g., 512x512 @ 40 steps = 10.5)
    target = 'btn-long';
  } else {
    // Maximum complexity
    target = 'btn-max';
  }

  // Special cases: override based on extreme values
  if (s >= 45) {
    target = 'btn-max';  // Many steps always need more time
  } else if (pixels >= 512*512 && s >= 35) {
    target = 'btn-long';  // Large images with many steps
  }

  console.log(`Resolution: ${w}x${h}, Steps: ${s}, Complexity: ${complexity.toFixed(2)}, Selected: ${target}`);

  const b = document.getElementById(target);
  if (b) b.click();
}
"""

css = """
#col-container {
    margin: 0 auto;
    max-width: 800px;
}
/* Hide the dispatcher buttons */
#btn-tiny, #btn-fast, #btn-std, #btn-long, #btn-max {
    display: none !important;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# NextStep-1-Large — Image generation")

        with gr.Row():
            prompt = gr.Text(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt",
                             container=False)
            run_button = gr.Button("Run", scale=0, variant="primary")
            cancel_button = gr.Button("Cancel", scale=0, variant="secondary")

        with gr.Row():
            with gr.Accordion("Advanced Settings", open=True):
                positive_prompt = gr.Text(label="Positive Prompt", show_label=True,
                                          placeholder="Optional: add positives")
                negative_prompt = gr.Text(label="Negative Prompt", show_label=True,
                                          placeholder="Optional: add negatives")
                with gr.Row():
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=3407)
                    num_inference_steps = gr.Slider(label="Sampling steps", minimum=10, maximum=50, step=1, value=28)
                with gr.Row():
                    width = gr.Slider(label="Width", minimum=256, maximum=512, step=64, value=512)
                    height = gr.Slider(label="Height", minimum=256, maximum=512, step=64, value=512)
                cfg = gr.Slider(label="CFG (guidance scale)", minimum=0.0, maximum=20.0, step=0.5, value=DEFAULT_CFG,
                                info="Higher = closer to text, lower = more creative")

        with gr.Row():
            result_1 = gr.Image(label="Result", format="png", interactive=False)

        # Hidden dispatcher buttons
        with gr.Row(visible=False):
            btn_tiny = gr.Button(visible=False, elem_id="btn-tiny")
            btn_fast = gr.Button(visible=False, elem_id="btn-fast")
            btn_std = gr.Button(visible=False, elem_id="btn-std")
            btn_long = gr.Button(visible=False, elem_id="btn-long")
            btn_max = gr.Button(visible=False, elem_id="btn-max")

        examples = [
            [
                "Studio portrait of an elderly sailor with a weathered face, dramatic Rembrandt lighting, shallow depth of field",
                101, 512, 512, 32, 7.5,
                "photorealistic, sharp eyes, detailed skin texture, soft rim light, 85mm lens",
                "over-smoothed skin, plastic look, extra limbs, watermark"],
            ["Isometric cozy coffee shop interior with hanging plants and warm Edison bulbs",
             202, 512, 384, 30, 8.5,
             "isometric view, clean lines, stylized, warm ambience, detailed furniture",
             "text, logo, watermark, perspective distortion"],
            ["Ultra-wide desert canyon at golden hour with long shadows and dust in the air",
             303, 512, 320, 28, 7.0,
             "cinematic, volumetric light, natural colors, high dynamic range",
             "over-saturated, haze artifacts, blown highlights"],
            ["Oil painting of a stormy sea with a lighthouse, thick impasto brushwork",
             707, 384, 512, 34, 7.0,
             "textured canvas, visible brush strokes, dramatic sky, moody lighting",
             "smooth digital look, airbrush, neon colors"],
        ]

        gr.Examples(
            examples=examples,
            inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt],
            label="Click & Fill Examples (Exact Size)",
        )

        # Wire up the dispatcher buttons to their respective functions
        ev_tiny = btn_tiny.click(infer_tiny,
                                 inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
                                         negative_prompt],
                                 outputs=[result_1])
        ev_fast = btn_fast.click(infer_fast,
                                 inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
                                         negative_prompt],
                                 outputs=[result_1])
        ev_std = btn_std.click(infer_std,
                               inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
                                       negative_prompt],
                               outputs=[result_1])
        ev_long = btn_long.click(infer_long,
                                 inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
                                         negative_prompt],
                                 outputs=[result_1])
        ev_max = btn_max.click(infer_max,
                               inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
                                       negative_prompt],
                               outputs=[result_1])

        # Trigger JS dispatcher on run button or prompt submit
        run_button.click(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch)
        prompt.submit(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch)

        # Cancel button cancels all possible events
        cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[ev_tiny, ev_fast, ev_std, ev_long, ev_max])

if __name__ == "__main__":
    demo.launch()