Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import json | |
import time | |
from transformers import AutoTokenizer | |
import os | |
import importlib | |
from huggingface_hub import hf_hub_download | |
import spaces | |
from dotenv import load_dotenv | |
from infer import ( | |
load_trained_model, | |
find_answer_start, | |
get_noising_schedule, | |
noisify_answer, | |
generate_diffusion_text, | |
filter_logits | |
) | |
from models import CustomTransformerModel | |
from model_config import CustomTransformerConfig | |
# Load .env only when running locally | |
if os.getenv("HF_TOKEN") is None: | |
load_dotenv() | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token is None: | |
raise ValueError("HF_TOKEN is not set") | |
rng = np.random.default_rng() | |
# Add new noising function | |
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0): | |
noised = input_ids.copy() | |
answer_len = len(input_ids) - answer_start | |
num_to_noise = int(threshold * answer_len * noise_start) | |
if num_to_noise == 0: | |
return noised, [] | |
all_indices = np.arange(answer_start, len(input_ids)) | |
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id] | |
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id] | |
# Proportionally split how many to noise | |
num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)) | |
num_eos_to_noise = num_to_noise - num_non_eos_to_noise | |
noised_indices = [] | |
# --- Non-EOS --- | |
if non_eos_indices: | |
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices]) | |
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) | |
weights = raw_weights / raw_weights.sum() | |
chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights) | |
noised_indices.extend(chosen.tolist()) | |
# --- EOS --- | |
if eos_indices and num_eos_to_noise > 0: | |
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices]) | |
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) | |
weights = raw_weights / raw_weights.sum() | |
chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights) | |
noised_indices.extend(chosen.tolist()) | |
for idx in noised_indices: | |
noised[idx] = mask_token_id | |
noised_indices = sorted(noised_indices) | |
return noised, noised_indices | |
def generate_diffusion_text(input_ids, top_p, top_k): | |
with torch.no_grad(): | |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) | |
with torch.amp.autocast('cuda', dtype=torch.float16): | |
logits = model(input_ids=input_tensor)["logits"] | |
logits = filter_logits(logits, top_k=top_p, top_p=top_k) | |
logits = logits.clamp(min=-1e8, max=1e4) | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
probs = torch.clamp(probs, min=1e-8, max=1.0) | |
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!" | |
assert (probs >= 0).all(), "Negative probs!" | |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist() | |
# Extract confidence of selected tokens | |
conf = probs[range(len(sampled)), sampled].cpu().numpy() | |
return sampled, conf | |
def format_chat_prompt(question): | |
return ( | |
"<|begin_of_text|>\n" | |
"<|start_header_id|>system<|end_header_id|>\n" | |
"You are a helpful assistant.\n" | |
"<|start_header_id|>user<|end_header_id|>\n" | |
f"{question}\n" | |
"<|start_header_id|>assistant<|end_header_id|>\n" | |
) | |
def render_html(label, text): | |
return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>" | |
def highlight_tokens(tokens, color_indices=None, color="green"): | |
highlighted = [] | |
for j, tok in enumerate(tokens): | |
if tokenizer.convert_tokens_to_ids(tok) == eos_token_id: | |
continue | |
token_str = tokenizer.convert_tokens_to_string([tok]) | |
if color_indices and j in color_indices: | |
highlighted.append(f'<span style="color:{color}">{token_str}</span>') | |
else: | |
highlighted.append(token_str) | |
return "".join(highlighted) | |
# --- Inference Wrapper --- | |
def diffusion_chat(question, max_it, pause_length, sharpness, | |
clustering, noise_start, use_confidence_noising, | |
noise_clipping, top_p, top_k): | |
if question.strip() == "": | |
question = "What do you know about the city of Amsterdam?" | |
prompt = format_chat_prompt(question) | |
input_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
answer_start = find_answer_start(input_ids, assistant_marker_ids) | |
if answer_start is None: | |
yield render_html("Error", "Could not find Assistant marker in input.") | |
return | |
input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256] | |
ori_input_tokens = input_ids | |
current_tokens, just_noised_indices = noisify_answer( | |
input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0 | |
) | |
yield render_html("Iteration 0 (initial noise)", | |
tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True)) | |
time.sleep(pause_length) | |
last_tokens = [] | |
prev_tokens = [] | |
for i in range(max_it): | |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k) | |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:] | |
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
diff_indices = [j for j in range(len(decoded)) if j >= len(prev_tokens) or decoded[j] != prev_tokens[j]] | |
prev_tokens = decoded | |
yield render_html(f"Iteration {i+1}/{max_it} (after generation)", | |
highlight_tokens(decoded, diff_indices, color="green")) | |
time.sleep(pause_length) | |
# Early stopping | |
last_tokens.append(current_tokens) | |
if len(last_tokens) > 3: | |
last_tokens.pop(0) | |
if len(last_tokens) == 3 and len(set(map(tuple, last_tokens))) == 1: | |
yield render_html("Stopped early", f"After {i+1} iterations.") | |
break | |
# Noising step | |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness) | |
if use_confidence_noising: | |
noised_answer, just_noised_indices = confidence_guided_noising( | |
current_tokens, answer_start, confidences, noise_clipping, | |
threshold=threshold, noise_start=noise_start | |
) | |
else: | |
noised_answer, just_noised_indices = noisify_answer( | |
current_tokens, answer_start, tokenizer, | |
threshold=threshold, clustering=clustering, noise_start=noise_start | |
) | |
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
red_indices = [j for j in range(len(decoded)) if (answer_start + j) in just_noised_indices] | |
yield render_html(f"Iteration {i+1}/{max_it} (before noising)", | |
highlight_tokens(decoded, red_indices, color="red")) | |
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:] | |
# Final output | |
answer_ids = current_tokens[answer_start:] | |
try: | |
final_ids = answer_ids[:answer_ids.index(eos_token_id)] | |
except ValueError: | |
final_ids = answer_ids | |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True) | |
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) | |
# --- Gradio Interface --- | |
print("Loading model...") | |
ckpt_path = hf_hub_download( | |
repo_id="ruurd/tini_model", | |
filename="diffusion-model-8B.pth", | |
token=os.getenv("HF_TOKEN") | |
) | |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path) | |
print("✅ Model loaded.") | |
vocab_size = len(tokenizer) | |
eos_token_id = tokenizer.eos_token_id | |
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] | |
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False) | |
demo = gr.Interface( | |
fn=diffusion_chat, | |
inputs=[ | |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"), | |
gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"), | |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"), | |
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"), | |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"), | |
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"), | |
gr.Checkbox(value=False, label="Use confidence-guided noising"), | |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"), | |
gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"), | |
gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers") | |
], | |
outputs=[gr.HTML(label="Diffusion Output")], | |
title="Diffusion Language Model Chat", | |
theme="default", | |
description="This interface runs a diffusion-based language model to generate answers progressively." | |
) | |
demo.launch(share=True, allowed_paths=["."], ssr_mode=False) | |