Spaces:
Running on Zero

lad / app.py
Ruurd's picture
Improve white space display
fb56411
raw
history blame
9.67 kB
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
from huggingface_hub import hf_hub_download
import spaces
from dotenv import load_dotenv
from infer import (
load_trained_model,
find_answer_start,
get_noising_schedule,
noisify_answer,
generate_diffusion_text,
filter_logits
)
from models import CustomTransformerModel
from model_config import CustomTransformerConfig
# Load .env only when running locally
if os.getenv("HF_TOKEN") is None:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
raise ValueError("HF_TOKEN is not set")
rng = np.random.default_rng()
# Add new noising function
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
if num_to_noise == 0:
return noised, []
all_indices = np.arange(answer_start, len(input_ids))
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
# Proportionally split how many to noise
num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5))
num_eos_to_noise = num_to_noise - num_non_eos_to_noise
noised_indices = []
# --- Non-EOS ---
if non_eos_indices:
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
weights = raw_weights / raw_weights.sum()
chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights)
noised_indices.extend(chosen.tolist())
# --- EOS ---
if eos_indices and num_eos_to_noise > 0:
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
weights = raw_weights / raw_weights.sum()
chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights)
noised_indices.extend(chosen.tolist())
for idx in noised_indices:
noised[idx] = mask_token_id
noised_indices = sorted(noised_indices)
return noised, noised_indices
@spaces.GPU
def generate_diffusion_text(input_ids, top_p, top_k):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
with torch.amp.autocast('cuda', dtype=torch.float16):
logits = model(input_ids=input_tensor)["logits"]
logits = filter_logits(logits, top_k=top_p, top_p=top_k)
logits = logits.clamp(min=-1e8, max=1e4)
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
probs = torch.clamp(probs, min=1e-8, max=1.0)
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
assert (probs >= 0).all(), "Negative probs!"
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
# Extract confidence of selected tokens
conf = probs[range(len(sampled)), sampled].cpu().numpy()
return sampled, conf
def format_chat_prompt(question):
return (
"<|begin_of_text|>\n"
"<|start_header_id|>system<|end_header_id|>\n"
"You are a helpful assistant.\n"
"<|start_header_id|>user<|end_header_id|>\n"
f"{question}\n"
"<|start_header_id|>assistant<|end_header_id|>\n"
)
def render_html(label, text):
return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
def highlight_tokens(tokens, color_indices=None, color="green"):
highlighted = []
for j, tok in enumerate(tokens):
if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
continue
token_str = tokenizer.convert_tokens_to_string([tok])
if color_indices and j in color_indices:
highlighted.append(f'<span style="color:{color}">{token_str}</span>')
else:
highlighted.append(token_str)
return "".join(highlighted)
# --- Inference Wrapper ---
def diffusion_chat(question, max_it, pause_length, sharpness,
clustering, noise_start, use_confidence_noising,
noise_clipping, top_p, top_k):
if question.strip() == "":
question = "What do you know about the city of Amsterdam?"
prompt = format_chat_prompt(question)
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield render_html("Error", "Could not find Assistant marker in input.")
return
input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
ori_input_tokens = input_ids
current_tokens, just_noised_indices = noisify_answer(
input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
)
yield render_html("Iteration 0 (initial noise)",
tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True))
time.sleep(pause_length)
last_tokens = []
prev_tokens = []
for i in range(max_it):
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
diff_indices = [j for j in range(len(decoded)) if j >= len(prev_tokens) or decoded[j] != prev_tokens[j]]
prev_tokens = decoded
yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
highlight_tokens(decoded, diff_indices, color="green"))
time.sleep(pause_length)
# Early stopping
last_tokens.append(current_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and len(set(map(tuple, last_tokens))) == 1:
yield render_html("Stopped early", f"After {i+1} iterations.")
break
# Noising step
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
if use_confidence_noising:
noised_answer, just_noised_indices = confidence_guided_noising(
current_tokens, answer_start, confidences, noise_clipping,
threshold=threshold, noise_start=noise_start
)
else:
noised_answer, just_noised_indices = noisify_answer(
current_tokens, answer_start, tokenizer,
threshold=threshold, clustering=clustering, noise_start=noise_start
)
decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
red_indices = [j for j in range(len(decoded)) if (answer_start + j) in just_noised_indices]
yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
highlight_tokens(decoded, red_indices, color="red"))
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
# Final output
answer_ids = current_tokens[answer_start:]
try:
final_ids = answer_ids[:answer_ids.index(eos_token_id)]
except ValueError:
final_ids = answer_ids
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
# --- Gradio Interface ---
print("Loading model...")
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_model",
filename="diffusion-model-8B.pth",
token=os.getenv("HF_TOKEN")
)
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
print("✅ Model loaded.")
vocab_size = len(tokenizer)
eos_token_id = tokenizer.eos_token_id
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"),
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"),
gr.Checkbox(value=False, label="Use confidence-guided noising"),
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"),
gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers")
],
outputs=[gr.HTML(label="Diffusion Output")],
title="Diffusion Language Model Chat",
theme="default",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch(share=True, allowed_paths=["."], ssr_mode=False)