Spaces:
Running
on
Zero
Running
on
Zero
Improve confidence guided noising and show number of tokens generated
Browse files
app.py
CHANGED
@@ -110,35 +110,48 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, clustering=0.5, noise
|
|
110 |
|
111 |
|
112 |
# Add new noising function
|
113 |
-
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start
|
114 |
noised = input_ids.copy()
|
115 |
answer_len = len(input_ids) - answer_start
|
116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
117 |
-
|
118 |
if num_to_noise == 0:
|
119 |
return noised
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
# Avoid zero-probability weights for selection
|
125 |
-
# If noise clipping == 1, all tokens have equal chance to be noised.
|
126 |
-
# If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
|
127 |
-
raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)
|
128 |
|
129 |
-
|
|
|
130 |
|
131 |
-
|
132 |
-
|
|
|
|
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
size=
|
137 |
replace=False,
|
138 |
-
p=
|
139 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
for idx in
|
142 |
noised[idx] = mask_token_id
|
143 |
|
144 |
return noised
|
@@ -256,11 +269,19 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
256 |
time.sleep(pause_length)
|
257 |
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
print(final_output)
|
263 |
-
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
|
|
264 |
|
265 |
# --- Gradio Interface ---
|
266 |
print("Loading model...")
|
@@ -271,11 +292,11 @@ demo = gr.Interface(
|
|
271 |
fn=diffusion_chat,
|
272 |
inputs=[
|
273 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
274 |
-
gr.Slider(1, 512, value=
|
275 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
|
276 |
-
gr.Slider(1.0, 20.0, value=
|
277 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
|
278 |
-
gr.Slider(0.0, 1.0, value=0.
|
279 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
280 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
|
281 |
|
|
|
110 |
|
111 |
|
112 |
# Add new noising function
|
113 |
+
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
|
114 |
noised = input_ids.copy()
|
115 |
answer_len = len(input_ids) - answer_start
|
116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
|
|
117 |
if num_to_noise == 0:
|
118 |
return noised
|
119 |
|
120 |
+
all_indices = np.arange(answer_start, len(input_ids))
|
121 |
+
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
|
122 |
+
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
num_non_eos_to_noise = int(num_to_noise * (len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)))
|
125 |
+
num_eos_to_noise = num_to_noise - num_non_eos_to_noise
|
126 |
|
127 |
+
# === Non-EOS sampling ===
|
128 |
+
raw_weights_non_eos = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
|
129 |
+
raw_weights_non_eos = np.clip(raw_weights_non_eos, a_min=noise_clipping, a_max=None)
|
130 |
+
weights_non_eos = raw_weights_non_eos / raw_weights_non_eos.sum() if raw_weights_non_eos.sum() > 0 else None
|
131 |
|
132 |
+
chosen_non_eos = rng.choice(
|
133 |
+
non_eos_indices,
|
134 |
+
size=min(num_non_eos_to_noise, len(non_eos_indices)),
|
135 |
replace=False,
|
136 |
+
p=weights_non_eos
|
137 |
+
) if weights_non_eos is not None else []
|
138 |
+
|
139 |
+
# === EOS sampling ===
|
140 |
+
if eos_indices:
|
141 |
+
raw_weights_eos = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
|
142 |
+
raw_weights_eos = np.clip(raw_weights_eos, a_min=noise_clipping, a_max=None)
|
143 |
+
weights_eos = raw_weights_eos / raw_weights_eos.sum() if raw_weights_eos.sum() > 0 else None
|
144 |
+
|
145 |
+
chosen_eos = rng.choice(
|
146 |
+
eos_indices,
|
147 |
+
size=min(num_eos_to_noise, len(eos_indices)),
|
148 |
+
replace=False,
|
149 |
+
p=weights_eos
|
150 |
+
) if weights_eos is not None else []
|
151 |
+
else:
|
152 |
+
chosen_eos = []
|
153 |
|
154 |
+
for idx in list(chosen_non_eos) + list(chosen_eos):
|
155 |
noised[idx] = mask_token_id
|
156 |
|
157 |
return noised
|
|
|
269 |
time.sleep(pause_length)
|
270 |
|
271 |
|
272 |
+
answer_ids = current_tokens[answer_start:]
|
273 |
+
try:
|
274 |
+
eos_index = answer_ids.index(eos_token_id)
|
275 |
+
final_ids = answer_ids[:eos_index]
|
276 |
+
except ValueError:
|
277 |
+
final_ids = answer_ids
|
278 |
+
|
279 |
+
num_tokens = len(final_ids)
|
280 |
+
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
281 |
+
|
282 |
print(final_output)
|
283 |
+
yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
284 |
+
|
285 |
|
286 |
# --- Gradio Interface ---
|
287 |
print("Loading model...")
|
|
|
292 |
fn=diffusion_chat,
|
293 |
inputs=[
|
294 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
295 |
+
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
|
296 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
|
297 |
+
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="↓ = more noising (sharpness)"),
|
298 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
|
299 |
+
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="↑ = more noise (noise start)"),
|
300 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
301 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
|
302 |
|