PeterPinetree's picture
Update app.py
2a55a01 verified
raw
history blame
13.9 kB
# app.py
import json
import threading
import time
from pathlib import Path
import solara
import pandas as pd
import plotly.graph_objects as go
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# for robust hover/click from the browser
import anywidget
import traitlets as t
import html # for escaping token text in the HTML label
# ---------- Model ----------
MODEL_ID = "Qwen/Qwen3-0.6B" # same as the working HF Space
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# ---------- Theme & layout (light blue / white / black accents) ----------
theme_css = """
:root{
--primary:#38bdf8; --bg:#ffffff; --text:#0b0f14; --muted:#6b7280; --border:#e5e7eb;
--mono:'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace';
}
/* Base */
body{ background:var(--bg); color:var(--text); margin:0;}
h1{ margin:6px 0 8px; }
/* Two-column layout */
.app-row { display:flex; align-items:flex-start; gap:16px; } /* was 24px */
.predictions-panel { flex:0 0 320px; position:relative; z-index:10;}/* was 360px */
.plot-panel { flex:1 1 auto; position:relative; z-index:1; overflow:hidden; }
/* Prediction rows (tighter) */
.rowbtn{
width:100%;
padding:6px 10px; /* was 10px 12px */
border-radius:10px; /* was 12px */
border:1px solid var(--border);
background:#fff; color:var(--text);
display:flex; justify-content:flex-start; align-items:center;
text-align:left; cursor:pointer; user-select:none;
font-family: var(--mono);
font-size:13px; /* was default ~14–16 */
line-height:1.15;
letter-spacing:.2px;
margin-bottom:6px; /* explicit, keeps list consistent */
}
.rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; }
/* New: 4-column grid inside each row button */
.rowbtn-grid{
display:grid;
grid-template-columns: 28px 72px 72px 1fr; /* # | probs | tokenID | token */
column-gap:8px;
align-items:center;
width:100%;
font-family: var(--mono);
font-size:13px;
line-height:1.15;
}
/* Neighbor chips (smaller) */
.badge{
display:inline-block; padding:2px 6px; /* was 2px 8px */
border:1px solid var(--border); border-radius:999px; margin:2px;
font-size:12px; line-height:1.15;
}
"""
# ---------- Reactive state ----------
text_rx = solara.reactive("Twinkle, twinkle, little ")
preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
selected_token_id_rx = solara.reactive(None)
neighbor_list_rx = solara.reactive([])
last_hovered_id_rx = solara.reactive(None)
auto_running_rx = solara.reactive(True)
neigh_msg_rx = solara.reactive("") # message shown when no neighborhood is available
# ---------- Embedding assets ----------
ASSETS = Path("assets/embeddings")
COORDS_PATH = ASSETS / "pca_top5k_coords.json"
NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
coords = {}
neighbors = {}
ids_set = set()
if COORDS_PATH.exists() and NEIGH_PATH.exists():
coords = json.loads(COORDS_PATH.read_text("utf-8"))
neighbors = json.loads(NEIGH_PATH.read_text("utf-8"))
ids_set = set(map(int, coords.keys()))
else:
notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
# ---------- Helpers ----------
def display_token_from_id(tid: int) -> str:
toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True)
t = toks[0] if toks else ""
for lead in ("▁", "Ġ"):
if t.startswith(lead):
t = t[len(lead):]
t = t.replace("\n","↵")
return t if t.strip() else "␠"
def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str:
# columns: index, probability, token id, token text
return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}"
# ---------- Prediction ----------
def predict_top10(prompt: str) -> pd.DataFrame:
if not prompt:
return pd.DataFrame(columns=["probs", "id", "tok"])
tokens = tokenizer(prompt, return_tensors="pt", padding=False)
out = model.generate(
**tokens,
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=False, # greedy; temp/top_k are ignored (by design)
)
scores = torch.softmax(out.scores[0], dim=-1)
topk = torch.topk(scores, 10)
ids = [int(topk.indices[0, i]) for i in range(10)]
probs = [float(topk.values[0, i]) for i in range(10)]
toks = [tokenizer.decode([i]) for i in ids] # for append
df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
return df
def on_predict():
df = predict_top10(text_rx.value)
preds_rx.set(df)
if len(df) == 0:
return
if selected_token_id_rx.value is None:
preview_token(int(df.iloc[0]["id"])) # only first time
else:
fig_rx.set(highlight(int(selected_token_id_rx.value))) # preserve selection
# ---------- Plot / neighborhood ----------
def base_scatter():
fig = go.Figure()
if coords:
xs, ys = zip(*[coords[k] for k in coords.keys()])
fig.add_trace(go.Scattergl(
x=xs, y=ys, mode="markers",
marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"),
hoverinfo="skip",
))
fig.update_layout(
height=380, margin=dict(l=6,r=6,t=6,b=6),
paper_bgcolor="white", plot_bgcolor="white",
xaxis=dict(visible=False), yaxis=dict(visible=False),
showlegend=False,
)
return fig
fig_rx = solara.reactive(base_scatter())
def get_neighbor_list(token_id: int, k: int = 20):
if not ids_set or token_id not in ids_set:
return []
raw = neighbors.get("neighbors", {}).get(str(token_id), [])
return raw[:k]
def highlight(token_id: int):
fig = base_scatter()
# Not in map (or missing map) → clear chips and show message
if not coords or token_id not in ids_set:
neighbor_list_rx.set([])
if not coords:
neigh_msg_rx.set("Embedding map unavailable – add `assets/embeddings/*.json`.")
else:
neigh_msg_rx.set("Neighborhood unavailable for this token (not in the top-5k set).")
return fig
# In map → clear message and draw neighbors/target
neigh_msg_rx.set("")
nbrs = get_neighbor_list(token_id, k=20)
if nbrs:
nx = [coords[str(nid)][0] for nid,_ in nbrs]
ny = [coords[str(nid)][1] for nid,_ in nbrs]
fig.add_trace(go.Scattergl(
x=nx, y=ny, mode="markers",
marker=dict(size=6, color="rgba(56,189,248,0.75)"),
hoverinfo="skip",
))
chips = [(display_token_from_id(int(nid)), float(sim)) for nid,sim in nbrs]
neighbor_list_rx.set(chips)
else:
neighbor_list_rx.set([])
tx, ty = coords[str(token_id)]
fig.add_trace(go.Scattergl(
x=[tx], y=[ty], mode="markers",
marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)),
hoverinfo="skip",
))
return fig
def preview_token(token_id: int):
# print("preview ->", token_id) # enable for debugging in Space logs
token_id = int(token_id)
if last_hovered_id_rx.value == token_id:
return
last_hovered_id_rx.set(token_id)
selected_token_id_rx.set(token_id)
fig_rx.set(highlight(token_id))
def append_token(token_id: int):
# print("append ->", token_id)
decoded = tokenizer.decode([int(token_id)])
text_rx.set(text_rx.value + decoded)
preview_token(int(token_id))
on_predict()
# ---------- Debounced auto-predict ----------
@solara.component
def AutoPredictWatcher():
text = text_rx.value
auto = auto_running_rx.value
def effect():
if not auto:
return
cancelled = False
snap = text
def worker():
time.sleep(0.25)
if not cancelled and snap == text_rx.value:
on_predict()
threading.Thread(target=worker, daemon=True).start()
def cleanup():
nonlocal cancelled
cancelled = True
return cleanup
solara.use_effect(effect, [text, auto])
return solara.Text("", style={"display": "none"})
# ---------- Hover-enabled list (browser) ----------
class HoverList(anywidget.AnyWidget):
"""
Renders the prediction rows in the browser and streams hover/click events
back to Python via synced traitlets. Supports HTML row labels via `label_html`.
"""
_esm = """
export function render({ model, el }) {
const renderList = () => {
const items = model.get('items') || [];
el.innerHTML = "";
const wrap = document.createElement('div');
wrap.style.display = 'flex';
wrap.style.flexDirection = 'column';
items.forEach((item) => {
const { tid, label, label_html } = item;
const btn = document.createElement('button');
btn.className = 'rowbtn';
btn.setAttribute('type', 'button');
btn.setAttribute('role', 'button');
btn.setAttribute('tabindex', '0');
// Prefer HTML layout if provided; fall back to plain text
if (label_html) { btn.innerHTML = label_html; }
else { btn.textContent = label || ""; }
// Hover → preview (bind several events for reliability)
const preview = () => {
model.set('hovered_id', tid);
model.save_changes();
};
btn.addEventListener('mouseenter', preview);
btn.addEventListener('mouseover', preview);
btn.addEventListener('mousemove', preview);
btn.addEventListener('focus', preview);
// Click → append
btn.addEventListener('click', () => {
model.set('clicked_id', tid);
model.save_changes();
});
wrap.appendChild(btn);
});
el.appendChild(wrap);
};
renderList();
model.on('change:items', renderList);
}
"""
items = t.List(trait=t.Dict()).tag(sync=True) # [{tid:int, label?:str, label_html?:str}, ...]
hovered_id = t.Int(allow_none=True).tag(sync=True)
clicked_id = t.Int(allow_none=True).tag(sync=True)
# ---------- Predictions list (uses HoverList) ----------
@solara.component
def PredictionsList():
df = preds_rx.value
with solara.Column(gap="6px", style={"maxWidth": "720px"}):
solara.Markdown("### Prediction")
solara.Text(
" # probs tokenID next predicted",
style={
"color": "var(--muted)",
"fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace',
},
)
# Build items for the browser widget
items = []
for i, row in df.iterrows():
tid = int(row["id"])
prob = row["probs"] # already a formatted string like "4.12%"
tok_disp = display_token_from_id(tid)
tok_safe = html.escape(tok_disp) # protect the HTML label
label_html = (
f'<div class="rowbtn-grid">'
f' <span class="c0">{i}</span>'
f' <span class="c1">{prob}</span>'
f' <span class="c2">{tid}</span>'
f' <span class="c3">{tok_safe}</span>'
f'</div>'
)
items.append({"tid": tid, "label_html": label_html}) # <-- note label_html
w = HoverList()
w.items = items
# Hover → preview (updates plot + neighbor chips)
def _on_hover(change):
tid = change["new"]
if tid is not None:
preview_token(int(tid))
w.observe(_on_hover, names="hovered_id")
# Click → append
def _on_click(change):
tid = change["new"]
if tid is not None:
append_token(int(tid))
w.observe(_on_click, names="clicked_id")
solara.display(w)
# ---------- Page ----------
@solara.component
def Page():
solara.Style(theme_css)
with solara.Column(margin=8, gap="10px"):
solara.Markdown("# Next-Token Predictor + Semantic Neighborhood")
solara.Markdown(
"Type text to see AI's top predictions for the next token. "
"Click a token to append it to your text. "
"Hover over a token to preview its **semantic neighborhood**."
)
solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
with solara.Row(classes=["app-row"]):
with solara.Column(classes=["predictions-panel"]):
PredictionsList()
with solara.Column(classes=["plot-panel"]):
solara.Markdown("### Semantic Neighborhood")
if not coords:
solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
else:
solara.FigurePlotly(fig_rx.value)
if neighbor_list_rx.value:
solara.Markdown("**Nearest neighbors:**")
with solara.Row(style={"flex-wrap":"wrap"}):
for tok, sim in neighbor_list_rx.value:
solara.HTML(
tag="span",
unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>'
)
elif neigh_msg_rx.value:
solara.Text(neigh_msg_rx.value, style={"color":"var(--muted)"})
AutoPredictWatcher()
# ---------- Kickoff ----------
on_predict()
Page()