Spaces:
Ruurd
/
Running on Zero

tini-lad / app.py
Ruurd's picture
Removed description
8c6f2ab
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
import os
from huggingface_hub import hf_hub_download
import spaces
from dotenv import load_dotenv
from infer import (
load_trained_model,
find_answer_start,
get_noising_schedule,
noisify_answer,
filter_logits,
confidence_guided_noising,
noisify_answer_without_remasking
)
from models import CustomTransformerModel
from model_config import CustomTransformerConfig
# Load .env only when running locally
if os.getenv("HF_TOKEN") is None:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
raise ValueError("HF_TOKEN is not set")
rng = np.random.default_rng()
@spaces.GPU
def generate_diffusion_text(input_ids, top_p, top_k):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
with torch.cuda.amp.autocast(dtype=torch.float16):
logits = model(input_ids=input_tensor)["logits"]
logits = filter_logits(logits, top_k=top_k, top_p=top_p)
logits = logits.clamp(min=-1e8, max=1e4)
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
probs = torch.clamp(probs, min=1e-8, max=1.0)
# assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
# assert (probs >= 0).all(), "Negative probs!"
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
conf = probs[range(len(sampled)), sampled].cpu().numpy()
return sampled, conf
def format_chat_prompt(question):
return (
"<|begin_of_text|>\n"
"<|start_header_id|>system<|end_header_id|>\n"
"You are a helpful assistant.\n"
"<|start_header_id|>user<|end_header_id|>\n"
f"{question}\n"
"<|start_header_id|>assistant<|end_header_id|>\n"
)
def render_html(label, text):
return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>"
def highlight_tokens(token_ids, answer_start, changed_indices, color):
tokens = tokenizer.convert_ids_to_tokens(token_ids)
highlighted = []
for j, tok in enumerate(tokens):
if tokenizer.convert_tokens_to_ids(tok) == eos_token_id:
continue
tok_str = tokenizer.convert_tokens_to_string([tok])
if (answer_start + j) in changed_indices:
highlighted.append(f'<span style="color:{color}">{tok_str}</span>')
else:
highlighted.append(tok_str)
return "".join(highlighted)
def diffusion_chat(question, noising, enable_pause, max_it):
sharpness = 3.0
noise_start = 0.5
top_p = 1.0
top_k = 10
clustering = False
pause_length = 1.0 if enable_pause else 0.0
if question.strip() == "":
question = "What do you know about Amsterdam?"
prompt = format_chat_prompt(question)
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield render_html("Error", "Could not find Assistant marker in input.")
return
input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256]
ori_input_tokens = input_ids
# Initial noising
current_tokens, just_noised_indices = noisify_answer(
input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
)
yield render_html("Iteration 0 (initial noise)",
highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
start = time.perf_counter()
last_tokens = []
prev_decoded = []
unmasked_mask = [False] * len(current_tokens)
for i in range(max_it):
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k)
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
# GREEN highlighting: compare to previous tokens
new_decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
diff_indices = {
answer_start + j for j, tok in enumerate(new_decoded)
if j >= len(prev_decoded) or tok != prev_decoded[j]
}
prev_decoded = new_decoded
time.sleep(max(pause_length - (time.perf_counter() - start), 0))
yield render_html(f"Iteration {i+1}/{max_it} (after generation)",
highlight_tokens(current_tokens[answer_start:], answer_start, diff_indices, color="green"))
time.sleep(pause_length)
# Early stopping
last_tokens.append(current_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield render_html("Stopped early", f"After {i+1} iterations.")
break
# NOISING
if i < max_it-1 and noising:
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
noised_answer, just_noised_indices = noisify_answer(
current_tokens, answer_start, tokenizer,
threshold=threshold, clustering=clustering, noise_start=noise_start
)
for idx in range(answer_start, len(current_tokens)):
if noised_answer[idx] != mask_token_id:
unmasked_mask[idx] = True
yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
start = time.perf_counter()
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
# Final output
answer_ids = current_tokens[answer_start:]
try:
final_ids = answer_ids[:answer_ids.index(eos_token_id)]
except ValueError:
final_ids = answer_ids
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore
def is_running_on_spaces():
return os.getenv("SPACE_ID") is not None
print("Loading model...")
if is_running_on_spaces():
# Load from Hugging Face Hub
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_model",
filename="diffusion-model-8B.pth",
token=os.getenv("HF_TOKEN")
)
else:
# Load from local path
ckpt_path = "diffusion-model-8B.pth" # change to your actual local path
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
print("✅ Model loaded.")
vocab_size = len(tokenizer)
eos_token_id = tokenizer.eos_token_id
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(
label="User Question",
lines=2,
placeholder="What do you know about Amsterdam?",
),
gr.Checkbox(label="Enable intermediate noising", value=True),
gr.Checkbox(label="Pause between iterations", value=False),
gr.Slider(1, 512, value=64, step=1, label="Increase the maximum number of iterations."),
],
outputs=gr.HTML(label="Diffusion Output"),
title="LAD Chat",
allow_flagging="never",
live=False # ensures the Stop button appears properly
)
demo.launch(share=True, allowed_paths=["."], ssr_mode=False)