Spaces:
Running on Zero

Ruurd commited on
Commit
31e34c4
·
1 Parent(s): fb56411

Noisify without remasking

Browse files
Files changed (3) hide show
  1. .gitignore +4 -1
  2. app.py +57 -68
  3. infer.py +65 -0
.gitignore CHANGED
@@ -83,4 +83,7 @@ vendor/
83
  # Environment files #
84
  #####################
85
  .env
86
- .env.*
 
 
 
 
83
  # Environment files #
84
  #####################
85
  .env
86
+ .env.*
87
+
88
+ *.pem
89
+ *.pth
app.py CHANGED
@@ -6,7 +6,9 @@ import time
6
  from transformers import AutoTokenizer
7
  import os
8
  import importlib
 
9
  from huggingface_hub import hf_hub_download
 
10
  import spaces
11
  from dotenv import load_dotenv
12
  from infer import (
@@ -15,7 +17,9 @@ from infer import (
15
  get_noising_schedule,
16
  noisify_answer,
17
  generate_diffusion_text,
18
- filter_logits
 
 
19
  )
20
  from models import CustomTransformerModel
21
  from model_config import CustomTransformerConfig
@@ -31,48 +35,6 @@ if hf_token is None:
31
 
32
  rng = np.random.default_rng()
33
 
34
- # Add new noising function
35
- def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
36
- noised = input_ids.copy()
37
- answer_len = len(input_ids) - answer_start
38
- num_to_noise = int(threshold * answer_len * noise_start)
39
- if num_to_noise == 0:
40
- return noised, []
41
-
42
- all_indices = np.arange(answer_start, len(input_ids))
43
- eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
44
- non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
45
-
46
- # Proportionally split how many to noise
47
- num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5))
48
- num_eos_to_noise = num_to_noise - num_non_eos_to_noise
49
-
50
- noised_indices = []
51
-
52
- # --- Non-EOS ---
53
- if non_eos_indices:
54
- raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
55
- raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
56
- weights = raw_weights / raw_weights.sum()
57
-
58
- chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights)
59
- noised_indices.extend(chosen.tolist())
60
-
61
- # --- EOS ---
62
- if eos_indices and num_eos_to_noise > 0:
63
- raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
64
- raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
65
- weights = raw_weights / raw_weights.sum()
66
-
67
- chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights)
68
- noised_indices.extend(chosen.tolist())
69
-
70
- for idx in noised_indices:
71
- noised[idx] = mask_token_id
72
-
73
- noised_indices = sorted(noised_indices)
74
- return noised, noised_indices
75
-
76
  @spaces.GPU
77
  def generate_diffusion_text(input_ids, top_p, top_k):
78
  with torch.no_grad():
@@ -104,22 +66,23 @@ def format_chat_prompt(question):
104
  def render_html(label, text):
105
  return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
106
 
107
- def highlight_tokens(tokens, color_indices=None, color="green"):
 
108
  highlighted = []
109
  for j, tok in enumerate(tokens):
110
  if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
111
  continue
112
- token_str = tokenizer.convert_tokens_to_string([tok])
113
- if color_indices and j in color_indices:
114
- highlighted.append(f'<span style="color:{color}">{token_str}</span>')
115
  else:
116
- highlighted.append(token_str)
117
  return "".join(highlighted)
118
 
119
- # --- Inference Wrapper ---
120
  def diffusion_chat(question, max_it, pause_length, sharpness,
121
  clustering, noise_start, use_confidence_noising,
122
- noise_clipping, top_p, top_k):
 
123
 
124
  if question.strip() == "":
125
  question = "What do you know about the city of Amsterdam?"
@@ -134,53 +97,69 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
134
  input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
135
  ori_input_tokens = input_ids
136
 
 
137
  current_tokens, just_noised_indices = noisify_answer(
138
  input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
139
  )
140
  yield render_html("Iteration 0 (initial noise)",
141
- tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True))
142
  time.sleep(pause_length)
143
 
144
  last_tokens = []
145
- prev_tokens = []
 
 
146
 
147
  for i in range(max_it):
148
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
149
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
150
 
151
- decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
152
- diff_indices = [j for j in range(len(decoded)) if j >= len(prev_tokens) or decoded[j] != prev_tokens[j]]
153
- prev_tokens = decoded
 
 
 
 
154
 
155
  yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
156
- highlight_tokens(decoded, diff_indices, color="green"))
157
  time.sleep(pause_length)
158
 
159
  # Early stopping
160
  last_tokens.append(current_tokens)
161
  if len(last_tokens) > 3:
162
  last_tokens.pop(0)
