|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import json |
|
import time |
|
from transformers import AutoTokenizer |
|
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout |
|
import os |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token) |
|
vocab_size = len(tokenizer) |
|
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id |
|
eot_token_id = tokenizer.eos_token_id |
|
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False) |
|
|
|
|
|
with open("token_probabilities.json") as f: |
|
token_probs_dict = json.load(f) |
|
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32) |
|
|
|
|
|
def load_model(): |
|
config = CustomTransformerConfig(vocab_size=vocab_size) |
|
model = CustomTransformerModel(config) |
|
model.load_state_dict(torch.hub.load_state_dict_from_url( |
|
"https://huggingface.co/Ruurd/tini_model/resolve/main/diffusion-model.pth", |
|
map_location="cuda", |
|
headers={"Authorization": f"Bearer {hf_token}"} |
|
)) |
|
model = disable_dropout(model) |
|
model.to("cuda") |
|
model.eval() |
|
return model |
|
|
|
rng = np.random.default_rng() |
|
|
|
|
|
def decode_tokens_safe(token_ids): |
|
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") |
|
|
|
def find_answer_start(input_ids, marker_ids): |
|
for i in range(len(input_ids) - len(marker_ids) + 1): |
|
if input_ids[i:i + len(marker_ids)] == marker_ids: |
|
return i + len(marker_ids) |
|
return None |
|
|
|
def get_noising_schedule(i, max_it, sharpness=5.0): |
|
x = i / max_it |
|
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) |
|
|
|
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0): |
|
noised = input_ids.copy() |
|
answer_len = len(input_ids) - answer_start |
|
num_to_noise = int(threshold * answer_len) |
|
if num_to_noise > 0: |
|
indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False) |
|
|
|
mixed_probs = token_probabilities.copy() |
|
mixed_probs[eot_token_id] *= eot_weight |
|
mixed_probs /= mixed_probs.sum() |
|
|
|
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs) |
|
for idx, val in zip(indices, noise): |
|
noised[idx] = val |
|
return noised |
|
|
|
def generate_diffusion_text(model, input_ids, answer_start): |
|
with torch.no_grad(): |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) |
|
logits = model(input_ids=input_tensor)["logits"] |
|
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() |
|
probs = torch.clamp(probs, min=1e-8, max=1.0) |
|
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist() |
|
return input_ids[:answer_start] + sampled[answer_start:] |
|
|
|
|
|
def diffusion_chat(question, eot_weight, max_it, sharpness, model): |
|
placeholder = "What do you know about the city of New York?" |
|
if question.strip() == "": |
|
question = placeholder |
|
|
|
prompt = f"User: {question}\nAssistant:" |
|
input_ids = tokenizer.encode(prompt, add_special_tokens=False) |
|
answer_start = find_answer_start(input_ids, assistant_marker_ids) |
|
if answer_start is None: |
|
yield "Error: Could not find Assistant marker in input." |
|
return |
|
|
|
if len(input_ids) < 256: |
|
input_ids += [pad_token] * (256 - len(input_ids)) |
|
else: |
|
input_ids = input_ids[:256] |
|
|
|
ori_input_tokens = input_ids |
|
current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight) |
|
prev_decoded_tokens = [] |
|
last_tokens = [] |
|
|
|
for i in range(max_it): |
|
generated_tokens = generate_diffusion_text(model, current_tokens, answer_start) |
|
current_tokens = generated_tokens |
|
|
|
decoded_ids = current_tokens[answer_start:] |
|
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids) |
|
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] |
|
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else [] |
|
|
|
if filtered_prev_tokens: |
|
highlighted = [] |
|
for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens): |
|
if tok_new != tok_old: |
|
highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>') |
|
else: |
|
highlighted.append(tokenizer.convert_tokens_to_string([tok_new])) |
|
else: |
|
highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens] |
|
|
|
prev_decoded_tokens = decoded_tokens |
|
yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted) |
|
|
|
last_tokens.append(generated_tokens) |
|
if len(last_tokens) > 3: |
|
last_tokens.pop(0) |
|
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: |
|
yield f"<b>Stopped early after {i+1} iterations.</b>" |
|
break |
|
|
|
threshold = get_noising_schedule(i, max_it, sharpness=sharpness) |
|
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight) |
|
time.sleep(0.01) |
|
|
|
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) |
|
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] |
|
final_output = tokenizer.convert_tokens_to_string(final_tokens) |
|
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output |
|
|
|
|
|
model_state = gr.State(load_model()) |
|
|
|
demo = gr.Interface( |
|
fn=diffusion_chat, |
|
inputs=[ |
|
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"), |
|
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"), |
|
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"), |
|
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"), |
|
model_state |
|
], |
|
outputs=gr.HTML(label="Diffusion Output"), |
|
title="Diffusion Language Model Chat", |
|
description="This interface runs a diffusion-based language model to generate answers progressively." |
|
) |
|
|
|
demo.launch() |
|
|