import gradio as gr import matplotlib.pyplot as plt import numpy as np import torch from transformers import AutoTokenizer, AutoModel from sklearn.decomposition import PCA # Load model & tokenizer once (tiny DistilBERT for speed on Spaces) MODEL_NAME = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True) model.eval() def plot_hidden_states(mode, max_tokens, sentence, show_words, focus_token): # Tokenize inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=max_tokens) with torch.no_grad(): outputs = model(**inputs) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) all_layers = torch.stack(outputs.hidden_states).squeeze(1).numpy() # [num_layers+1, seq_len, hidden_dim] fig = plt.figure() ax = fig.add_subplot(111, projection="3d") if mode == "Per-token trajectory": hs = outputs.last_hidden_state.squeeze(0).numpy() xy = PCA(n_components=2).fit_transform(hs) x, y = xy[:, 0], xy[:, 1] z = np.arange(len(x)) ax.plot(x, y, z, label="Hidden state trajectory") ax.legend() if show_words: for i, tok in enumerate(tokens): ax.text(x[i], y[i], z[i], tok, fontsize=9, color="red") elif mode == "Per-layer trajectory": if focus_token.strip() in tokens: idx = tokens.index(focus_token.strip()) else: idx = 0 path_layers = all_layers[:, idx, :] # [num_layers+1, hidden_dim] xy = PCA(n_components=2).fit_transform(path_layers) x, y = xy[:, 0], xy[:, 1] z = np.arange(len(x)) ax.plot(x, y, z, label=f"Layer evolution for '{tokens[idx]}'") ax.legend() for i in range(len(z)): ax.text(x[i], y[i], z[i], f"L{i}", fontsize=8, color="blue") ax.set_xlabel("PC1") ax.set_ylabel("PC2") ax.set_zlabel("Index") ax.set_title(mode) plt.tight_layout() return fig with gr.Blocks() as demo: gr.Markdown("# 🌀 3D Hidden States Explorer") gr.Markdown( """ Visualize **transformer hidden states** in 3D. Choose between two modes: - **Per-token trajectory:** how tokens in a sentence evolve in the final layer. - **Per-layer trajectory:** how one token moves across all layers. """ ) with gr.Row(): mode = gr.Radio(["Per-token trajectory", "Per-layer trajectory"], value="Per-token trajectory", label="Mode") max_tokens = gr.Slider(10, 64, value=32, step=1, label="Max Tokens") sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence") show_words = gr.Checkbox(label="Show Tokens (per-token mode)", value=True) focus_token = gr.Textbox(value="hidden", label="Focus Token (per-layer mode)") plot = gr.Plot() btn = gr.Button("Generate Plot") btn.click(plot_hidden_states, inputs=[mode, max_tokens, sentence, show_words, focus_token], outputs=plot) if __name__ == "__main__": demo.launch()