Spaces:
Running
on
Zero
Running
on
Zero
Noisify without remasking
Browse files- .gitignore +4 -1
- app.py +57 -68
- infer.py +65 -0
.gitignore
CHANGED
@@ -83,4 +83,7 @@ vendor/
|
|
83 |
# Environment files #
|
84 |
#####################
|
85 |
.env
|
86 |
-
.env.*
|
|
|
|
|
|
|
|
83 |
# Environment files #
|
84 |
#####################
|
85 |
.env
|
86 |
+
.env.*
|
87 |
+
|
88 |
+
*.pem
|
89 |
+
*.pth
|
app.py
CHANGED
@@ -6,7 +6,9 @@ import time
|
|
6 |
from transformers import AutoTokenizer
|
7 |
import os
|
8 |
import importlib
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
|
|
10 |
import spaces
|
11 |
from dotenv import load_dotenv
|
12 |
from infer import (
|
@@ -15,7 +17,9 @@ from infer import (
|
|
15 |
get_noising_schedule,
|
16 |
noisify_answer,
|
17 |
generate_diffusion_text,
|
18 |
-
filter_logits
|
|
|
|
|
19 |
)
|
20 |
from models import CustomTransformerModel
|
21 |
from model_config import CustomTransformerConfig
|
@@ -31,48 +35,6 @@ if hf_token is None:
|
|
31 |
|
32 |
rng = np.random.default_rng()
|
33 |
|
34 |
-
# Add new noising function
|
35 |
-
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
|
36 |
-
noised = input_ids.copy()
|
37 |
-
answer_len = len(input_ids) - answer_start
|
38 |
-
num_to_noise = int(threshold * answer_len * noise_start)
|
39 |
-
if num_to_noise == 0:
|
40 |
-
return noised, []
|
41 |
-
|
42 |
-
all_indices = np.arange(answer_start, len(input_ids))
|
43 |
-
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
|
44 |
-
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
|
45 |
-
|
46 |
-
# Proportionally split how many to noise
|
47 |
-
num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5))
|
48 |
-
num_eos_to_noise = num_to_noise - num_non_eos_to_noise
|
49 |
-
|
50 |
-
noised_indices = []
|
51 |
-
|
52 |
-
# --- Non-EOS ---
|
53 |
-
if non_eos_indices:
|
54 |
-
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
|
55 |
-
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
|
56 |
-
weights = raw_weights / raw_weights.sum()
|
57 |
-
|
58 |
-
chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights)
|
59 |
-
noised_indices.extend(chosen.tolist())
|
60 |
-
|
61 |
-
# --- EOS ---
|
62 |
-
if eos_indices and num_eos_to_noise > 0:
|
63 |
-
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
|
64 |
-
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
|
65 |
-
weights = raw_weights / raw_weights.sum()
|
66 |
-
|
67 |
-
chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights)
|
68 |
-
noised_indices.extend(chosen.tolist())
|
69 |
-
|
70 |
-
for idx in noised_indices:
|
71 |
-
noised[idx] = mask_token_id
|
72 |
-
|
73 |
-
noised_indices = sorted(noised_indices)
|
74 |
-
return noised, noised_indices
|
75 |
-
|
76 |
@spaces.GPU
|
77 |
def generate_diffusion_text(input_ids, top_p, top_k):
|
78 |
with torch.no_grad():
|
@@ -104,22 +66,23 @@ def format_chat_prompt(question):
|
|
104 |
def render_html(label, text):
|
105 |
return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
|
106 |
|
107 |
-
def highlight_tokens(
|
|
|
108 |
highlighted = []
|
109 |
for j, tok in enumerate(tokens):
|
110 |
if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
|
111 |
continue
|
112 |
-
|
113 |
-
if
|
114 |
-
highlighted.append(f'<span style="color:{color}">{
|
115 |
else:
|
116 |
-
highlighted.append(
|
117 |
return "".join(highlighted)
|
118 |
|
119 |
-
# --- Inference Wrapper ---
|
120 |
def diffusion_chat(question, max_it, pause_length, sharpness,
|
121 |
clustering, noise_start, use_confidence_noising,
|
122 |
-
noise_clipping, top_p,
|
|
|
123 |
|
124 |
if question.strip() == "":
|
125 |
question = "What do you know about the city of Amsterdam?"
|
@@ -134,53 +97,69 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
134 |
input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
|
135 |
ori_input_tokens = input_ids
|
136 |
|
|
|
137 |
current_tokens, just_noised_indices = noisify_answer(
|
138 |
input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
|
139 |
)
|
140 |
yield render_html("Iteration 0 (initial noise)",
|
141 |
-
|
142 |
time.sleep(pause_length)
|
143 |
|
144 |
last_tokens = []
|
145 |
-
|
|
|
|
|
146 |
|
147 |
for i in range(max_it):
|
148 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
149 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
|
155 |
yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
|
156 |
-
highlight_tokens(
|
157 |
time.sleep(pause_length)
|
158 |
|
159 |
# Early stopping
|
160 |
last_tokens.append(current_tokens)
|
161 |
if len(last_tokens) > 3:
|
162 |
last_tokens.pop(0)
|
163 |
-
if len(last_tokens) == 3 and
|
164 |
yield render_html("Stopped early", f"After {i+1} iterations.")
|
165 |
break
|
166 |
|
167 |
-
#
|
168 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
169 |
if use_confidence_noising:
|
170 |
noised_answer, just_noised_indices = confidence_guided_noising(
|
171 |
-
current_tokens, answer_start, confidences, noise_clipping,
|
172 |
threshold=threshold, noise_start=noise_start
|
173 |
)
|
|
|
|
|
|
|
|
|
|
|
174 |
else:
|
175 |
noised_answer, just_noised_indices = noisify_answer(
|
176 |
current_tokens, answer_start, tokenizer,
|
177 |
threshold=threshold, clustering=clustering, noise_start=noise_start
|
178 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
-
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
181 |
-
red_indices = [j for j in range(len(decoded)) if (answer_start + j) in just_noised_indices]
|
182 |
yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
|
183 |
-
highlight_tokens(
|
184 |
|
185 |
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
|
186 |
|
@@ -195,13 +174,22 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
195 |
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
|
196 |
|
197 |
|
198 |
-
|
|
|
|
|
199 |
print("Loading model...")
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
|
206 |
print("✅ Model loaded.")
|
207 |
|
@@ -220,6 +208,7 @@ demo = gr.Interface(
|
|
220 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
221 |
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
222 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
|
|
223 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
|
224 |
gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"),
|
225 |
gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers")
|
|
|
6 |
from transformers import AutoTokenizer
|
7 |
import os
|
8 |
import importlib
|
9 |
+
import os
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
import spaces
|
13 |
from dotenv import load_dotenv
|
14 |
from infer import (
|
|
|
17 |
get_noising_schedule,
|
18 |
noisify_answer,
|
19 |
generate_diffusion_text,
|
20 |
+
filter_logits,
|
21 |
+
confidence_guided_noising,
|
22 |
+
noisify_answer_without_remasking
|
23 |
)
|
24 |
from models import CustomTransformerModel
|
25 |
from model_config import CustomTransformerConfig
|
|
|
35 |
|
36 |
rng = np.random.default_rng()
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
@spaces.GPU
|
39 |
def generate_diffusion_text(input_ids, top_p, top_k):
|
40 |
with torch.no_grad():
|
|
|
66 |
def render_html(label, text):
|
67 |
return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
|
68 |
|
69 |
+
def highlight_tokens(token_ids, answer_start, changed_indices, color):
|
70 |
+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
71 |
highlighted = []
|
72 |
for j, tok in enumerate(tokens):
|
73 |
if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
|
74 |
continue
|
75 |
+
tok_str = tokenizer.convert_tokens_to_string([tok])
|
76 |
+
if (answer_start + j) in changed_indices:
|
77 |
+
highlighted.append(f'<span style="color:{color}">{tok_str}</span>')
|
78 |
else:
|
79 |
+
highlighted.append(tok_str)
|
80 |
return "".join(highlighted)
|
81 |
|
|
|
82 |
def diffusion_chat(question, max_it, pause_length, sharpness,
|
83 |
clustering, noise_start, use_confidence_noising,
|
84 |
+
use_permanent_unmasking, noise_clipping, top_p,
|
85 |
+
top_k):
|
86 |
|
87 |
if question.strip() == "":
|
88 |
question = "What do you know about the city of Amsterdam?"
|
|
|
97 |
input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
|
98 |
ori_input_tokens = input_ids
|
99 |
|
100 |
+
# Initial noising
|
101 |
current_tokens, just_noised_indices = noisify_answer(
|
102 |
input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
|
103 |
)
|
104 |
yield render_html("Iteration 0 (initial noise)",
|
105 |
+
highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
|
106 |
time.sleep(pause_length)
|
107 |
|
108 |
last_tokens = []
|
109 |
+
prev_decoded = []
|
110 |
+
|
111 |
+
unmasked_mask = [False] * len(current_tokens)
|
112 |
|
113 |
for i in range(max_it):
|
114 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
115 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
116 |
|
117 |
+
# GREEN highlighting: compare to previous tokens
|
118 |
+
new_decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
119 |
+
diff_indices = {
|
120 |
+
answer_start + j for j, tok in enumerate(new_decoded)
|
121 |
+
if j >= len(prev_decoded) or tok != prev_decoded[j]
|
122 |
+
}
|
123 |
+
prev_decoded = new_decoded
|
124 |
|
125 |
yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
|
126 |
+
highlight_tokens(current_tokens[answer_start:], answer_start, diff_indices, color="green"))
|
127 |
time.sleep(pause_length)
|
128 |
|
129 |
# Early stopping
|
130 |
last_tokens.append(current_tokens)
|
131 |
if len(last_tokens) > 3:
|
132 |
last_tokens.pop(0)
|
133 |
+
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
|
134 |
yield render_html("Stopped early", f"After {i+1} iterations.")
|
135 |
break
|
136 |
|
137 |
+
# NOISING
|
138 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
139 |
if use_confidence_noising:
|
140 |
noised_answer, just_noised_indices = confidence_guided_noising(
|
141 |
+
current_tokens, answer_start, tokenizer, confidences, noise_clipping,
|
142 |
threshold=threshold, noise_start=noise_start
|
143 |
)
|
144 |
+
elif use_permanent_unmasking:
|
145 |
+
noised_answer, just_noised_indices = noisify_answer_without_remasking(
|
146 |
+
current_tokens, answer_start, tokenizer, threshold=threshold,
|
147 |
+
noise_start=noise_start, unmasked_mask=unmasked_mask
|
148 |
+
)
|
149 |
else:
|
150 |
noised_answer, just_noised_indices = noisify_answer(
|
151 |
current_tokens, answer_start, tokenizer,
|
152 |
threshold=threshold, clustering=clustering, noise_start=noise_start
|
153 |
)
|
154 |
+
|
155 |
+
for idx in range(answer_start, len(current_tokens)):
|
156 |
+
if noised_answer[idx] != mask_token_id:
|
157 |
+
unmasked_mask[idx] = True
|
158 |
+
|
159 |
+
|
160 |
|
|
|
|
|
161 |
yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
|
162 |
+
highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
|
163 |
|
164 |
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
|
165 |
|
|
|
174 |
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
|
175 |
|
176 |
|
177 |
+
def is_running_on_spaces():
|
178 |
+
return os.getenv("SPACE_ID") is not None
|
179 |
+
|
180 |
print("Loading model...")
|
181 |
+
|
182 |
+
if is_running_on_spaces():
|
183 |
+
# Load from Hugging Face Hub
|
184 |
+
ckpt_path = hf_hub_download(
|
185 |
+
repo_id="ruurd/tini_model",
|
186 |
+
filename="diffusion-model-8B.pth",
|
187 |
+
token=os.getenv("HF_TOKEN")
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
# Load from local path
|
191 |
+
ckpt_path = "diffusion-model-3B.pth" # change to your actual local path
|
192 |
+
|
193 |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
|
194 |
print("✅ Model loaded.")
|
195 |
|
|
|
208 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
209 |
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
210 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
211 |
+
gr.Checkbox(value=False, label="Use permanent unmasking"),
|
212 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
|
213 |
gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"),
|
214 |
gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers")
|
infer.py
CHANGED
@@ -125,6 +125,71 @@ def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering
|
|
125 |
|
126 |
import torch.nn.functional as F
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
|
129 |
eos_token_id=None, eos_boost=0.0):
|
130 |
model.eval()
|
|
|
125 |
|
126 |
import torch.nn.functional as F
|
127 |
|
128 |
+
def noisify_answer_without_remasking(input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0, unmasked_mask=None):
|
129 |
+
noised = input_ids.copy()
|
130 |
+
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
131 |
+
|
132 |
+
eligible_indices = list(range(answer_start, len(noised)))
|
133 |
+
|
134 |
+
if unmasked_mask is not None:
|
135 |
+
eligible_indices = [i for i in eligible_indices if not unmasked_mask[i]]
|
136 |
+
|
137 |
+
answer_len = len(noised) - answer_start
|
138 |
+
num_to_noise = int(threshold * answer_len * noise_start)
|
139 |
+
|
140 |
+
if num_to_noise == 0 or len(eligible_indices) == 0:
|
141 |
+
return noised, []
|
142 |
+
|
143 |
+
selected = rng.choice(eligible_indices, size=num_to_noise, replace=False).tolist()
|
144 |
+
|
145 |
+
for idx in selected:
|
146 |
+
noised[idx] = mask_token_id
|
147 |
+
|
148 |
+
return noised, selected
|
149 |
+
|
150 |
+
def confidence_guided_noising(input_ids, answer_start, tokenizer, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
|
151 |
+
noised = input_ids.copy()
|
152 |
+
answer_len = len(input_ids) - answer_start
|
153 |
+
num_to_noise = int(threshold * answer_len * noise_start)
|
154 |
+
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
155 |
+
eos_token_id = tokenizer.eos_token_id
|
156 |
+
if num_to_noise == 0:
|
157 |
+
return noised, []
|
158 |
+
|
159 |
+
all_indices = np.arange(answer_start, len(input_ids))
|
160 |
+
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
|
161 |
+
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
|
162 |
+
|
163 |
+
# Proportionally split how many to noise
|
164 |
+
num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5))
|
165 |
+
num_eos_to_noise = num_to_noise - num_non_eos_to_noise
|
166 |
+
|
167 |
+
noised_indices = []
|
168 |
+
|
169 |
+
# --- Non-EOS ---
|
170 |
+
if non_eos_indices:
|
171 |
+
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
|
172 |
+
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
|
173 |
+
weights = raw_weights / raw_weights.sum()
|
174 |
+
|
175 |
+
chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights)
|
176 |
+
noised_indices.extend(chosen.tolist())
|
177 |
+
|
178 |
+
# --- EOS ---
|
179 |
+
if eos_indices and num_eos_to_noise > 0:
|
180 |
+
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
|
181 |
+
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
|
182 |
+
weights = raw_weights / raw_weights.sum()
|
183 |
+
|
184 |
+
chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights)
|
185 |
+
noised_indices.extend(chosen.tolist())
|
186 |
+
|
187 |
+
for idx in noised_indices:
|
188 |
+
noised[idx] = mask_token_id
|
189 |
+
|
190 |
+
noised_indices = sorted(noised_indices)
|
191 |
+
return noised, noised_indices
|
192 |
+
|
193 |
def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
|
194 |
eos_token_id=None, eos_boost=0.0):
|
195 |
model.eval()
|