Spaces:
Running on Zero

Ruurd commited on
Commit
a8d72d4
·
1 Parent(s): ec83427

add eos_bias

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -36,11 +36,16 @@ if hf_token is None:
36
  rng = np.random.default_rng()
37
 
38
  @spaces.GPU
39
- def generate_diffusion_text(input_ids, top_p, top_k):
40
  with torch.no_grad():
41
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
42
- with torch.amp.autocast('cuda', dtype=torch.float16):
43
  logits = model(input_ids=input_tensor)["logits"]
 
 
 
 
 
44
  logits = filter_logits(logits, top_k=top_p, top_p=top_k)
45
  logits = logits.clamp(min=-1e8, max=1e4)
46
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
@@ -79,11 +84,12 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
79
  highlighted.append(tok_str)
80
  return "".join(highlighted)
81
 
82
- def diffusion_chat(question, max_it, pause_length, sharpness,
83
  clustering, noise_start, use_confidence_noising,
84
  use_permanent_unmasking, noise_clipping, top_p,
85
  top_k):
86
 
 
87
  if question.strip() == "":
88
  question = "What do you know about the city of Amsterdam?"
89
 
@@ -111,7 +117,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
111
  unmasked_mask = [False] * len(current_tokens)
112
 
113
  for i in range(max_it):
114
- generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
115
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
116
 
117
  # GREEN highlighting: compare to previous tokens
@@ -205,6 +211,7 @@ demo = gr.Interface(
205
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
206
  gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
207
  gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
 
208
  gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
209
  gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
210
  gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
 
36
  rng = np.random.default_rng()
37
 
38
  @spaces.GPU
39
+ def generate_diffusion_text(input_ids, top_p, top_k, eos_bias=0.0):
40
  with torch.no_grad():
41
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
42
+ with torch.cuda.amp.autocast(dtype=torch.float16):
43
  logits = model(input_ids=input_tensor)["logits"]
44
+
45
+ # Apply eos_bias
46
+ if eos_bias != 0.0:
47
+ logits[0, :, eos_token_id] += eos_bias
48
+
49
  logits = filter_logits(logits, top_k=top_p, top_p=top_k)
50
  logits = logits.clamp(min=-1e8, max=1e4)
51
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
 
84
  highlighted.append(tok_str)
85
  return "".join(highlighted)
86
 
87
+ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
88
  clustering, noise_start, use_confidence_noising,
89
  use_permanent_unmasking, noise_clipping, top_p,
90
  top_k):
91
 
92
+ eos_bias = -eos_bias
93
  if question.strip() == "":
94
  question = "What do you know about the city of Amsterdam?"
95
 
 
117
  unmasked_mask = [False] * len(current_tokens)
118
 
119
  for i in range(max_it):
120
+ generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
121
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
122
 
123
  # GREEN highlighting: compare to previous tokens
 
211
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
212
  gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
213
  gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
214
+ gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
215
  gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
216
  gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
217
  gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),