Spaces:
Running on Zero

Ruurd commited on
Commit
fb56411
·
1 Parent(s): 0e840df

Improve white space display

Browse files
Files changed (3) hide show
  1. app.py +50 -76
  2. infer.py +50 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -101,131 +101,105 @@ def format_chat_prompt(question):
101
  "<|start_header_id|>assistant<|end_header_id|>\n"
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # --- Inference Wrapper ---
105
  def diffusion_chat(question, max_it, pause_length, sharpness,
106
  clustering, noise_start, use_confidence_noising,
107
  noise_clipping, top_p, top_k):
108
- placeholder = "What do you know about the city of Amsterdam?"
109
  if question.strip() == "":
110
- question = placeholder
111
 
112
- print('started generation')
113
  prompt = format_chat_prompt(question)
114
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
115
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
116
  if answer_start is None:
117
- yield "Error: Could not find Assistant marker in input."
118
  return
119
-
120
- if len(input_ids) < 256:
121
- input_ids += [mask_token_id] * (256 - len(input_ids))
122
- else:
123
- input_ids = input_ids[:256]
124
 
 
125
  ori_input_tokens = input_ids
 
126
  current_tokens, just_noised_indices = noisify_answer(
127
- input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start = 1.0,
128
- )
129
- yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>')
 
130
  time.sleep(pause_length)
131
- last_tokens = []
132
- prev_decoded_tokens = []
133
 
134
- generation_start = time.time()
 
135
 
136
  for i in range(max_it):
137
- print('Generating output')
138
-
139
- # Model step
140
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
141
-
142
- elapsed = time.time() - generation_start
143
- remaining = pause_length - elapsed
144
- if remaining > 0:
145
- time.sleep(remaining)
146
-
147
- # Save full output for noising step
148
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
149
 
150
- # --- GREEN HIGHLIGHT ---
151
- decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
152
- highlighted = []
153
- for j, tok in enumerate(decoded_tokens):
154
- tok_id = tokenizer.convert_tokens_to_ids(tok)
155
- if tok_id == eos_token_id:
156
- continue
157
- token_str = tokenizer.convert_tokens_to_string([tok])
158
- if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
159
- highlighted.append(f'<span style="color:green">{token_str}</span>')
160
- else:
161
- highlighted.append(token_str)
162
-
163
- prev_decoded_tokens = decoded_tokens
164
- yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
165
  time.sleep(pause_length)
166
 
167
- # --- Early stopping ---
168
  last_tokens.append(current_tokens)
169
  if len(last_tokens) > 3:
170
  last_tokens.pop(0)
171
- if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
172
- yield f"<b>Stopped early after {i+1} iterations.</b>"
173
  break
174
 
175
- previous_tokens = current_tokens.copy()
176
-
177
- # --- NOISING STEP ---
178
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
179
  if use_confidence_noising:
180
  noised_answer, just_noised_indices = confidence_guided_noising(
181
- current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, noise_start=noise_start
 
182
  )
183
- # just_noised_indices = []
184
  else:
185
  noised_answer, just_noised_indices = noisify_answer(
186
- current_tokens, answer_start, tokenizer, threshold=threshold, clustering=clustering, noise_start = noise_start,
 
187
  )
188
 
189
- # --- RED HIGHLIGHT ---
190
- decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
191
- highlighted = []
192
- for j, tok in enumerate(decoded_tokens):
193
- tok_id = tokenizer.convert_tokens_to_ids(tok)
194
- if tok_id == eos_token_id:
195
- continue
196
- token_str = tokenizer.convert_tokens_to_string([tok])
197
- abs_idx = answer_start + j
198
- if abs_idx in just_noised_indices:
199
- highlighted.append(f'<span style="color:red">{token_str}</span>')
200
- else:
201
- highlighted.append(token_str)
202
-
203
- # Compose full input again: prompt + noised answer
204
- current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
205
-
206
- yield f"<b>Iteration {i+1}/{max_it} (before noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
207
- generation_start = time.time()
208
 
 
209
 
 
210
  answer_ids = current_tokens[answer_start:]
