Spaces:
Running
on
Zero
Running
on
Zero
Improve white space display
Browse files- app.py +50 -76
- infer.py +50 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -101,131 +101,105 @@ def format_chat_prompt(question):
|
|
101 |
"<|start_header_id|>assistant<|end_header_id|>\n"
|
102 |
)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
# --- Inference Wrapper ---
|
105 |
def diffusion_chat(question, max_it, pause_length, sharpness,
|
106 |
clustering, noise_start, use_confidence_noising,
|
107 |
noise_clipping, top_p, top_k):
|
108 |
-
|
109 |
if question.strip() == "":
|
110 |
-
question =
|
111 |
|
112 |
-
print('started generation')
|
113 |
prompt = format_chat_prompt(question)
|
114 |
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
115 |
answer_start = find_answer_start(input_ids, assistant_marker_ids)
|
116 |
if answer_start is None:
|
117 |
-
yield "Error
|
118 |
return
|
119 |
-
|
120 |
-
if len(input_ids) < 256:
|
121 |
-
input_ids += [mask_token_id] * (256 - len(input_ids))
|
122 |
-
else:
|
123 |
-
input_ids = input_ids[:256]
|
124 |
|
|
|
125 |
ori_input_tokens = input_ids
|
|
|
126 |
current_tokens, just_noised_indices = noisify_answer(
|
127 |
-
|
128 |
-
|
129 |
-
yield
|
|
|
130 |
time.sleep(pause_length)
|
131 |
-
last_tokens = []
|
132 |
-
prev_decoded_tokens = []
|
133 |
|
134 |
-
|
|
|
135 |
|
136 |
for i in range(max_it):
|
137 |
-
print('Generating output')
|
138 |
-
|
139 |
-
# Model step
|
140 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
141 |
-
|
142 |
-
elapsed = time.time() - generation_start
|
143 |
-
remaining = pause_length - elapsed
|
144 |
-
if remaining > 0:
|
145 |
-
time.sleep(remaining)
|
146 |
-
|
147 |
-
# Save full output for noising step
|
148 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
continue
|
157 |
-
token_str = tokenizer.convert_tokens_to_string([tok])
|
158 |
-
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
159 |
-
highlighted.append(f'<span style="color:green">{token_str}</span>')
|
160 |
-
else:
|
161 |
-
highlighted.append(token_str)
|
162 |
-
|
163 |
-
prev_decoded_tokens = decoded_tokens
|
164 |
-
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
|
165 |
time.sleep(pause_length)
|
166 |
|
167 |
-
#
|
168 |
last_tokens.append(current_tokens)
|
169 |
if len(last_tokens) > 3:
|
170 |
last_tokens.pop(0)
|
171 |
-
if len(last_tokens) == 3 and last_tokens
|
172 |
-
yield
|
173 |
break
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
# --- NOISING STEP ---
|
178 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
179 |
if use_confidence_noising:
|
180 |
noised_answer, just_noised_indices = confidence_guided_noising(
|
181 |
-
current_tokens, answer_start, confidences, noise_clipping,
|
|
|
182 |
)
|
183 |
-
# just_noised_indices = []
|
184 |
else:
|
185 |
noised_answer, just_noised_indices = noisify_answer(
|
186 |
-
current_tokens, answer_start, tokenizer,
|
|
|
187 |
)
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
194 |
-
if tok_id == eos_token_id:
|
195 |
-
continue
|
196 |
-
token_str = tokenizer.convert_tokens_to_string([tok])
|
197 |
-
abs_idx = answer_start + j
|
198 |
-
if abs_idx in just_noised_indices:
|
199 |
-
highlighted.append(f'<span style="color:red">{token_str}</span>')
|
200 |
-
else:
|
201 |
-
highlighted.append(token_str)
|
202 |
-
|
203 |
-
# Compose full input again: prompt + noised answer
|
204 |
-
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
|
205 |
-
|
206 |
-
yield f"<b>Iteration {i+1}/{max_it} (before noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
|
207 |
-
generation_start = time.time()
|
208 |
|
|
|
209 |
|
|
|
210 |
answer_ids = current_tokens[answer_start:]
|
211 |
try:
|
212 |
-
|
213 |
-
final_ids = answer_ids[:eos_index]
|
214 |
except ValueError:
|
215 |
final_ids = answer_ids
|
216 |
-
|
217 |
-
num_tokens = len(final_ids)
|
218 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
219 |
-
|
220 |
-
print(final_output)
|
221 |
-
yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
222 |
|
223 |
|
224 |
# --- Gradio Interface ---
|
225 |
print("Loading model...")
|
226 |
ckpt_path = hf_hub_download(
|
227 |
repo_id="ruurd/tini_model",
|
228 |
-
filename="diffusion-model.pth",
|
229 |
token=os.getenv("HF_TOKEN")
|
230 |
)
|
231 |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
|
|
|
101 |
"<|start_header_id|>assistant<|end_header_id|>\n"
|
102 |
)
|
103 |
|
104 |
+
def render_html(label, text):
|
105 |
+
return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
|
106 |
+
|
107 |
+
def highlight_tokens(tokens, color_indices=None, color="green"):
|
108 |
+
highlighted = []
|
109 |
+
for j, tok in enumerate(tokens):
|
110 |
+
if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
|
111 |
+
continue
|
112 |
+
token_str = tokenizer.convert_tokens_to_string([tok])
|
113 |
+
if color_indices and j in color_indices:
|
114 |
+
highlighted.append(f'<span style="color:{color}">{token_str}</span>')
|
115 |
+
else:
|
116 |
+
highlighted.append(token_str)
|
117 |
+
return "".join(highlighted)
|
118 |
+
|
119 |
# --- Inference Wrapper ---
|
120 |
def diffusion_chat(question, max_it, pause_length, sharpness,
|
121 |
clustering, noise_start, use_confidence_noising,
|
122 |
noise_clipping, top_p, top_k):
|
123 |
+
|
124 |
if question.strip() == "":
|
125 |
+
question = "What do you know about the city of Amsterdam?"
|
126 |
|
|
|
127 |
prompt = format_chat_prompt(question)
|
128 |
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
129 |
answer_start = find_answer_start(input_ids, assistant_marker_ids)
|
130 |
if answer_start is None:
|
131 |
+
yield render_html("Error", "Could not find Assistant marker in input.")
|
132 |
return
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
|
135 |
ori_input_tokens = input_ids
|
136 |
+
|
137 |
current_tokens, just_noised_indices = noisify_answer(
|
138 |
+
input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
|
139 |
+
)
|
140 |
+
yield render_html("Iteration 0 (initial noise)",
|
141 |
+
tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True))
|
142 |
time.sleep(pause_length)
|
|
|
|
|
143 |
|
144 |
+
last_tokens = []
|
145 |
+
prev_tokens = []
|
146 |
|
147 |
for i in range(max_it):
|
|
|
|
|
|
|
148 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
150 |
|
151 |
+
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
152 |
+
diff_indices = [j for j in range(len(decoded)) if j >= len(prev_tokens) or decoded[j] != prev_tokens[j]]
|
153 |
+
prev_tokens = decoded
|
154 |
+
|
155 |
+
yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
|
156 |
+
highlight_tokens(decoded, diff_indices, color="green"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
time.sleep(pause_length)
|
158 |
|
159 |
+
# Early stopping
|
160 |
last_tokens.append(current_tokens)
|
161 |
if len(last_tokens) > 3:
|
162 |
last_tokens.pop(0)
|
163 |
+
if len(last_tokens) == 3 and len(set(map(tuple, last_tokens))) == 1:
|
164 |
+
yield render_html("Stopped early", f"After {i+1} iterations.")
|
165 |
break
|
166 |
|
167 |
+
# Noising step
|
|
|
|
|
168 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
169 |
if use_confidence_noising:
|
170 |
noised_answer, just_noised_indices = confidence_guided_noising(
|
171 |
+
current_tokens, answer_start, confidences, noise_clipping,
|
172 |
+
threshold=threshold, noise_start=noise_start
|
173 |
)
|
|
|
174 |
else:
|
175 |
noised_answer, just_noised_indices = noisify_answer(
|
176 |
+
current_tokens, answer_start, tokenizer,
|
177 |
+
threshold=threshold, clustering=clustering, noise_start=noise_start
|
178 |
)
|
179 |
|
180 |
+
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
181 |
+
red_indices = [j for j in range(len(decoded)) if (answer_start + j) in just_noised_indices]
|
182 |
+
yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
|
183 |
+
highlight_tokens(decoded, red_indices, color="red"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
|
186 |
|
187 |
+
# Final output
|
188 |
answer_ids = current_tokens[answer_start:]
|
189 |
try:
|
190 |
+
final_ids = answer_ids[:answer_ids.index(eos_token_id)]
|
|
|
191 |
except ValueError:
|
192 |
final_ids = answer_ids
|
193 |
+
|
|
|
194 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
195 |
+
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
|
|
|
|
|
196 |
|
197 |
|
198 |
# --- Gradio Interface ---
|
199 |
print("Loading model...")
|
200 |
ckpt_path = hf_hub_download(
|
201 |
repo_id="ruurd/tini_model",
|
202 |
+
filename="diffusion-model-8B.pth",
|
203 |
token=os.getenv("HF_TOKEN")
|
204 |
)
|
205 |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
|
infer.py
CHANGED
@@ -6,6 +6,7 @@ import random
|
|
6 |
import importlib
|
7 |
import torch.nn as nn
|
8 |
import os
|
|
|
9 |
|
10 |
from transformers import AutoTokenizer
|
11 |
|
@@ -162,6 +163,55 @@ def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
|
|
162 |
labels[0, :prompt_len] = -100
|
163 |
loss = model(input_ids, labels=labels).loss
|
164 |
return torch.exp(loss).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
|
167 |
def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
|
|
|
6 |
import importlib
|
7 |
import torch.nn as nn
|
8 |
import os
|
9 |
+
from IPython.display import display, HTML, Markdown, clear_output
|
10 |
|
11 |
from transformers import AutoTokenizer
|
12 |
|
|
|
163 |
labels[0, :prompt_len] = -100
|
164 |
loss = model(input_ids, labels=labels).loss
|
165 |
return torch.exp(loss).item()
|
166 |
+
|
167 |
+
|
168 |
+
def format_token_colored_inline(token_id, conf, tokenizer, mask_token_id=128000):
|
169 |
+
token_str = tokenizer.decode([token_id]).replace("\n", "<br>")
|
170 |
+
# token_str = token_str.replace(" ", " ") # Preserve spaces for inline display
|
171 |
+
# token_str = token_str.replace("\t", " ") # Replace tabs with spaces
|
172 |
+
|
173 |
+
if token_id == mask_token_id:
|
174 |
+
color = "black"
|
175 |
+
else:
|
176 |
+
color = f"hsl({int(conf * 120)}, 100%, 25%)"
|
177 |
+
|
178 |
+
return f"<span style='color:{color}' title='Conf: {conf:.2f}'>{token_str}</span>"
|
179 |
+
|
180 |
+
|
181 |
+
def display_diffusion_output(i, max_it, question, ori_input_tokens, generated_tokens, confidences, answer_start, tokenizer):
|
182 |
+
clear_output(wait=True)
|
183 |
+
display(Markdown(f"### Iteration {i}/{max_it-1}"))
|
184 |
+
display(Markdown(f"**Question:** {tokenizer.decode(ori_input_tokens[:answer_start])}"))
|
185 |
+
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
186 |
+
|
187 |
+
output_html = ''.join([
|
188 |
+
format_token_colored_inline(tok, conf, tokenizer, mask_token_id)
|
189 |
+
for tok, conf in zip(generated_tokens[answer_start:], confidences[answer_start:])
|
190 |
+
if tok != 128001 # skip EOT
|
191 |
+
])
|
192 |
+
output_html = f"<div style='white-space: pre-wrap'>{output_html}</div>"
|
193 |
+
|
194 |
+
html = HTML(f"<b>Diffusion Output with Confidence:</b><br><div style='line-height:1.8; white-space: pre-wrap'>{output_html}</div>")
|
195 |
+
display(html)
|
196 |
+
|
197 |
+
return output_html
|
198 |
+
|
199 |
+
def save_html_colored_output(filename, html_content):
|
200 |
+
with open(filename, "w", encoding="utf-8") as f:
|
201 |
+
f.write(f"""
|
202 |
+
<html>
|
203 |
+
<head>
|
204 |
+
<meta charset="utf-8">
|
205 |
+
<style>
|
206 |
+
body {{ font-family: sans-serif; line-height: 1.6; }}
|
207 |
+
span {{ padding: 0 2px; }}
|
208 |
+
</style>
|
209 |
+
</head>
|
210 |
+
<body>
|
211 |
+
{html_content}
|
212 |
+
</body>
|
213 |
+
</html>
|
214 |
+
""")
|
215 |
|
216 |
|
217 |
def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
|
requirements.txt
CHANGED
@@ -6,3 +6,4 @@ accelerate>=0.24.1
|
|
6 |
gradio>=4.10.0
|
7 |
numpy
|
8 |
load_dotenv
|
|
|
|
6 |
gradio>=4.10.0
|
7 |
numpy
|
8 |
load_dotenv
|
9 |
+
ipython
|