import os import streamlit as st import sentencepiece as spm import pandas as pd import numpy as np st.set_page_config(layout="wide") st.title("Gemini Token Probabilities") # --- Load the SentencePiece tokenizer once --- @st.cache_resource def load_tokenizer(): # Determine the directory that this script lives in (i.e. src/) here = os.path.dirname(__file__) # Build the absolute path to the gemini-1.5-pro-002 folder inside src/ model_dir = os.path.join(here, "gemini-1.5-pro-002") model_path = os.path.join(model_dir, "gemini-1.5-pro-002.spm.model") if not os.path.isfile(model_path): st.error(f"Cannot find model at:\n{model_path}") st.stop() sp = spm.SentencePieceProcessor() sp.Load(model_path) return sp sp = load_tokenizer() # --- Precompute global min and max raw log‐probs over the entire vocab --- @st.cache_data def compute_vocab_min_max(_sp: spm.SentencePieceProcessor): scores = np.array([_sp.GetScore(i) for i in range(_sp.GetPieceSize())], dtype=float) return float(scores.min()), float(scores.max()) global_min, global_max = compute_vocab_min_max(sp) # --- User input area --- text = st.text_area("Enter text to tokenize:", "") if st.button("Tokenize"): if not text.strip(): st.warning("Enter some text first.") else: # 1) Tokenize into subword pieces and IDs pieces = sp.EncodeAsPieces(text) ids = sp.EncodeAsIds(text) # 2) Retrieve raw log‐probability for each input piece raw_scores = np.array([sp.GetScore(tid) for tid in ids], dtype=float) # 3) Normalize each raw_score against [global_min, global_max] → [0, 1] if global_max != global_min: normalized_0_1 = (raw_scores - global_min) / (global_max - global_min) else: normalized_0_1 = np.zeros_like(raw_scores) # 4) Build DataFrame df = pd.DataFrame({ "Token": pieces, # Pass the 0–1 values into “Global Normalized” column "Global Normalized": normalized_0_1 }) # 5) Display with progress bars (as percentages) st.dataframe( df, use_container_width=True, column_config={ "Global Normalized": st.column_config.ProgressColumn( "Score (percent)", help="Raw log-prob min–max normalized over full vocab, shown as %", format="percent", min_value=0.0, max_value=1.0 ) } )