211
  try:
212
- eos_index = answer_ids.index(eos_token_id)
213
- final_ids = answer_ids[:eos_index]
214
  except ValueError:
215
  final_ids = answer_ids
216
-
217
- num_tokens = len(final_ids)
218
  final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
219
-
220
- print(final_output)
221
- yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
222
 
223
 
224
  # --- Gradio Interface ---
225
  print("Loading model...")
226
  ckpt_path = hf_hub_download(
227
  repo_id="ruurd/tini_model",
228
- filename="diffusion-model.pth",
229
  token=os.getenv("HF_TOKEN")
230
  )
231
  model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
 
101
  "<|start_header_id|>assistant<|end_header_id|>\n"
102
  )
103
 
104
+ def render_html(label, text):
105
+ return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
106
+
107
+ def highlight_tokens(tokens, color_indices=None, color="green"):
108
+ highlighted = []
109
+ for j, tok in enumerate(tokens):
110
+ if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
111
+ continue
112
+ token_str = tokenizer.convert_tokens_to_string([tok])
113
+ if color_indices and j in color_indices:
114
+ highlighted.append(f'<span style="color:{color}">{token_str}</span>')
115
+ else:
116
+ highlighted.append(token_str)
117
+ return "".join(highlighted)
118
+
119
  # --- Inference Wrapper ---
120
  def diffusion_chat(question, max_it, pause_length, sharpness,
121
  clustering, noise_start, use_confidence_noising,
122
  noise_clipping, top_p, top_k):
123
+
124
  if question.strip() == "":
125
+ question = "What do you know about the city of Amsterdam?"
126
 
 
127
  prompt = format_chat_prompt(question)
128
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
129
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
130
  if answer_start is None:
131
+ yield render_html("Error", "Could not find Assistant marker in input.")
132
  return
 
 
 
 
 
133
 
134
+ input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
135
  ori_input_tokens = input_ids
136
+
137
  current_tokens, just_noised_indices = noisify_answer(
138
+ input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
139
+ )
140
+ yield render_html("Iteration 0 (initial noise)",
141
+ tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True))
142
  time.sleep(pause_length)
 
 
143
 
144
+ last_tokens = []
145
+ prev_tokens = []
146
 
147
  for i in range(max_it):
 
 
 
148
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
 
 
 
 
 
 
 
149
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
150
 
151
+ decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
152
+ diff_indices = [j for j in range(len(decoded)) if j >= len(prev_tokens) or decoded[j] != prev_tokens[j]]
153
+ prev_tokens = decoded
154
+
155
+ yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
156
+ highlight_tokens(decoded, diff_indices, color="green"))
 
 
 
 
 
 
 
 
 
157
  time.sleep(pause_length)
158
 
159
+ # Early stopping
160
  last_tokens.append(current_tokens)
161
  if len(last_tokens) > 3:
162
  last_tokens.pop(0)
163
+ if len(last_tokens) == 3 and len(set(map(tuple, last_tokens))) == 1:
164
+ yield render_html("Stopped early", f"After {i+1} iterations.")
165
  break
166
 
167
+ # Noising step
 
 
168
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
169
  if use_confidence_noising:
170
  noised_answer, just_noised_indices = confidence_guided_noising(
171
+ current_tokens, answer_start, confidences, noise_clipping,
172
+ threshold=threshold, noise_start=noise_start
173
  )
 
174
  else:
175
  noised_answer, just_noised_indices = noisify_answer(
176
+ current_tokens, answer_start, tokenizer,
177
+ threshold=threshold, clustering=clustering, noise_start=noise_start
178
  )
179
 
