Spaces:
Running on Zero

Ruurd commited on
Commit
b5f844d
·
verified ·
1 Parent(s): a3a4100

Fix clamping and introduce top-k and top-p filtering

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -151,13 +151,40 @@ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clippi
151
  noised_indices = sorted(noised_indices)
152
  return noised, noised_indices
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  @spaces.GPU
156
  def generate_diffusion_text(input_ids):
157
  with torch.no_grad():
158
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
159
  logits = model(input_ids=input_tensor)["logits"]
160
- logits = logits.clamp(min=-1e4, max=1e4)
 
161
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
162
  probs = torch.clamp(probs, min=1e-8, max=1.0)
163
  assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
 
151
  noised_indices = sorted(noised_indices)
152
  return noised, noised_indices
153
 
154
+ def filter_logits(logits, top_k=0, top_p=0.0):
155
+ """Filter logits per position for top-k / nucleus (top-p) sampling."""
156
+ logits = logits.clone() # don't modify in-place
157
+ batch_size, seq_len, vocab_size = logits.shape
158
+
159
+ for i in range(seq_len):
160
+ token_logits = logits[0, i]
161
+
162
+ if top_k > 0:
163
+ top_values, _ = torch.topk(token_logits, top_k)
164
+ threshold = top_values[-1]
165
+ token_logits[token_logits < threshold] = float("-inf")
166
+
167
+ if top_p > 0.0:
168
+ sorted_logits, sorted_indices = torch.sort(token_logits, descending=True)
169
+ cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
170
+
171
+ sorted_indices_to_remove = cumulative_probs > top_p
172
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
173
+ sorted_indices_to_remove[0] = 0 # always keep at least 1 token
174
+
175
+ token_logits[sorted_indices[sorted_indices_to_remove]] = float("-inf")
176
+
177
+ logits[0, i] = token_logits
178
+
179
+ return logits
180
 
181
  @spaces.GPU
182
  def generate_diffusion_text(input_ids):
183
  with torch.no_grad():
184
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
185
  logits = model(input_ids=input_tensor)["logits"]
186
+ logits = filter_logits(logits, top_k=top_k, top_p=top_p)
187
+ logits = logits.clamp(min=-1e8, max=1e4)
188
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
189
  probs = torch.clamp(probs, min=1e-8, max=1.0)
190
  assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"