Spaces:
Running on Zero

Ruurd commited on
Commit
db84545
·
verified ·
1 Parent(s): 12738e5

Add top_p and top_k sliders

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -179,11 +179,11 @@ def filter_logits(logits, top_k=0, top_p=0.0):
179
  return logits
180
 
181
  @spaces.GPU
182
- def generate_diffusion_text(input_ids):
183
  with torch.no_grad():
184
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
185
  logits = model(input_ids=input_tensor)["logits"]
186
- logits = filter_logits(logits, top_k=100, top_p=0.9)
187
  logits = logits.clamp(min=-1e8, max=1e4)
188
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
189
  probs = torch.clamp(probs, min=1e-8, max=1.0)
@@ -196,7 +196,9 @@ def generate_diffusion_text(input_ids):
196
  return sampled, conf
197
 
198
  # --- Inference Wrapper ---
199
- def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_start, use_confidence_noising, noise_clipping):
 
 
200
  placeholder = "What do you know about the city of Amsterdam?"
201
  if question.strip() == "":
202
  question = placeholder
@@ -229,7 +231,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
229
  print('Generating output')
230
 
231
  # Model step
232
- generated_tokens, confidences = generate_diffusion_text(current_tokens)
233
 
234
  elapsed = time.time() - generation_start
235
  remaining = pause_length - elapsed
@@ -322,14 +324,15 @@ demo = gr.Interface(
322
  fn=diffusion_chat,
323
  inputs=[
324
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
325
- gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
326
- gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
327
- gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="↓ = more noising (sharpness)"),
328
- gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
329
- gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="↑ = more noise (noise start)"),
330
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
331
- gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
332
-
 
333
  ],
334
  outputs=[gr.HTML(label="Diffusion Output")],
335
  title="Diffusion Language Model Chat",
 
179
  return logits
180
 
181
  @spaces.GPU
182
+ def generate_diffusion_text(input_ids, top_p, top_k):
183
  with torch.no_grad():
184
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
185
  logits = model(input_ids=input_tensor)["logits"]
186
+ logits = filter_logits(logits, top_k=top_p, top_p=top_k)
187
  logits = logits.clamp(min=-1e8, max=1e4)
188
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
189
  probs = torch.clamp(probs, min=1e-8, max=1.0)
 
196
  return sampled, conf
197
 
198
  # --- Inference Wrapper ---
199
+ def diffusion_chat(question, max_it, pause_length, sharpness,
200
+ clustering, noise_start, use_confidence_noising,
201
+ noise_clipping, top_p, top_k):
202
  placeholder = "What do you know about the city of Amsterdam?"
203
  if question.strip() == "":
204
  question = placeholder
 
231
  print('Generating output')
232
 
233
  # Model step
234
+ generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
235
 
236
  elapsed = time.time() - generation_start
237
  remaining = pause_length - elapsed
 
324
  fn=diffusion_chat,
325
  inputs=[
326
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
327
+ gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
328
+ gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
329
+ gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
330
+ gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
331
+ gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"),
332
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
333
+ gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
334
+ gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"]),
335
+ gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers"])
336
  ],
337
  outputs=[gr.HTML(label="Diffusion Output")],
338
  title="Diffusion Language Model Chat",