180
+ decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
181
+ red_indices = [j for j in range(len(decoded)) if (answer_start + j) in just_noised_indices]
182
+ yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
183
+ highlight_tokens(decoded, red_indices, color="red"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
186
 
187
+ # Final output
188
  answer_ids = current_tokens[answer_start:]
189
  try:
190
+ final_ids = answer_ids[:answer_ids.index(eos_token_id)]
 
191
  except ValueError:
192
  final_ids = answer_ids
193
+
 
194
  final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
195
+ yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
 
 
196
 
197
 
198
  # --- Gradio Interface ---
199
  print("Loading model...")
200
  ckpt_path = hf_hub_download(
201
  repo_id="ruurd/tini_model",
202
+ filename="diffusion-model-8B.pth",
203
  token=os.getenv("HF_TOKEN")
204
  )
205
  model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
infer.py CHANGED
@@ -6,6 +6,7 @@ import random
6
  import importlib
7
  import torch.nn as nn
8
  import os
 
9
 
10
  from transformers import AutoTokenizer
11
 
@@ -162,6 +163,55 @@ def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
162
  labels[0, :prompt_len] = -100
163
  loss = model(input_ids, labels=labels).loss
164
  return torch.exp(loss).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
 
6
  import importlib
7
  import torch.nn as nn
8
  import os
9
+ from IPython.display import display, HTML, Markdown, clear_output
10
 
11
  from transformers import AutoTokenizer
12
 
 
163
  labels[0, :prompt_len] = -100
164
  loss = model(input_ids, labels=labels).loss
165
  return torch.exp(loss).item()
166
+
167
+
168
+ def format_token_colored_inline(token_id, conf, tokenizer, mask_token_id=128000):
169
+ token_str = tokenizer.decode([token_id]).replace("\n", "<br>")
170
+ # token_str = token_str.replace(" ", "&nbsp;") # Preserve spaces for inline display
171
+ # token_str = token_str.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;") # Replace tabs with spaces
172
+
173
+ if token_id == mask_token_id:
174
+ color = "black"
175
+ else:
176
+ color = f"hsl({int(conf * 120)}, 100%, 25%)"
177
+
178
+ return f"<span style='color:{color}' title='Conf: {conf:.2f}'>{token_str}</span>"
179
+
180
+
181
+ def display_diffusion_output(i, max_it, question, ori_input_tokens, generated_tokens, confidences, answer_start, tokenizer):
182
+ clear_output(wait=True)
183
+ display(Markdown(f"### Iteration {i}/{max_it-1}"))
184
+ display(Markdown(f"**Question:** {tokenizer.decode(ori_input_tokens[:answer_start])}"))
185
+ mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
186
+
187
+ output_html = ''.join([
188
+ format_token_colored_inline(tok, conf, tokenizer, mask_token_id)
189
+ for tok, conf in zip(generated_tokens[answer_start:], confidences[answer_start:])
190
+ if tok != 128001 # skip EOT
191
+ ])
192
+ output_html = f"<div style='white-space: pre-wrap'>{output_html}</div>"
193
+
194
+ html = HTML(f"<b>Diffusion Output with Confidence:</b><br><div style='line-height:1.8; white-space: pre-wrap'>{output_html}</div>")
195
+ display(html)
196
+
197
+ return output_html
198
+
199
+ def save_html_colored_output(filename, html_content):
200
+ with open(filename, "w", encoding="utf-8") as f:
201
+ f.write(f"""
202
+ <html>
203
+ <head>
204
+ <meta charset="utf-8">
205
+ <style>
206
+ body {{ font-family: sans-serif; line-height: 1.6; }}
207
+ span {{ padding: 0 2px; }}
208
+ </style>
209
+ </head>
210
+ <body>
211
+ {html_content}
212
+ </body>
213
+ </html>
214
+ """)
215
 
216
 
217
  def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
requirements.txt CHANGED
@@ -6,3 +6,4 @@ accelerate>=0.24.1
6
  gradio>=4.10.0
7
  numpy
8
  load_dotenv
 
 
6
  gradio>=4.10.0
7
  numpy
8
  load_dotenv
9
+ ipython