|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import json |
|
import time |
|
from transformers import AutoTokenizer |
|
import os |
|
import importlib |
|
from huggingface_hub import hf_hub_download |
|
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout |
|
import spaces |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token) |
|
vocab_size = len(tokenizer) |
|
eos_token_id = tokenizer.eos_token_id |
|
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] |
|
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
ckpt_path = hf_hub_download( |
|
repo_id="ruurd/tini_model", |
|
filename="diffusion-model.pth", |
|
token=os.getenv("HF_TOKEN"), |
|
|
|
) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = CustomTransformerModel(CustomTransformerConfig()) |
|
|
|
|
|
full_model = torch.load(ckpt_path, map_location=device) |
|
|
|
|
|
try: |
|
state_dict = full_model.state_dict() |
|
except AttributeError: |
|
state_dict = full_model |
|
|
|
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
print("Missing keys:", missing) |
|
print("Unexpected keys:", unexpected) |
|
|
|
model = disable_dropout(model) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
rng = np.random.default_rng() |
|
|
|
|
|
def decode_tokens_safe(token_ids): |
|
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") |
|
|
|
def find_answer_start(input_ids, marker_ids): |
|
for i in range(len(input_ids) - len(marker_ids) + 1): |
|
if input_ids[i:i + len(marker_ids)] == marker_ids: |
|
return i + len(marker_ids) |
|
return None |
|
|
|
def get_noising_schedule(i, max_it, sharpness=5.0): |
|
x = i / max_it |
|
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) |
|
|
|
def noisify_answer(input_ids, answer_start, threshold=1.0, clustering=0.5, noise_start = 1.0): |
|
noised = input_ids.copy() |
|
answer_len = len(noised) - answer_start |
|
num_to_noise = int(threshold * answer_len * noise_start) |
|
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0] |
|
|
|
if num_to_noise == 0: |
|
return noised, [] |
|
|
|
num_clusters = max(1, int((1 - clustering) * num_to_noise)) |
|
cluster_size = max(1, int(num_to_noise / num_clusters)) |
|
|
|
noised_indices = set() |
|
for _ in range(num_clusters): |
|
center = rng.integers(answer_start, len(noised)) |
|
span_start = max(answer_start, center - cluster_size // 2) |
|
span_end = min(len(noised), span_start + cluster_size) |
|
noised_indices.update(range(span_start, span_end)) |
|
|
|
noised_indices = sorted(list(noised_indices))[:num_to_noise] |
|
|
|
for idx in noised_indices: |
|
noised[idx] = mask_token_id |
|
|
|
return noised, noised_indices |
|
|
|
|
|
|
|
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0): |
|
noised = input_ids.copy() |
|
answer_len = len(input_ids) - answer_start |
|
num_to_noise = int(threshold * answer_len * noise_start) |
|
if num_to_noise == 0: |
|
return noised, [] |
|
|
|
all_indices = np.arange(answer_start, len(input_ids)) |
|
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id] |
|
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id] |
|
|
|
|
|
num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)) |
|
num_eos_to_noise = num_to_noise - num_non_eos_to_noise |
|
|
|
noised_indices = [] |
|
|
|
|
|
if non_eos_indices: |
|
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices]) |
|
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) |
|
weights = raw_weights / raw_weights.sum() |
|
|
|
chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights) |
|
noised_indices.extend(chosen.tolist()) |
|
|
|
|
|
if eos_indices and num_eos_to_noise > 0: |
|
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices]) |
|
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) |
|
weights = raw_weights / raw_weights.sum() |
|
|
|
chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights) |
|
noised_indices.extend(chosen.tolist()) |
|
|
|
for idx in noised_indices: |
|
noised[idx] = mask_token_id |
|
|
|
noised_indices = sorted(noised_indices) |
|
return noised, noised_indices |
|
|
|
def filter_logits(logits, top_k=0, top_p=0.0): |
|
"""Filter logits per position for top-k / nucleus (top-p) sampling.""" |
|
logits = logits.clone() |
|
batch_size, seq_len, vocab_size = logits.shape |
|
|
|
for i in range(seq_len): |
|
token_logits = logits[0, i] |
|
|
|
if top_k > 0: |
|
top_values, _ = torch.topk(token_logits, top_k) |
|
threshold = top_values[-1] |
|
token_logits[token_logits < threshold] = float("-inf") |
|
|
|
if top_p > 0.0: |
|
sorted_logits, sorted_indices = torch.sort(token_logits, descending=True) |
|
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
|
sorted_indices_to_remove[0] = 0 |
|
|
|
token_logits[sorted_indices[sorted_indices_to_remove]] = float("-inf") |
|
|
|
logits[0, i] = token_logits |
|
|
|
return logits |
|
|
|
@spaces.GPU |
|
def generate_diffusion_text(input_ids, top_p, top_k): |
|
with torch.no_grad(): |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) |
|
logits = model(input_ids=input_tensor)["logits"] |
|
logits = filter_logits(logits, top_k=top_p, top_p=top_k) |
|
logits = logits.clamp(min=-1e8, max=1e4) |
|
probs = torch.nn.functional.softmax(logits, dim=-1)[0] |
|
probs = torch.clamp(probs, min=1e-8, max=1.0) |
|
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!" |
|
assert (probs >= 0).all(), "Negative probs!" |
|
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist() |
|
|
|
|
|
conf = probs[range(len(sampled)), sampled].cpu().numpy() |
|
return sampled, conf |
|
|
|
|
|
def diffusion_chat(question, max_it, pause_length, sharpness, |
|
clustering, noise_start, use_confidence_noising, |
|
noise_clipping, top_p, top_k): |
|
placeholder = "What do you know about the city of Amsterdam?" |
|
if question.strip() == "": |
|
question = placeholder |
|
|
|
print('started generation') |
|
prompt = f"User: {question}\nAssistant:" |
|
input_ids = tokenizer.encode(prompt, add_special_tokens=False) |
|
answer_start = find_answer_start(input_ids, assistant_marker_ids) |
|
if answer_start is None: |
|
yield "Error: Could not find Assistant marker in input." |
|
return |
|
|
|
if len(input_ids) < 256: |
|
input_ids += [mask_token_id] * (256 - len(input_ids)) |
|
else: |
|
input_ids = input_ids[:256] |
|
|
|
ori_input_tokens = input_ids |
|
current_tokens, just_noised_indices = noisify_answer( |
|
input_ids, answer_start, threshold=1.0, clustering=clustering, noise_start = 1.0, |
|
) |
|
yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>') |
|
time.sleep(pause_length) |
|
last_tokens = [] |
|
prev_decoded_tokens = [] |
|
|
|
generation_start = time.time() |
|
|
|
for i in range(max_it): |
|
print('Generating output') |
|
|
|
|
|
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k) |
|
|
|
elapsed = time.time() - generation_start |
|
remaining = pause_length - elapsed |
|
if remaining > 0: |
|
time.sleep(remaining) |
|
|
|
|
|
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:] |
|
|
|
|
|
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) |
|
highlighted = [] |
|
for j, tok in enumerate(decoded_tokens): |
|
tok_id = tokenizer.convert_tokens_to_ids(tok) |
|
if tok_id == eos_token_id: |
|
continue |
|
token_str = tokenizer.convert_tokens_to_string([tok]) |
|
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]: |
|
highlighted.append(f'<span style="color:green">{token_str}</span>') |
|
else: |
|
highlighted.append(token_str) |
|
|
|
prev_decoded_tokens = decoded_tokens |
|
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>') |
|
time.sleep(pause_length) |
|
|
|
|
|
last_tokens.append(current_tokens) |
|
if len(last_tokens) > 3: |
|
last_tokens.pop(0) |
|
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: |
|
yield f"<b>Stopped early after {i+1} iterations.</b>" |
|
break |
|
|
|
previous_tokens = current_tokens.copy() |
|
|
|
|
|
threshold = get_noising_schedule(i, max_it, sharpness=sharpness) |
|
if use_confidence_noising: |
|
noised_answer, just_noised_indices = confidence_guided_noising( |
|
current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, noise_start=noise_start |
|
) |
|
|
|
else: |
|
noised_answer, just_noised_indices = noisify_answer( |
|
current_tokens, answer_start, threshold=threshold, clustering=clustering, noise_start = noise_start, |
|
) |
|
|
|
|
|
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) |
|
highlighted = [] |
|
for j, tok in enumerate(decoded_tokens): |
|
tok_id = tokenizer.convert_tokens_to_ids(tok) |
|
if tok_id == eos_token_id: |
|
continue |
|
token_str = tokenizer.convert_tokens_to_string([tok]) |
|
abs_idx = answer_start + j |
|
if abs_idx in just_noised_indices: |
|
highlighted.append(f'<span style="color:red">{token_str}</span>') |
|
else: |
|
highlighted.append(token_str) |
|
|
|
|
|
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:] |
|
|
|
yield f"<b>Iteration {i+1}/{max_it} (before noising):</b><br>" + "".join(highlighted).replace('\n', '<br>') |
|
generation_start = time.time() |
|
|
|
|
|
answer_ids = current_tokens[answer_start:] |
|
try: |
|
eos_index = answer_ids.index(eos_token_id) |
|
final_ids = answer_ids[:eos_index] |
|
except ValueError: |
|
final_ids = answer_ids |
|
|
|
num_tokens = len(final_ids) |
|
final_output = tokenizer.decode(final_ids, skip_special_tokens=True) |
|
|
|
print(final_output) |
|
yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>') |
|
|
|
|
|
|
|
print("Loading model...") |
|
model = load_model() |
|
print("✅ Model loaded.") |
|
|
|
demo = gr.Interface( |
|
fn=diffusion_chat, |
|
inputs=[ |
|
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"), |
|
gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"), |
|
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"), |
|
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"), |
|
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"), |
|
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"), |
|
gr.Checkbox(value=False, label="Use confidence-guided noising"), |
|
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"), |
|
gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"), |
|
gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers") |
|
], |
|
outputs=[gr.HTML(label="Diffusion Output")], |
|
title="Diffusion Language Model Chat", |
|
theme="default", |
|
description="This interface runs a diffusion-based language model to generate answers progressively." |
|
) |
|
|
|
demo.launch(share=True, allowed_paths=["."], ssr_mode=False) |
|
|
|
|
|
|
|
|