163
- if len(last_tokens) == 3 and len(set(map(tuple, last_tokens))) == 1:
164
  yield render_html("Stopped early", f"After {i+1} iterations.")
165
  break
166
 
167
- # Noising step
168
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
169
  if use_confidence_noising:
170
  noised_answer, just_noised_indices = confidence_guided_noising(
171
- current_tokens, answer_start, confidences, noise_clipping,
172
  threshold=threshold, noise_start=noise_start
173
  )
 
 
 
 
 
174
  else:
175
  noised_answer, just_noised_indices = noisify_answer(
176
  current_tokens, answer_start, tokenizer,
177
  threshold=threshold, clustering=clustering, noise_start=noise_start
178
  )
 
 
 
 
 
 
179
 
180
- decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
181
- red_indices = [j for j in range(len(decoded)) if (answer_start + j) in just_noised_indices]
182
  yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
183
- highlight_tokens(decoded, red_indices, color="red"))
184
 
185
  current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
186
 
@@ -195,13 +174,22 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
195
  yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
196
 
197
 
198
- # --- Gradio Interface ---
 
 
199
  print("Loading model...")
200
- ckpt_path = hf_hub_download(
201
- repo_id="ruurd/tini_model",
202
- filename="diffusion-model-8B.pth",
203
- token=os.getenv("HF_TOKEN")
204
- )
 
 
 
 
 
 
 
205
  model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
206
  print("✅ Model loaded.")
207
 
@@ -220,6 +208,7 @@ demo = gr.Interface(
220
  gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
221
  gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"),
222
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
 
223
  gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
224
  gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"),
225
  gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers")
 
6
  from transformers import AutoTokenizer
7
  import os
8
  import importlib
9
+ import os
10
  from huggingface_hub import hf_hub_download
11
+
12
  import spaces
13
  from dotenv import load_dotenv
14
  from infer import (
 
17
  get_noising_schedule,
18
  noisify_answer,
19
  generate_diffusion_text,
20
+ filter_logits,
21
+ confidence_guided_noising,
22
+ noisify_answer_without_remasking
23
  )
24
  from models import CustomTransformerModel
25
  from model_config import CustomTransformerConfig
 
35
 
36
  rng = np.random.default_rng()
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @spaces.GPU
39
  def generate_diffusion_text(input_ids, top_p, top_k):
40
  with torch.no_grad():
 
66
  def render_html(label, text):
67
  return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
68
 
69
+ def highlight_tokens(token_ids, answer_start, changed_indices, color):
70
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
71
  highlighted = []
72
  for j, tok in enumerate(tokens):
73
  if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
74
  continue
75
+ tok_str = tokenizer.convert_tokens_to_string([tok])
76
+ if (answer_start + j) in changed_indices:
77
+ highlighted.append(f'<span style="color:{color}">{tok_str}</span>')
78
  else:
79
+ highlighted.append(tok_str)
80
  return "".join(highlighted)
81
 
 
82
  def diffusion_chat(question, max_it, pause_length, sharpness,
83
  clustering, noise_start, use_confidence_noising,
84
+ use_permanent_unmasking, noise_clipping, top_p,
85
+ top_k):
86
 
87
  if question.strip() == "":
88
  question = "What do you know about the city of Amsterdam?"
 
97
  input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
98
  ori_input_tokens = input_ids
99
 
100
+ # Initial noising
101
  current_tokens, just_noised_indices = noisify_answer(
102
  input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
103
  )
104
  yield render_html("Iteration 0 (initial noise)",
105
+ highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
106
  time.sleep(pause_length)
107
 
108
  last_tokens = []
109
+ prev_decoded = []
110
+
111
+ unmasked_mask = [False] * len(current_tokens)
112
 
113
  for i in range(max_it):
114
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
115
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
116
 
117
+ # GREEN highlighting: compare to previous tokens
118
+ new_decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
119
+ diff_indices = {
120
+ answer_start + j for j, tok in enumerate(new_decoded)
121
+ if j >= len(prev_decoded) or tok != prev_decoded[j]
122
+ }
123
+ prev_decoded = new_decoded
124
 
125
  yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
126
+ highlight_tokens(current_tokens[answer_start:], answer_start, diff_indices, color="green"))
127
  time.sleep(pause_length)
128
 
129
  # Early stopping
130
  last_tokens.append(current_tokens)
131
  if len(last_tokens) > 3:
132
  last_tokens.pop(0)
133
+ if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
134
  yield render_html("Stopped early", f"After {i+1} iterations.")
135
  break
136
 
137
+ # NOISING
138
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
139
  if use_confidence_noising:
140
  noised_answer, just_noised_indices = confidence_guided_noising(
141
+ current_tokens, answer_start, tokenizer, confidences, noise_clipping,
142
  threshold=threshold, noise_start=noise_start
143
  )
144
+ elif use_permanent_unmasking:
145
+ noised_answer, just_noised_indices = noisify_answer_without_remasking(
146
+ current_tokens, answer_start, tokenizer, threshold=threshold,
147
+ noise_start=noise_start, unmasked_mask=unmasked_mask
148
+ )
149
  else:
150
  noised_answer, just_noised_indices = noisify_answer(
151
  current_tokens, answer_start, tokenizer,
152
  threshold=threshold, clustering=clustering, noise_start=noise_start
153
  )
154
+
155
+ for idx in range(answer_start, len(current_tokens)):
156
+ if noised_answer[idx] != mask_token_id:
157
+ unmasked_mask[idx] = True
158
+
159
+
160
 
 
 
161
  yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
162
+ highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
163
 
164
  current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
165
 
 
174
  yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
175
 
176
 
177
+ def is_running_on_spaces():
178
+ return os.getenv("SPACE_ID") is not None
179
+
180
  print("Loading model...")
181
+
182
+ if is_running_on_spaces():
183
+ # Load from Hugging Face Hub
184
+ ckpt_path = hf_hub_download(
185
+ repo_id="ruurd/tini_model",
186
+ filename="diffusion-model-8B.pth",
187
+ token=os.getenv("HF_TOKEN")
188
+ )
189
+ else:
190
+ # Load from local path
191
+ ckpt_path = "diffusion-model-3B.pth" # change to your actual local path
192
+
193
  model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
194
  print("✅ Model loaded.")
195
 
 
208
  gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
209
  gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"),
210
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
211
+ gr.Checkbox(value=False, label="Use permanent unmasking"),
212
  gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
213
  gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"),
214
  gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers")
infer.py CHANGED
@@ -125,6 +125,71 @@ def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering
125
 
126
  import torch.nn.functional as F
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
129
  eos_token_id=None, eos_boost=0.0):
