Spaces:
Ruurd
/
Running on Zero

Ruurd commited on
Commit
63d4168
·
1 Parent(s): ec83427

Simplified interface

Browse files
Files changed (3) hide show
  1. app.py +27 -40
  2. infer.py +0 -90
  3. requirements.txt +1 -0
app.py CHANGED
@@ -16,7 +16,6 @@ from infer import (
16
  find_answer_start,
17
  get_noising_schedule,
18
  noisify_answer,
19
- generate_diffusion_text,
20
  filter_logits,
21
  confidence_guided_noising,
22
  noisify_answer_without_remasking
@@ -39,17 +38,17 @@ rng = np.random.default_rng()
39
  def generate_diffusion_text(input_ids, top_p, top_k):
40
  with torch.no_grad():
41
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
42
- with torch.amp.autocast('cuda', dtype=torch.float16):
 
43
  logits = model(input_ids=input_tensor)["logits"]
44
- logits = filter_logits(logits, top_k=top_p, top_p=top_k)
 
45
  logits = logits.clamp(min=-1e8, max=1e4)
46
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
47
  probs = torch.clamp(probs, min=1e-8, max=1.0)
48
- assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
49
- assert (probs >= 0).all(), "Negative probs!"
50
  sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
51
-
52
- # Extract confidence of selected tokens
53
  conf = probs[range(len(sampled)), sampled].cpu().numpy()
54
  return sampled, conf
55
 
@@ -79,10 +78,14 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
79
  highlighted.append(tok_str)
80
  return "".join(highlighted)
81
 
82
- def diffusion_chat(question, max_it, pause_length, sharpness,
83
- clustering, noise_start, use_confidence_noising,
84
- use_permanent_unmasking, noise_clipping, top_p,
85
- top_k):
 
 
 
 
86
 
87
  if question.strip() == "":
88
  question = "What do you know about the city of Amsterdam?"
@@ -111,6 +114,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
111
  unmasked_mask = [False] * len(current_tokens)
112
 
113
  for i in range(max_it):
 
114
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
115
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
116
 
@@ -133,25 +137,15 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
133
  if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
134
  yield render_html("Stopped early", f"After {i+1} iterations.")
135
  break
136
-
137
  # NOISING
138
- if i < max_it-1:
139
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
140
- if use_confidence_noising:
141
- noised_answer, just_noised_indices = confidence_guided_noising(
142
- current_tokens, answer_start, tokenizer, confidences, noise_clipping,
143
- threshold=threshold, noise_start=noise_start
144
- )
145
- elif use_permanent_unmasking:
146
- noised_answer, just_noised_indices = noisify_answer_without_remasking(
147
- current_tokens, answer_start, tokenizer, threshold=threshold,
148
- noise_start=noise_start, unmasked_mask=unmasked_mask
149
- )
150
- else:
151
- noised_answer, just_noised_indices = noisify_answer(
152
- current_tokens, answer_start, tokenizer,
153
- threshold=threshold, clustering=clustering, noise_start=noise_start
154
- )
155
 
156
  for idx in range(answer_start, len(current_tokens)):
157
  if noised_answer[idx] != mask_token_id:
@@ -172,7 +166,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
172
  final_ids = answer_ids
173
 
174
  final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
175
- yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
176
 
177
 
178
  def is_running_on_spaces():
@@ -197,22 +191,15 @@ print("✅ Model loaded.")
197
  vocab_size = len(tokenizer)
198
  eos_token_id = tokenizer.eos_token_id
199
  mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
200
- assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
201
 
202
  demo = gr.Interface(
203
  fn=diffusion_chat,
204
  inputs=[
205
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
206
- gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: = more iterations"),
207
- gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration = longer pause"),
208
- gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: = more noise in later iterations"),
209
- gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
210
- gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
211
- gr.Checkbox(value=False, label="Use confidence-guided noising"),
212
- gr.Checkbox(value=False, label="Use permanent unmasking"),
213
- gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
214
- gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
215
- gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers")
216
  ],
217
  outputs=[gr.HTML(label="Diffusion Output")],
218
  title="Diffusion Language Model Chat",
 
16
  find_answer_start,
17
  get_noising_schedule,
18
  noisify_answer,
 
19
  filter_logits,
20
  confidence_guided_noising,
21
  noisify_answer_without_remasking
 
38
  def generate_diffusion_text(input_ids, top_p, top_k):
39
  with torch.no_grad():
40
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
41
+
42
+ with torch.cuda.amp.autocast(dtype=torch.float16):
43
  logits = model(input_ids=input_tensor)["logits"]
44
+
45
+ logits = filter_logits(logits, top_k=top_k, top_p=top_p)
46
  logits = logits.clamp(min=-1e8, max=1e4)
47
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
48
  probs = torch.clamp(probs, min=1e-8, max=1.0)
49
+ # assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
50
+ # assert (probs >= 0).all(), "Negative probs!"
51
  sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
 
 
52
  conf = probs[range(len(sampled)), sampled].cpu().numpy()
53
  return sampled, conf
54
 
 
78
  highlighted.append(tok_str)
79
  return "".join(highlighted)
80
 
81
+ def diffusion_chat(question, noising, max_it, pause_length):
82
+
83
+ pause_length = 0
84
+ sharpness = 3.0
85
+ noise_start = 0.5
86
+ top_p = 1.0
87
+ top_k = 10
88
+ clustering = False
89
 
90
  if question.strip() == "":
91
  question = "What do you know about the city of Amsterdam?"
 
114
  unmasked_mask = [False] * len(current_tokens)
115
 
116
  for i in range(max_it):
117
+
118
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
119
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
120
 
 
137
  if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
138
  yield render_html("Stopped early", f"After {i+1} iterations.")
139
  break
140
+
141
  # NOISING
142
+ if i < max_it-1 and noising:
143
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
144
+
145
+ noised_answer, just_noised_indices = noisify_answer(
146
+ current_tokens, answer_start, tokenizer,
147
+ threshold=threshold, clustering=clustering, noise_start=noise_start
148
+ )
 
 
 
 
 
 
 
 
 
 
149
 
150
  for idx in range(answer_start, len(current_tokens)):
151
  if noised_answer[idx] != mask_token_id:
 
166
  final_ids = answer_ids
167
 
168
  final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
169
+ yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore
170
 
171
 
172
  def is_running_on_spaces():
 
191
  vocab_size = len(tokenizer)
192
  eos_token_id = tokenizer.eos_token_id
193
  mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
194
+ assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
195
 
196
  demo = gr.Interface(
197
  fn=diffusion_chat,
198
  inputs=[
199
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
200
+ gr.Checkbox(label="Enable noising", value=True, info="If disabled, the model will not apply any intermediate noise."),
201
+ gr.Slider(1, 512, value=64, step=1, label="Increase the maximum number of iterations to run."),
202
+ gr.Slider(0, 5, value=0, step=0.01, label="Increase the pause between iterations to visualize the process.")
 
 
 
 
 
 
 
203
  ],
204
  outputs=[gr.HTML(label="Diffusion Output")],
205
  title="Diffusion Language Model Chat",
infer.py CHANGED
@@ -190,26 +190,6 @@ def confidence_guided_noising(input_ids, answer_start, tokenizer, confidences, n
190
  noised_indices = sorted(noised_indices)
191
  return noised, noised_indices
192
 
193
- def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
194
- eos_token_id=None, eos_boost=0.0):
195
- model.eval()
196
- with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
197
- input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
198
- logits = model(input_ids=input_tensor)["logits"] # (1, seq_len, vocab_size)
199
-
200
- # Optionally boost or suppress EOS token
201
- if eos_token_id is not None and eos_boost != 0.0:
202
- logits[:, :, eos_token_id] += eos_boost
203
-
204
- # Filter and sample
205
- filtered_logits = filter_logits(logits, top_k=top_k, top_p=top_p, temperature=temperature)
206
- probs = F.softmax(filtered_logits, dim=-1).squeeze() # (seq_len, vocab_size)
207
- probs = torch.clamp(probs, min=1e-8, max=1.0)
208
- sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
209
- confidences = probs.gather(1, sampled.unsqueeze(-1)).squeeze(-1)
210
-
211
- return input_ids[:answer_start] + sampled[answer_start:].tolist(), confidences
212
-
213
 
214
  def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
215
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -277,73 +257,3 @@ def save_html_colored_output(filename, html_content):
277
  </body>
278
  </html>
279
  """)
280
-
281
-
282
- def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
283
- noising_sharpness=5.0, max_length=256, top_k=100, top_p=1.0,
284
- temperature=1.0, eos_token_id = None, eos_boost = 0.0) -> str:
285
-
286
- if eos_token_id is None:
287
- eos_token_id = tokenizer.eos_token_id
288
- # Format prompt with LLaMA 3 chat template
289
- prompt = (
290
- "<|begin_of_text|>\n"
291
- "<|start_header_id|>system<|end_header_id|>\n"
292
- "You are a helpful assistant.\n"
293
- "<|eot_id|>\n"
294
- "<|start_header_id|>user<|end_header_id|>\n"
295
- f"{question.strip()}\n"
296
- "<|start_header_id|>assistant<|end_header_id|>\n"
297
- )
298
- input_ids = tokenizer.encode(prompt, add_special_tokens=False)
299
- marker = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
300
-
301
- def find_answer_start(ids, marker):
302
- for i in range(len(ids) - len(marker) + 1):
303
- if ids[i:i+len(marker)] == marker:
304
- return i + len(marker)
305
- return None
306
-
307
- answer_start = find_answer_start(input_ids, marker)
308
- if answer_start is None:
309
- raise ValueError("Assistant marker not found in prompt.")
310
-
311
- # Pad to max length
312
- pad_token = tokenizer.eos_token_id
313
- mask_token = tokenizer.encode("MASK", add_special_tokens=False)[0]
314
- input_ids = input_ids[:max_length]
315
- if len(input_ids) < max_length:
316
- input_ids += [mask_token] * (max_length - len(input_ids))
317
-
318
- ori_tokens = input_ids
319
- current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token)
320
-
321
- last_tokens = []
322
- for step in range(max_it):
323
- # Generate a new prediction
324
- current_tokens, confidence_scores = generate_diffusion_text(
325
- model, current_tokens, answer_start,
326
- top_k=top_k, top_p=top_p, temperature=temperature,
327
- eos_token_id=eos_token_id, eos_boost=eos_boost
328
- )
329
-
330
- # Display for debugging / tracking
331
- display_diffusion_output(
332
- step, max_it, question,
333
- ori_tokens, current_tokens, confidence_scores,
334
- answer_start, tokenizer
335
- )
336
-
337
- # Early stopping
338
- last_tokens.append(current_tokens)
339
- if len(last_tokens) > 4:
340
- last_tokens.pop(0)
341
- if all(t == last_tokens[0] for t in last_tokens):
342
- break
343
-
344
- # Re-apply noise for next iteration
345
- if step < max_it - 1:
346
- threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
347
- current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token)
348
-
349
- return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
 
190
  noised_indices = sorted(noised_indices)
191
  return noised, noised_indices
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
195
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
257
  </body>
258
  </html>
259
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -7,3 +7,4 @@ gradio>=4.10.0
7
  numpy
8
  load_dotenv
9
  ipython
 
 
7
  numpy
8
  load_dotenv
9
  ipython
10
+ spaces