# app.py import json import threading import time from pathlib import Path import solara import pandas as pd import plotly.graph_objects as go import torch from transformers import AutoTokenizer, AutoModelForCausalLM # for robust hover/click from the browser import anywidget import traitlets as t import html # for escaping token text in the HTML label # ---------- Model ---------- MODEL_ID = "Qwen/Qwen3-0.6B" # same as the working HF Space tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) # ---------- Theme & layout (light blue / white / black accents) ---------- theme_css = """ :root{ --primary:#38bdf8; --bg:#ffffff; --text:#0b0f14; --muted:#6b7280; --border:#e5e7eb; --mono:'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace'; } /* Base */ body{ background:var(--bg); color:var(--text); margin:0;} h1{ margin:6px 0 8px; } /* Two-column layout */ .app-row { display:flex; align-items:flex-start; gap:16px; } /* was 24px */ .predictions-panel { flex:0 0 320px; position:relative; z-index:10;}/* was 360px */ .plot-panel { flex:1 1 auto; position:relative; z-index:1; overflow:hidden; } /* Prediction rows (tighter) */ .rowbtn{ width:100%; padding:6px 10px; /* was 10px 12px */ border-radius:10px; /* was 12px */ border:1px solid var(--border); background:#fff; color:var(--text); display:flex; justify-content:flex-start; align-items:center; text-align:left; cursor:pointer; user-select:none; font-family: var(--mono); font-size:13px; /* was default ~14–16 */ line-height:1.15; letter-spacing:.2px; margin-bottom:6px; /* explicit, keeps list consistent */ } .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; } /* New: 4-column grid inside each row button */ .rowbtn-grid{ display:grid; grid-template-columns: 28px 72px 72px 1fr; /* # | probs | tokenID | token */ column-gap:8px; align-items:center; width:100%; font-family: var(--mono); font-size:13px; line-height:1.15; } /* Neighbor chips (smaller) */ .badge{ display:inline-block; padding:2px 6px; /* was 2px 8px */ border:1px solid var(--border); border-radius:999px; margin:2px; font-size:12px; line-height:1.15; } """ # ---------- Reactive state ---------- text_rx = solara.reactive("Twinkle, twinkle, little ") preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"])) selected_token_id_rx = solara.reactive(None) neighbor_list_rx = solara.reactive([]) last_hovered_id_rx = solara.reactive(None) auto_running_rx = solara.reactive(True) neigh_msg_rx = solara.reactive("") # message shown when no neighborhood is available # ---------- Embedding assets ---------- ASSETS = Path("assets/embeddings") COORDS_PATH = ASSETS / "pca_top5k_coords.json" NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json" coords = {} neighbors = {} ids_set = set() if COORDS_PATH.exists() and NEIGH_PATH.exists(): coords = json.loads(COORDS_PATH.read_text("utf-8")) neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) ids_set = set(map(int, coords.keys())) else: notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.") # ---------- Helpers ---------- def display_token_from_id(tid: int) -> str: toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True) t = toks[0] if toks else "" for lead in ("▁", "Ġ"): if t.startswith(lead): t = t[len(lead):] t = t.replace("\n","↵") return t if t.strip() else "␠" def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str: # columns: index, probability, token id, token text return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}" # ---------- Prediction ---------- def predict_top10(prompt: str) -> pd.DataFrame: if not prompt: return pd.DataFrame(columns=["probs", "id", "tok"]) tokens = tokenizer(prompt, return_tensors="pt", padding=False) out = model.generate( **tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, do_sample=False, # greedy; temp/top_k are ignored (by design) ) scores = torch.softmax(out.scores[0], dim=-1) topk = torch.topk(scores, 10) ids = [int(topk.indices[0, i]) for i in range(10)] probs = [float(topk.values[0, i]) for i in range(10)] toks = [tokenizer.decode([i]) for i in ids] # for append df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks}) df["probs"] = df["probs"].map(lambda p: f"{p:.2%}") return df def on_predict(): df = predict_top10(text_rx.value) preds_rx.set(df) if len(df) == 0: return if selected_token_id_rx.value is None: preview_token(int(df.iloc[0]["id"])) # only first time else: fig_rx.set(highlight(int(selected_token_id_rx.value))) # preserve selection # ---------- Plot / neighborhood ---------- def base_scatter(): fig = go.Figure() if coords: xs, ys = zip(*[coords[k] for k in coords.keys()]) fig.add_trace(go.Scattergl( x=xs, y=ys, mode="markers", marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"), hoverinfo="skip", )) fig.update_layout( height=380, margin=dict(l=6,r=6,t=6,b=6), paper_bgcolor="white", plot_bgcolor="white", xaxis=dict(visible=False), yaxis=dict(visible=False), showlegend=False, ) return fig fig_rx = solara.reactive(base_scatter()) def get_neighbor_list(token_id: int, k: int = 20): if not ids_set or token_id not in ids_set: return [] raw = neighbors.get("neighbors", {}).get(str(token_id), []) return raw[:k] def highlight(token_id: int): fig = base_scatter() # Not in map (or missing map) → clear chips and show message if not coords or token_id not in ids_set: neighbor_list_rx.set([]) if not coords: neigh_msg_rx.set("Embedding map unavailable – add `assets/embeddings/*.json`.") else: neigh_msg_rx.set("Neighborhood unavailable for this token (not in the top-5k set).") return fig # In map → clear message and draw neighbors/target neigh_msg_rx.set("") nbrs = get_neighbor_list(token_id, k=20) if nbrs: nx = [coords[str(nid)][0] for nid,_ in nbrs] ny = [coords[str(nid)][1] for nid,_ in nbrs] fig.add_trace(go.Scattergl( x=nx, y=ny, mode="markers", marker=dict(size=6, color="rgba(56,189,248,0.75)"), hoverinfo="skip", )) chips = [(display_token_from_id(int(nid)), float(sim)) for nid,sim in nbrs] neighbor_list_rx.set(chips) else: neighbor_list_rx.set([]) tx, ty = coords[str(token_id)] fig.add_trace(go.Scattergl( x=[tx], y=[ty], mode="markers", marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)), hoverinfo="skip", )) return fig def preview_token(token_id: int): # print("preview ->", token_id) # enable for debugging in Space logs token_id = int(token_id) if last_hovered_id_rx.value == token_id: return last_hovered_id_rx.set(token_id) selected_token_id_rx.set(token_id) fig_rx.set(highlight(token_id)) def append_token(token_id: int): # print("append ->", token_id) decoded = tokenizer.decode([int(token_id)]) text_rx.set(text_rx.value + decoded) preview_token(int(token_id)) on_predict() # ---------- Debounced auto-predict ---------- @solara.component def AutoPredictWatcher(): text = text_rx.value auto = auto_running_rx.value def effect(): if not auto: return cancelled = False snap = text def worker(): time.sleep(0.25) if not cancelled and snap == text_rx.value: on_predict() threading.Thread(target=worker, daemon=True).start() def cleanup(): nonlocal cancelled cancelled = True return cleanup solara.use_effect(effect, [text, auto]) return solara.Text("", style={"display": "none"}) # ---------- Hover-enabled list (browser) ---------- class HoverList(anywidget.AnyWidget): """ Renders the prediction rows in the browser and streams hover/click events back to Python via synced traitlets. Supports HTML row labels via `label_html`. """ _esm = """ export function render({ model, el }) { const renderList = () => { const items = model.get('items') || []; el.innerHTML = ""; const wrap = document.createElement('div'); wrap.style.display = 'flex'; wrap.style.flexDirection = 'column'; items.forEach((item) => { const { tid, label, label_html } = item; const btn = document.createElement('button'); btn.className = 'rowbtn'; btn.setAttribute('type', 'button'); btn.setAttribute('role', 'button'); btn.setAttribute('tabindex', '0'); // Prefer HTML layout if provided; fall back to plain text if (label_html) { btn.innerHTML = label_html; } else { btn.textContent = label || ""; } // Hover → preview (bind several events for reliability) const preview = () => { model.set('hovered_id', tid); model.save_changes(); }; btn.addEventListener('mouseenter', preview); btn.addEventListener('mouseover', preview); btn.addEventListener('mousemove', preview); btn.addEventListener('focus', preview); // Click → append btn.addEventListener('click', () => { model.set('clicked_id', tid); model.save_changes(); }); wrap.appendChild(btn); }); el.appendChild(wrap); }; renderList(); model.on('change:items', renderList); } """ items = t.List(trait=t.Dict()).tag(sync=True) # [{tid:int, label?:str, label_html?:str}, ...] hovered_id = t.Int(allow_none=True).tag(sync=True) clicked_id = t.Int(allow_none=True).tag(sync=True) # ---------- Predictions list (uses HoverList) ---------- @solara.component def PredictionsList(): df = preds_rx.value with solara.Column(gap="6px", style={"maxWidth": "720px"}): solara.Markdown("### Prediction") solara.Text( " # probs tokenID next predicted", style={ "color": "var(--muted)", "fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace', }, ) # Build items for the browser widget items = [] for i, row in df.iterrows(): tid = int(row["id"]) prob = row["probs"] # already a formatted string like "4.12%" tok_disp = display_token_from_id(tid) tok_safe = html.escape(tok_disp) # protect the HTML label label_html = ( f'