Spaces:
Sleeping
Sleeping
# app.py | |
import json | |
import threading | |
import time | |
from pathlib import Path | |
import solara | |
import pandas as pd | |
import plotly.graph_objects as go | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# for robust hover/click from the browser | |
import anywidget | |
import traitlets as t | |
import html # for escaping token text in the HTML label | |
# ---------- Model ---------- | |
MODEL_ID = "Qwen/Qwen3-0.6B" # same as the working HF Space | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
# ---------- Theme & layout (light blue / white / black accents) ---------- | |
theme_css = """ | |
:root{ | |
--primary:#38bdf8; --bg:#ffffff; --text:#0b0f14; --muted:#6b7280; --border:#e5e7eb; | |
--mono:'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace'; | |
} | |
/* Base */ | |
body{ background:var(--bg); color:var(--text); margin:0;} | |
h1{ margin:6px 0 8px; } | |
/* Two-column layout */ | |
.app-row { display:flex; align-items:flex-start; gap:16px; } /* was 24px */ | |
.predictions-panel { flex:0 0 320px; position:relative; z-index:10;}/* was 360px */ | |
.plot-panel { flex:1 1 auto; position:relative; z-index:1; overflow:hidden; } | |
/* Prediction rows (tighter) */ | |
.rowbtn{ | |
width:100%; | |
padding:6px 10px; /* was 10px 12px */ | |
border-radius:10px; /* was 12px */ | |
border:1px solid var(--border); | |
background:#fff; color:var(--text); | |
display:flex; justify-content:flex-start; align-items:center; | |
text-align:left; cursor:pointer; user-select:none; | |
font-family: var(--mono); | |
font-size:13px; /* was default ~14–16 */ | |
line-height:1.15; | |
letter-spacing:.2px; | |
margin-bottom:6px; /* explicit, keeps list consistent */ | |
} | |
.rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; } | |
/* New: 4-column grid inside each row button */ | |
.rowbtn-grid{ | |
display:grid; | |
grid-template-columns: 28px 72px 72px 1fr; /* # | probs | tokenID | token */ | |
column-gap:8px; | |
align-items:center; | |
width:100%; | |
font-family: var(--mono); | |
font-size:13px; | |
line-height:1.15; | |
} | |
/* Neighbor chips (smaller) */ | |
.badge{ | |
display:inline-block; padding:2px 6px; /* was 2px 8px */ | |
border:1px solid var(--border); border-radius:999px; margin:2px; | |
font-size:12px; line-height:1.15; | |
} | |
""" | |
# ---------- Reactive state ---------- | |
text_rx = solara.reactive("Twinkle, twinkle, little ") | |
preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"])) | |
selected_token_id_rx = solara.reactive(None) | |
neighbor_list_rx = solara.reactive([]) | |
last_hovered_id_rx = solara.reactive(None) | |
auto_running_rx = solara.reactive(True) | |
neigh_msg_rx = solara.reactive("") # message shown when no neighborhood is available | |
# ---------- Embedding assets ---------- | |
ASSETS = Path("assets/embeddings") | |
COORDS_PATH = ASSETS / "pca_top5k_coords.json" | |
NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json" | |
coords = {} | |
neighbors = {} | |
ids_set = set() | |
if COORDS_PATH.exists() and NEIGH_PATH.exists(): | |
coords = json.loads(COORDS_PATH.read_text("utf-8")) | |
neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) | |
ids_set = set(map(int, coords.keys())) | |
else: | |
notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.") | |
# ---------- Helpers ---------- | |
def display_token_from_id(tid: int) -> str: | |
toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True) | |
t = toks[0] if toks else "" | |
for lead in ("▁", "Ġ"): | |
if t.startswith(lead): | |
t = t[len(lead):] | |
t = t.replace("\n","↵") | |
return t if t.strip() else "␠" | |
def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str: | |
# columns: index, probability, token id, token text | |
return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}" | |
# ---------- Prediction ---------- | |
def predict_top10(prompt: str) -> pd.DataFrame: | |
if not prompt: | |
return pd.DataFrame(columns=["probs", "id", "tok"]) | |
tokens = tokenizer(prompt, return_tensors="pt", padding=False) | |
out = model.generate( | |
**tokens, | |
max_new_tokens=1, | |
output_scores=True, | |
return_dict_in_generate=True, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=False, # greedy; temp/top_k are ignored (by design) | |
) | |
scores = torch.softmax(out.scores[0], dim=-1) | |
topk = torch.topk(scores, 10) | |
ids = [int(topk.indices[0, i]) for i in range(10)] | |
probs = [float(topk.values[0, i]) for i in range(10)] | |
toks = [tokenizer.decode([i]) for i in ids] # for append | |
df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks}) | |
df["probs"] = df["probs"].map(lambda p: f"{p:.2%}") | |
return df | |
def on_predict(): | |
df = predict_top10(text_rx.value) | |
preds_rx.set(df) | |
if len(df) == 0: | |
return | |
if selected_token_id_rx.value is None: | |
preview_token(int(df.iloc[0]["id"])) # only first time | |
else: | |
fig_rx.set(highlight(int(selected_token_id_rx.value))) # preserve selection | |
# ---------- Plot / neighborhood ---------- | |
def base_scatter(): | |
fig = go.Figure() | |
if coords: | |
xs, ys = zip(*[coords[k] for k in coords.keys()]) | |
fig.add_trace(go.Scattergl( | |
x=xs, y=ys, mode="markers", | |
marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"), | |
hoverinfo="skip", | |
)) | |
fig.update_layout( | |
height=380, margin=dict(l=6,r=6,t=6,b=6), | |
paper_bgcolor="white", plot_bgcolor="white", | |
xaxis=dict(visible=False), yaxis=dict(visible=False), | |
showlegend=False, | |
) | |
return fig | |
fig_rx = solara.reactive(base_scatter()) | |
def get_neighbor_list(token_id: int, k: int = 20): | |
if not ids_set or token_id not in ids_set: | |
return [] | |
raw = neighbors.get("neighbors", {}).get(str(token_id), []) | |
return raw[:k] | |
def highlight(token_id: int): | |
fig = base_scatter() | |
# Not in map (or missing map) → clear chips and show message | |
if not coords or token_id not in ids_set: | |
neighbor_list_rx.set([]) | |
if not coords: | |
neigh_msg_rx.set("Embedding map unavailable – add `assets/embeddings/*.json`.") | |
else: | |
neigh_msg_rx.set("Neighborhood unavailable for this token (not in the top-5k set).") | |
return fig | |
# In map → clear message and draw neighbors/target | |
neigh_msg_rx.set("") | |
nbrs = get_neighbor_list(token_id, k=20) | |
if nbrs: | |
nx = [coords[str(nid)][0] for nid,_ in nbrs] | |
ny = [coords[str(nid)][1] for nid,_ in nbrs] | |
fig.add_trace(go.Scattergl( | |
x=nx, y=ny, mode="markers", | |
marker=dict(size=6, color="rgba(56,189,248,0.75)"), | |
hoverinfo="skip", | |
)) | |
chips = [(display_token_from_id(int(nid)), float(sim)) for nid,sim in nbrs] | |
neighbor_list_rx.set(chips) | |
else: | |
neighbor_list_rx.set([]) | |
tx, ty = coords[str(token_id)] | |
fig.add_trace(go.Scattergl( | |
x=[tx], y=[ty], mode="markers", | |
marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)), | |
hoverinfo="skip", | |
)) | |
return fig | |
def preview_token(token_id: int): | |
# print("preview ->", token_id) # enable for debugging in Space logs | |
token_id = int(token_id) | |
if last_hovered_id_rx.value == token_id: | |
return | |
last_hovered_id_rx.set(token_id) | |
selected_token_id_rx.set(token_id) | |
fig_rx.set(highlight(token_id)) | |
def append_token(token_id: int): | |
# print("append ->", token_id) | |
decoded = tokenizer.decode([int(token_id)]) | |
text_rx.set(text_rx.value + decoded) | |
preview_token(int(token_id)) | |
on_predict() | |
# ---------- Debounced auto-predict ---------- | |
def AutoPredictWatcher(): | |
text = text_rx.value | |
auto = auto_running_rx.value | |
def effect(): | |
if not auto: | |
return | |
cancelled = False | |
snap = text | |
def worker(): | |
time.sleep(0.25) | |
if not cancelled and snap == text_rx.value: | |
on_predict() | |
threading.Thread(target=worker, daemon=True).start() | |
def cleanup(): | |
nonlocal cancelled | |
cancelled = True | |
return cleanup | |
solara.use_effect(effect, [text, auto]) | |
return solara.Text("", style={"display": "none"}) | |
# ---------- Hover-enabled list (browser) ---------- | |
class HoverList(anywidget.AnyWidget): | |
""" | |
Renders the prediction rows in the browser and streams hover/click events | |
back to Python via synced traitlets. Supports HTML row labels via `label_html`. | |
""" | |
_esm = """ | |
export function render({ model, el }) { | |
const renderList = () => { | |
const items = model.get('items') || []; | |
el.innerHTML = ""; | |
const wrap = document.createElement('div'); | |
wrap.style.display = 'flex'; | |
wrap.style.flexDirection = 'column'; | |
items.forEach((item) => { | |
const { tid, label, label_html } = item; | |
const btn = document.createElement('button'); | |
btn.className = 'rowbtn'; | |
btn.setAttribute('type', 'button'); | |
btn.setAttribute('role', 'button'); | |
btn.setAttribute('tabindex', '0'); | |
// Prefer HTML layout if provided; fall back to plain text | |
if (label_html) { btn.innerHTML = label_html; } | |
else { btn.textContent = label || ""; } | |
// Hover → preview (bind several events for reliability) | |
const preview = () => { | |
model.set('hovered_id', tid); | |
model.save_changes(); | |
}; | |
btn.addEventListener('mouseenter', preview); | |
btn.addEventListener('mouseover', preview); | |
btn.addEventListener('mousemove', preview); | |
btn.addEventListener('focus', preview); | |
// Click → append | |
btn.addEventListener('click', () => { | |
model.set('clicked_id', tid); | |
model.save_changes(); | |
}); | |
wrap.appendChild(btn); | |
}); | |
el.appendChild(wrap); | |
}; | |
renderList(); | |
model.on('change:items', renderList); | |
} | |
""" | |
items = t.List(trait=t.Dict()).tag(sync=True) # [{tid:int, label?:str, label_html?:str}, ...] | |
hovered_id = t.Int(allow_none=True).tag(sync=True) | |
clicked_id = t.Int(allow_none=True).tag(sync=True) | |
# ---------- Predictions list (uses HoverList) ---------- | |
def PredictionsList(): | |
df = preds_rx.value | |
with solara.Column(gap="6px", style={"maxWidth": "720px"}): | |
solara.Markdown("### Prediction") | |
solara.Text( | |
" # probs tokenID next predicted", | |
style={ | |
"color": "var(--muted)", | |
"fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace', | |
}, | |
) | |
# Build items for the browser widget | |
items = [] | |
for i, row in df.iterrows(): | |
tid = int(row["id"]) | |
prob = row["probs"] # already a formatted string like "4.12%" | |
tok_disp = display_token_from_id(tid) | |
tok_safe = html.escape(tok_disp) # protect the HTML label | |
label_html = ( | |
f'<div class="rowbtn-grid">' | |
f' <span class="c0">{i}</span>' | |
f' <span class="c1">{prob}</span>' | |
f' <span class="c2">{tid}</span>' | |
f' <span class="c3">{tok_safe}</span>' | |
f'</div>' | |
) | |
items.append({"tid": tid, "label_html": label_html}) # <-- note label_html | |
w = HoverList() | |
w.items = items | |
# Hover → preview (updates plot + neighbor chips) | |
def _on_hover(change): | |
tid = change["new"] | |
if tid is not None: | |
preview_token(int(tid)) | |
w.observe(_on_hover, names="hovered_id") | |
# Click → append | |
def _on_click(change): | |
tid = change["new"] | |
if tid is not None: | |
append_token(int(tid)) | |
w.observe(_on_click, names="clicked_id") | |
solara.display(w) | |
# ---------- Page ---------- | |
def Page(): | |
solara.Style(theme_css) | |
with solara.Column(margin=8, gap="10px"): | |
solara.Markdown("# Next-Token Predictor + Semantic Neighborhood") | |
solara.Markdown( | |
"Type text to see AI's top predictions for the next token. " | |
"Click a token to append it to your text. " | |
"Hover over a token to preview its **semantic neighborhood**." | |
) | |
solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"}) | |
with solara.Row(classes=["app-row"]): | |
with solara.Column(classes=["predictions-panel"]): | |
PredictionsList() | |
with solara.Column(classes=["plot-panel"]): | |
solara.Markdown("### Semantic Neighborhood") | |
if not coords: | |
solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.") | |
else: | |
solara.FigurePlotly(fig_rx.value) | |
if neighbor_list_rx.value: | |
solara.Markdown("**Nearest neighbors:**") | |
with solara.Row(style={"flex-wrap":"wrap"}): | |
for tok, sim in neighbor_list_rx.value: | |
solara.HTML( | |
tag="span", | |
unsafe_innerHTML=f'<span class="badge">{tok} {(sim*100):.1f}%</span>' | |
) | |
elif neigh_msg_rx.value: | |
solara.Text(neigh_msg_rx.value, style={"color":"var(--muted)"}) | |
AutoPredictWatcher() | |
# ---------- Kickoff ---------- | |
on_predict() | |
Page() | |