Simplified interface
Browse files- app.py +27 -40
- infer.py +0 -90
- requirements.txt +1 -0
app.py
CHANGED
@@ -16,7 +16,6 @@ from infer import (
|
|
16 |
find_answer_start,
|
17 |
get_noising_schedule,
|
18 |
noisify_answer,
|
19 |
-
generate_diffusion_text,
|
20 |
filter_logits,
|
21 |
confidence_guided_noising,
|
22 |
noisify_answer_without_remasking
|
@@ -39,17 +38,17 @@ rng = np.random.default_rng()
|
|
39 |
def generate_diffusion_text(input_ids, top_p, top_k):
|
40 |
with torch.no_grad():
|
41 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
42 |
-
|
|
|
43 |
logits = model(input_ids=input_tensor)["logits"]
|
44 |
-
|
|
|
45 |
logits = logits.clamp(min=-1e8, max=1e4)
|
46 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
47 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
48 |
-
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
49 |
-
assert (probs >= 0).all(), "Negative probs!"
|
50 |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|
51 |
-
|
52 |
-
# Extract confidence of selected tokens
|
53 |
conf = probs[range(len(sampled)), sampled].cpu().numpy()
|
54 |
return sampled, conf
|
55 |
|
@@ -79,10 +78,14 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
|
|
79 |
highlighted.append(tok_str)
|
80 |
return "".join(highlighted)
|
81 |
|
82 |
-
def diffusion_chat(question, max_it, pause_length
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
|
87 |
if question.strip() == "":
|
88 |
question = "What do you know about the city of Amsterdam?"
|
@@ -111,6 +114,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
111 |
unmasked_mask = [False] * len(current_tokens)
|
112 |
|
113 |
for i in range(max_it):
|
|
|
114 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
115 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
116 |
|
@@ -133,25 +137,15 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
133 |
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
|
134 |
yield render_html("Stopped early", f"After {i+1} iterations.")
|
135 |
break
|
136 |
-
|
137 |
# NOISING
|
138 |
-
if i < max_it-1:
|
139 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
elif use_permanent_unmasking:
|
146 |
-
noised_answer, just_noised_indices = noisify_answer_without_remasking(
|
147 |
-
current_tokens, answer_start, tokenizer, threshold=threshold,
|
148 |
-
noise_start=noise_start, unmasked_mask=unmasked_mask
|
149 |
-
)
|
150 |
-
else:
|
151 |
-
noised_answer, just_noised_indices = noisify_answer(
|
152 |
-
current_tokens, answer_start, tokenizer,
|
153 |
-
threshold=threshold, clustering=clustering, noise_start=noise_start
|
154 |
-
)
|
155 |
|
156 |
for idx in range(answer_start, len(current_tokens)):
|
157 |
if noised_answer[idx] != mask_token_id:
|
@@ -172,7 +166,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
172 |
final_ids = answer_ids
|
173 |
|
174 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
175 |
-
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
|
176 |
|
177 |
|
178 |
def is_running_on_spaces():
|
@@ -197,22 +191,15 @@ print("✅ Model loaded.")
|
|
197 |
vocab_size = len(tokenizer)
|
198 |
eos_token_id = tokenizer.eos_token_id
|
199 |
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
200 |
-
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id
|
201 |
|
202 |
demo = gr.Interface(
|
203 |
fn=diffusion_chat,
|
204 |
inputs=[
|
205 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
206 |
-
gr.
|
207 |
-
gr.Slider(
|
208 |
-
gr.Slider(
|
209 |
-
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
210 |
-
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
211 |
-
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
212 |
-
gr.Checkbox(value=False, label="Use permanent unmasking"),
|
213 |
-
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
|
214 |
-
gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
|
215 |
-
gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers")
|
216 |
],
|
217 |
outputs=[gr.HTML(label="Diffusion Output")],
|
218 |
title="Diffusion Language Model Chat",
|
|
|
16 |
find_answer_start,
|
17 |
get_noising_schedule,
|
18 |
noisify_answer,
|
|
|
19 |
filter_logits,
|
20 |
confidence_guided_noising,
|
21 |
noisify_answer_without_remasking
|
|
|
38 |
def generate_diffusion_text(input_ids, top_p, top_k):
|
39 |
with torch.no_grad():
|
40 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
41 |
+
|
42 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
43 |
logits = model(input_ids=input_tensor)["logits"]
|
44 |
+
|
45 |
+
logits = filter_logits(logits, top_k=top_k, top_p=top_p)
|
46 |
logits = logits.clamp(min=-1e8, max=1e4)
|
47 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
48 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
49 |
+
# assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
50 |
+
# assert (probs >= 0).all(), "Negative probs!"
|
51 |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|
|
|
|
|
52 |
conf = probs[range(len(sampled)), sampled].cpu().numpy()
|
53 |
return sampled, conf
|
54 |
|
|
|
78 |
highlighted.append(tok_str)
|
79 |
return "".join(highlighted)
|
80 |
|
81 |
+
def diffusion_chat(question, noising, max_it, pause_length):
|
82 |
+
|
83 |
+
pause_length = 0
|
84 |
+
sharpness = 3.0
|
85 |
+
noise_start = 0.5
|
86 |
+
top_p = 1.0
|
87 |
+
top_k = 10
|
88 |
+
clustering = False
|
89 |
|
90 |
if question.strip() == "":
|
91 |
question = "What do you know about the city of Amsterdam?"
|
|
|
114 |
unmasked_mask = [False] * len(current_tokens)
|
115 |
|
116 |
for i in range(max_it):
|
117 |
+
|
118 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
119 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
120 |
|
|
|
137 |
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
|
138 |
yield render_html("Stopped early", f"After {i+1} iterations.")
|
139 |
break
|
140 |
+
|
141 |
# NOISING
|
142 |
+
if i < max_it-1 and noising:
|
143 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
144 |
+
|
145 |
+
noised_answer, just_noised_indices = noisify_answer(
|
146 |
+
current_tokens, answer_start, tokenizer,
|
147 |
+
threshold=threshold, clustering=clustering, noise_start=noise_start
|
148 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
for idx in range(answer_start, len(current_tokens)):
|
151 |
if noised_answer[idx] != mask_token_id:
|
|
|
166 |
final_ids = answer_ids
|
167 |
|
168 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
169 |
+
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore
|
170 |
|
171 |
|
172 |
def is_running_on_spaces():
|
|
|
191 |
vocab_size = len(tokenizer)
|
192 |
eos_token_id = tokenizer.eos_token_id
|
193 |
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
194 |
+
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
|
195 |
|
196 |
demo = gr.Interface(
|
197 |
fn=diffusion_chat,
|
198 |
inputs=[
|
199 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
200 |
+
gr.Checkbox(label="Enable noising", value=True, info="If disabled, the model will not apply any intermediate noise."),
|
201 |
+
gr.Slider(1, 512, value=64, step=1, label="Increase the maximum number of iterations to run."),
|
202 |
+
gr.Slider(0, 5, value=0, step=0.01, label="Increase the pause between iterations to visualize the process.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
],
|
204 |
outputs=[gr.HTML(label="Diffusion Output")],
|
205 |
title="Diffusion Language Model Chat",
|
infer.py
CHANGED
@@ -190,26 +190,6 @@ def confidence_guided_noising(input_ids, answer_start, tokenizer, confidences, n
|
|
190 |
noised_indices = sorted(noised_indices)
|
191 |
return noised, noised_indices
|
192 |
|
193 |
-
def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
|
194 |
-
eos_token_id=None, eos_boost=0.0):
|
195 |
-
model.eval()
|
196 |
-
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
197 |
-
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
198 |
-
logits = model(input_ids=input_tensor)["logits"] # (1, seq_len, vocab_size)
|
199 |
-
|
200 |
-
# Optionally boost or suppress EOS token
|
201 |
-
if eos_token_id is not None and eos_boost != 0.0:
|
202 |
-
logits[:, :, eos_token_id] += eos_boost
|
203 |
-
|
204 |
-
# Filter and sample
|
205 |
-
filtered_logits = filter_logits(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
206 |
-
probs = F.softmax(filtered_logits, dim=-1).squeeze() # (seq_len, vocab_size)
|
207 |
-
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
208 |
-
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
209 |
-
confidences = probs.gather(1, sampled.unsqueeze(-1)).squeeze(-1)
|
210 |
-
|
211 |
-
return input_ids[:answer_start] + sampled[answer_start:].tolist(), confidences
|
212 |
-
|
213 |
|
214 |
def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
|
215 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
@@ -277,73 +257,3 @@ def save_html_colored_output(filename, html_content):
|
|
277 |
</body>
|
278 |
</html>
|
279 |
""")
|
280 |
-
|
281 |
-
|
282 |
-
def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
|
283 |
-
noising_sharpness=5.0, max_length=256, top_k=100, top_p=1.0,
|
284 |
-
temperature=1.0, eos_token_id = None, eos_boost = 0.0) -> str:
|
285 |
-
|
286 |
-
if eos_token_id is None:
|
287 |
-
eos_token_id = tokenizer.eos_token_id
|
288 |
-
# Format prompt with LLaMA 3 chat template
|
289 |
-
prompt = (
|
290 |
-
"<|begin_of_text|>\n"
|
291 |
-
"<|start_header_id|>system<|end_header_id|>\n"
|
292 |
-
"You are a helpful assistant.\n"
|
293 |
-
"<|eot_id|>\n"
|
294 |
-
"<|start_header_id|>user<|end_header_id|>\n"
|
295 |
-
f"{question.strip()}\n"
|
296 |
-
"<|start_header_id|>assistant<|end_header_id|>\n"
|
297 |
-
)
|
298 |
-
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
299 |
-
marker = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
|
300 |
-
|
301 |
-
def find_answer_start(ids, marker):
|
302 |
-
for i in range(len(ids) - len(marker) + 1):
|
303 |
-
if ids[i:i+len(marker)] == marker:
|
304 |
-
return i + len(marker)
|
305 |
-
return None
|
306 |
-
|
307 |
-
answer_start = find_answer_start(input_ids, marker)
|
308 |
-
if answer_start is None:
|
309 |
-
raise ValueError("Assistant marker not found in prompt.")
|
310 |
-
|
311 |
-
# Pad to max length
|
312 |
-
pad_token = tokenizer.eos_token_id
|
313 |
-
mask_token = tokenizer.encode("MASK", add_special_tokens=False)[0]
|
314 |
-
input_ids = input_ids[:max_length]
|
315 |
-
if len(input_ids) < max_length:
|
316 |
-
input_ids += [mask_token] * (max_length - len(input_ids))
|
317 |
-
|
318 |
-
ori_tokens = input_ids
|
319 |
-
current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token)
|
320 |
-
|
321 |
-
last_tokens = []
|
322 |
-
for step in range(max_it):
|
323 |
-
# Generate a new prediction
|
324 |
-
current_tokens, confidence_scores = generate_diffusion_text(
|
325 |
-
model, current_tokens, answer_start,
|
326 |
-
top_k=top_k, top_p=top_p, temperature=temperature,
|
327 |
-
eos_token_id=eos_token_id, eos_boost=eos_boost
|
328 |
-
)
|
329 |
-
|
330 |
-
# Display for debugging / tracking
|
331 |
-
display_diffusion_output(
|
332 |
-
step, max_it, question,
|
333 |
-
ori_tokens, current_tokens, confidence_scores,
|
334 |
-
answer_start, tokenizer
|
335 |
-
)
|
336 |
-
|
337 |
-
# Early stopping
|
338 |
-
last_tokens.append(current_tokens)
|
339 |
-
if len(last_tokens) > 4:
|
340 |
-
last_tokens.pop(0)
|
341 |
-
if all(t == last_tokens[0] for t in last_tokens):
|
342 |
-
break
|
343 |
-
|
344 |
-
# Re-apply noise for next iteration
|
345 |
-
if step < max_it - 1:
|
346 |
-
threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
|
347 |
-
current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token)
|
348 |
-
|
349 |
-
return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
|
|
|
190 |
noised_indices = sorted(noised_indices)
|
191 |
return noised, noised_indices
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
|
195 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
257 |
</body>
|
258 |
</html>
|
259 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -7,3 +7,4 @@ gradio>=4.10.0
|
|
7 |
numpy
|
8 |
load_dotenv
|
9 |
ipython
|
|
|
|
7 |
numpy
|
8 |
load_dotenv
|
9 |
ipython
|
10 |
+
spaces
|