Spaces:
Running
on
Zero
Running
on
Zero
add eos_bias
Browse files
app.py
CHANGED
@@ -36,11 +36,16 @@ if hf_token is None:
|
|
36 |
rng = np.random.default_rng()
|
37 |
|
38 |
@spaces.GPU
|
39 |
-
def generate_diffusion_text(input_ids, top_p, top_k):
|
40 |
with torch.no_grad():
|
41 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
42 |
-
with torch.amp.autocast(
|
43 |
logits = model(input_ids=input_tensor)["logits"]
|
|
|
|
|
|
|
|
|
|
|
44 |
logits = filter_logits(logits, top_k=top_p, top_p=top_k)
|
45 |
logits = logits.clamp(min=-1e8, max=1e4)
|
46 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
@@ -79,11 +84,12 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
|
|
79 |
highlighted.append(tok_str)
|
80 |
return "".join(highlighted)
|
81 |
|
82 |
-
def diffusion_chat(question, max_it, pause_length, sharpness,
|
83 |
clustering, noise_start, use_confidence_noising,
|
84 |
use_permanent_unmasking, noise_clipping, top_p,
|
85 |
top_k):
|
86 |
|
|
|
87 |
if question.strip() == "":
|
88 |
question = "What do you know about the city of Amsterdam?"
|
89 |
|
@@ -111,7 +117,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
|
|
111 |
unmasked_mask = [False] * len(current_tokens)
|
112 |
|
113 |
for i in range(max_it):
|
114 |
-
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
|
115 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
116 |
|
117 |
# GREEN highlighting: compare to previous tokens
|
@@ -205,6 +211,7 @@ demo = gr.Interface(
|
|
205 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
206 |
gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
|
207 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
|
|
|
208 |
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
|
209 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
210 |
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
|
|
36 |
rng = np.random.default_rng()
|
37 |
|
38 |
@spaces.GPU
|
39 |
+
def generate_diffusion_text(input_ids, top_p, top_k, eos_bias=0.0):
|
40 |
with torch.no_grad():
|
41 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
42 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
43 |
logits = model(input_ids=input_tensor)["logits"]
|
44 |
+
|
45 |
+
# Apply eos_bias
|
46 |
+
if eos_bias != 0.0:
|
47 |
+
logits[0, :, eos_token_id] += eos_bias
|
48 |
+
|
49 |
logits = filter_logits(logits, top_k=top_p, top_p=top_k)
|
50 |
logits = logits.clamp(min=-1e8, max=1e4)
|
51 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
|
|
84 |
highlighted.append(tok_str)
|
85 |
return "".join(highlighted)
|
86 |
|
87 |
+
def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
88 |
clustering, noise_start, use_confidence_noising,
|
89 |
use_permanent_unmasking, noise_clipping, top_p,
|
90 |
top_k):
|
91 |
|
92 |
+
eos_bias = -eos_bias
|
93 |
if question.strip() == "":
|
94 |
question = "What do you know about the city of Amsterdam?"
|
95 |
|
|
|
117 |
unmasked_mask = [False] * len(current_tokens)
|
118 |
|
119 |
for i in range(max_it):
|
120 |
+
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
|
121 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
122 |
|
123 |
# GREEN highlighting: compare to previous tokens
|
|
|
211 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
212 |
gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
|
213 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
|
214 |
+
gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
|
215 |
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
|
216 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
217 |
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
|