Spaces:
Running on Zero

Ruurd commited on
Commit
a494446
·
1 Parent(s): 9756472

Fix generation

Browse files
Files changed (1) hide show
  1. app.py +25 -16
app.py CHANGED
@@ -130,7 +130,7 @@ def confidence_guided_noising(input_ids, answer_start, confidences, threshold, e
130
 
131
 
132
  @spaces.GPU
133
- def generate_diffusion_text(input_ids, answer_start):
134
  with torch.no_grad():
135
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
136
  logits = model(input_ids=input_tensor)["logits"]
@@ -170,15 +170,24 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
170
 
171
  for i in range(max_it):
172
  print('Generating output')
173
- generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
174
- current_tokens = generated_tokens
175
 
176
- # --- Decode and highlight changed tokens in GREEN ---
177
- decoded_ids = current_tokens[answer_start:]
178
- decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
 
 
 
179
 
 
 
 
 
 
180
  highlighted = []
181
  for j, tok in enumerate(decoded_tokens):
 
 
 
182
  token_str = tokenizer.convert_tokens_to_string([tok])
183
  if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
184
  highlighted.append(f'<span style="color:green">{token_str}</span>')
@@ -189,27 +198,29 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
189
  yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
190
  time.sleep(0.1)
191
 
192
- # --- Apply noising and highlight RED tokens ---
193
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
194
  if use_confidence_noising:
195
- current_tokens = confidence_guided_noising(
196
  generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping
197
  )
198
- just_noised_indices = [] # Optional: could extract from confidence scores
199
  else:
200
- current_tokens, just_noised_indices = noisify_answer(
201
  generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
202
  )
203
 
204
- decoded_ids = current_tokens[answer_start:]
205
- decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
 
206
 
 
 
207
  highlighted = []
208
  for j, tok in enumerate(decoded_tokens):
209
  tok_id = tokenizer.convert_tokens_to_ids(tok)
210
  if tok_id == eot_token_id:
211
- continue # Skip EOT tokens in display
212
-
213
  token_str = tokenizer.convert_tokens_to_string([tok])
214
  abs_idx = answer_start + j
215
  if abs_idx in just_noised_indices:
@@ -228,8 +239,6 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
228
  yield f"<b>Stopped early after {i+1} iterations.</b>"
229
  break
230
 
231
-
232
-
233
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
234
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
235
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
 
130
 
131
 
132
  @spaces.GPU
133
+ def generate_diffusion_text(input_ids):
134
  with torch.no_grad():
135
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
136
  logits = model(input_ids=input_tensor)["logits"]
 
170
 
171
  for i in range(max_it):
172
  print('Generating output')
 
 
173
 
174
+ # Compose full input: original prompt + current answer
175
+ full_input_tokens = ori_input_tokens[:answer_start] + current_tokens[answer_start:]
176
+ full_input_tokens = full_input_tokens[:256] + [pad_token] * max(0, 256 - len(full_input_tokens))
177
+
178
+ # Model step
179
+ generated_tokens, confidences = generate_diffusion_text(full_input_tokens)
180
 
181
+ # Save full output for noising step
182
+ current_tokens = generated_tokens
183
+
184
+ # --- GREEN HIGHLIGHT ---
185
+ decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
186
  highlighted = []
187
  for j, tok in enumerate(decoded_tokens):
188
+ tok_id = tokenizer.convert_tokens_to_ids(tok)
189
+ if tok_id == eot_token_id:
190
+ continue
191
  token_str = tokenizer.convert_tokens_to_string([tok])
192
  if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
193
  highlighted.append(f'<span style="color:green">{token_str}</span>')
 
198
  yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
199
  time.sleep(0.1)
200
 
201
+ # --- NOISING STEP ---
202
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
203
  if use_confidence_noising:
204
+ noised_answer = confidence_guided_noising(
205
  generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping
206
  )
207
+ just_noised_indices = []
208
  else:
209
+ noised_answer, just_noised_indices = noisify_answer(
210
  generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
211
  )
212
 
213
+ # Compose full input again: prompt + noised answer
214
+ current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
215
+ current_tokens = current_tokens[:256] + [pad_token] * max(0, 256 - len(current_tokens))
216
 
217
+ # --- RED HIGHLIGHT ---
218
+ decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
219
  highlighted = []
220
  for j, tok in enumerate(decoded_tokens):
221
  tok_id = tokenizer.convert_tokens_to_ids(tok)
222
  if tok_id == eot_token_id:
223
+ continue
 
224
  token_str = tokenizer.convert_tokens_to_string([tok])
225
  abs_idx = answer_start + j
226
  if abs_idx in just_noised_indices:
 
239
  yield f"<b>Stopped early after {i+1} iterations.</b>"
240
  break
241
 
 
 
242
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
243
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
244
  final_output = tokenizer.convert_tokens_to_string(final_tokens)