130
  model.eval()
 
125
 
126
  import torch.nn.functional as F
127
 
128
+ def noisify_answer_without_remasking(input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0, unmasked_mask=None):
129
+ noised = input_ids.copy()
130
+ mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
131
+
132
+ eligible_indices = list(range(answer_start, len(noised)))
133
+
134
+ if unmasked_mask is not None:
135
+ eligible_indices = [i for i in eligible_indices if not unmasked_mask[i]]
136
+
137
+ answer_len = len(noised) - answer_start
138
+ num_to_noise = int(threshold * answer_len * noise_start)
139
+
140
+ if num_to_noise == 0 or len(eligible_indices) == 0:
141
+ return noised, []
142
+
143
+ selected = rng.choice(eligible_indices, size=num_to_noise, replace=False).tolist()
144
+
145
+ for idx in selected:
146
+ noised[idx] = mask_token_id
147
+
148
+ return noised, selected
149
+
150
+ def confidence_guided_noising(input_ids, answer_start, tokenizer, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
151
+ noised = input_ids.copy()
152
+ answer_len = len(input_ids) - answer_start
153
+ num_to_noise = int(threshold * answer_len * noise_start)
154
+ mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
155
+ eos_token_id = tokenizer.eos_token_id
156
+ if num_to_noise == 0:
157
+ return noised, []
158
+
159
+ all_indices = np.arange(answer_start, len(input_ids))
160
+ eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
161
+ non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
162
+
163
+ # Proportionally split how many to noise
164
+ num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5))
165
+ num_eos_to_noise = num_to_noise - num_non_eos_to_noise
166
+
167
+ noised_indices = []
168
+
169
+ # --- Non-EOS ---
170
+ if non_eos_indices:
171
+ raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
172
+ raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
173
+ weights = raw_weights / raw_weights.sum()
174
+
175
+ chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights)
176
+ noised_indices.extend(chosen.tolist())
177
+
178
+ # --- EOS ---
179
+ if eos_indices and num_eos_to_noise > 0:
180
+ raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
181
+ raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
182
+ weights = raw_weights / raw_weights.sum()
183
+
184
+ chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights)
185
+ noised_indices.extend(chosen.tolist())
186
+
187
+ for idx in noised_indices:
188
+ noised[idx] = mask_token_id
189
+
190
+ noised_indices = sorted(noised_indices)
191
+ return noised, noised_indices
192
+
193
  def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
194
  eos_token_id=None, eos_boost=0.0):
195
  model.eval()