Spaces:
Running on Zero

Ruurd commited on
Commit
20ff8b2
·
1 Parent(s): bd9baef

Fix loading of model and tokenizer

Browse files

Change location of inference functions

Files changed (2) hide show
  1. app.py +8 -78
  2. infer.py +28 -16
app.py CHANGED
@@ -29,56 +29,8 @@ hf_token = os.getenv("HF_TOKEN")
29
  if hf_token is None:
30
  raise ValueError("HF_TOKEN is not set")
31
 
32
- # --- Load tokenizer ---
33
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B", use_fast=True, token=hf_token)
34
- vocab_size = len(tokenizer)
35
- eos_token_id = tokenizer.eos_token_id
36
- mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
37
- assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
38
-
39
  rng = np.random.default_rng()
40
 
41
- # --- Utility Functions ---
42
- def decode_tokens_safe(token_ids):
43
- return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
44
-
45
- def find_answer_start(input_ids, marker_ids):
46
- for i in range(len(input_ids) - len(marker_ids) + 1):
47
- if input_ids[i:i + len(marker_ids)] == marker_ids:
48
- return i + len(marker_ids)
49
- return None
50
-
51
- def get_noising_schedule(i, max_it, sharpness=5.0):
52
- x = i / max_it
53
- return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
54
-
55
- def noisify_answer(input_ids, answer_start, threshold=1.0, clustering=0.5, noise_start = 1.0):
56
- noised = input_ids.copy()
57
- answer_len = len(noised) - answer_start
58
- num_to_noise = int(threshold * answer_len * noise_start)
59
- mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
60
-
61
- if num_to_noise == 0:
62
- return noised, []
63
-
64
- num_clusters = max(1, int((1 - clustering) * num_to_noise))
65
- cluster_size = max(1, int(num_to_noise / num_clusters))
66
-
67
- noised_indices = set()
68
- for _ in range(num_clusters):
69
- center = rng.integers(answer_start, len(noised))
70
- span_start = max(answer_start, center - cluster_size // 2)
71
- span_end = min(len(noised), span_start + cluster_size)
72
- noised_indices.update(range(span_start, span_end))
73
-
74
- noised_indices = sorted(list(noised_indices))[:num_to_noise]
75
-
76
- for idx in noised_indices:
77
- noised[idx] = mask_token_id
78
-
79
- return noised, noised_indices
80
-
81
-
82
  # Add new noising function
83
  def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
84
  noised = input_ids.copy()
@@ -121,33 +73,6 @@ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clippi
121
  noised_indices = sorted(noised_indices)
122
  return noised, noised_indices
123
 
124
- def filter_logits(logits, top_k=0, top_p=0.0):
125
- """Filter logits per position for top-k / nucleus (top-p) sampling."""
126
- logits = logits.clone() # don't modify in-place
127
- batch_size, seq_len, vocab_size = logits.shape
128
-
129
- for i in range(seq_len):
130
- token_logits = logits[0, i]
131
-
132
- if top_k > 0:
133
- top_values, _ = torch.topk(token_logits, top_k)
134
- threshold = top_values[-1]
135
- token_logits[token_logits < threshold] = float("-inf")
136
-
137
- if top_p > 0.0:
138
- sorted_logits, sorted_indices = torch.sort(token_logits, descending=True)
139
- cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
140
-
141
- sorted_indices_to_remove = cumulative_probs > top_p
142
- sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
143
- sorted_indices_to_remove[0] = 0 # always keep at least 1 token
144
-
145
- token_logits[sorted_indices[sorted_indices_to_remove]] = float("-inf")
146
-
147
- logits[0, i] = token_logits
148
-
149
- return logits
150
-
151
  @spaces.GPU
152
  def generate_diffusion_text(input_ids, top_p, top_k):
153
  with torch.no_grad():
@@ -198,7 +123,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
198
 
199
  ori_input_tokens = input_ids
200
  current_tokens, just_noised_indices = noisify_answer(
201
- input_ids, answer_start, threshold=1.0, clustering=clustering, noise_start = 1.0,
202
  )
203
  yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>')
204
  time.sleep(pause_length)
@@ -257,7 +182,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
257
  # just_noised_indices = []
258
  else:
259
  noised_answer, just_noised_indices = noisify_answer(
260
- current_tokens, answer_start, threshold=threshold, clustering=clustering, noise_start = noise_start,
261
  )
262
 
263
  # --- RED HIGHLIGHT ---
@@ -302,9 +227,14 @@ ckpt_path = hf_hub_download(
302
  filename="diffusion-model.pth",
303
  token=os.getenv("HF_TOKEN")
304
  )
305
- model = load_trained_model(checkpoint_path=ckpt_path)
306
  print("✅ Model loaded.")
307
 
 
 
 
 
 
308
  demo = gr.Interface(
309
  fn=diffusion_chat,
310
  inputs=[
 
29
  if hf_token is None:
30
  raise ValueError("HF_TOKEN is not set")
31
 
 
 
 
 
 
 
 
32
  rng = np.random.default_rng()
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Add new noising function
35
  def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
36
  noised = input_ids.copy()
 
73
  noised_indices = sorted(noised_indices)
74
  return noised, noised_indices
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @spaces.GPU
77
  def generate_diffusion_text(input_ids, top_p, top_k):
78
  with torch.no_grad():
 
123
 
124
  ori_input_tokens = input_ids
125
  current_tokens, just_noised_indices = noisify_answer(
126
+ input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start = 1.0,
127
  )
128
  yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>')
129
  time.sleep(pause_length)
 
182
  # just_noised_indices = []
183
  else:
184
  noised_answer, just_noised_indices = noisify_answer(
185
+ current_tokens, answer_start, tokenizer, threshold=threshold, clustering=clustering, noise_start = noise_start,
186
  )
187
 
188
  # --- RED HIGHLIGHT ---
 
227
  filename="diffusion-model.pth",
228
  token=os.getenv("HF_TOKEN")
229
  )
230
+ model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
231
  print("✅ Model loaded.")
232
 
233
+ vocab_size = len(tokenizer)
234
+ eos_token_id = tokenizer.eos_token_id
235
+ mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
236
+ assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
237
+
238
  demo = gr.Interface(
239
  fn=diffusion_chat,
240
  inputs=[
infer.py CHANGED
@@ -82,8 +82,8 @@ def filter_logits(logits, top_k=0, top_p=1.0, temperature=1.0):
82
 
83
  return logits
84
 
85
-
86
- def decode_tokens_safe(tokenizer, token_ids):
87
  return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
88
 
89
  def find_answer_start(input_ids, marker_ids):
@@ -92,24 +92,36 @@ def find_answer_start(input_ids, marker_ids):
92
  return i + len(marker_ids)
93
  return None
94
 
95
- def noisify_answer(input_ids, answer_start, threshold=1.0, is_unmasked=None, mask_token_id=128002):
96
- noised = input_ids.copy()
97
- total_len = len(input_ids)
98
- candidates = [
99
- i for i in range(answer_start, total_len)
100
- if is_unmasked is None or not is_unmasked[i]
101
- ]
102
- num_to_add = int(threshold * total_len)
103
- if num_to_add > 0 and len(candidates) > 0:
104
- newly_masked = rng.choice(candidates, size=min(num_to_add, len(candidates)), replace=False)
105
- for idx in newly_masked:
106
- noised[idx] = mask_token_id
107
- return noised
108
-
109
  def get_noising_schedule(i, max_it, sharpness=5.0):
110
  x = i / max_it
111
  return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  import torch.nn.functional as F
114
 
115
  def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
 
82
 
83
  return logits
84
 
85
+ # --- Utility Functions ---
86
+ def decode_tokens_safe(token_ids, tokenizer):
87
  return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
88
 
89
  def find_answer_start(input_ids, marker_ids):
 
92
  return i + len(marker_ids)
93
  return None
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def get_noising_schedule(i, max_it, sharpness=5.0):
96
  x = i / max_it
97
  return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
98
 
99
+ def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering=0.5, noise_start = 1.0):
100
+ noised = input_ids.copy()
101
+ answer_len = len(noised) - answer_start
102
+ num_to_noise = int(threshold * answer_len * noise_start)
103
+ mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
104
+
105
+ if num_to_noise == 0:
106
+ return noised, []
107
+
108
+ num_clusters = max(1, int((1 - clustering) * num_to_noise))
109
+ cluster_size = max(1, int(num_to_noise / num_clusters))
110
+
111
+ noised_indices = set()
112
+ for _ in range(num_clusters):
113
+ center = rng.integers(answer_start, len(noised))
114
+ span_start = max(answer_start, center - cluster_size // 2)
115
+ span_end = min(len(noised), span_start + cluster_size)
116
+ noised_indices.update(range(span_start, span_end))
117
+
118
+ noised_indices = sorted(list(noised_indices))[:num_to_noise]
119
+
120
+ for idx in noised_indices:
121
+ noised[idx] = mask_token_id
122
+
123
+ return noised, noised_indices
124
+
125
  import torch.nn.functional as F
126
 
127
  def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,