rdave88's picture
Update app.py
3c13c5f verified
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA
# Load model & tokenizer once (tiny DistilBERT for speed on Spaces)
MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
model.eval()
def plot_hidden_states(mode, max_tokens, sentence, show_words, focus_token):
# Tokenize
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=max_tokens)
with torch.no_grad():
outputs = model(**inputs)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
all_layers = torch.stack(outputs.hidden_states).squeeze(1).numpy() # [num_layers+1, seq_len, hidden_dim]
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
if mode == "Per-token trajectory":
hs = outputs.last_hidden_state.squeeze(0).numpy()
xy = PCA(n_components=2).fit_transform(hs)
x, y = xy[:, 0], xy[:, 1]
z = np.arange(len(x))
ax.plot(x, y, z, label="Hidden state trajectory")
ax.legend()
if show_words:
for i, tok in enumerate(tokens):
ax.text(x[i], y[i], z[i], tok, fontsize=9, color="red")
elif mode == "Per-layer trajectory":
if focus_token.strip() in tokens:
idx = tokens.index(focus_token.strip())
else:
idx = 0
path_layers = all_layers[:, idx, :] # [num_layers+1, hidden_dim]
xy = PCA(n_components=2).fit_transform(path_layers)
x, y = xy[:, 0], xy[:, 1]
z = np.arange(len(x))
ax.plot(x, y, z, label=f"Layer evolution for '{tokens[idx]}'")
ax.legend()
for i in range(len(z)):
ax.text(x[i], y[i], z[i], f"L{i}", fontsize=8, color="blue")
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_zlabel("Index")
ax.set_title(mode)
plt.tight_layout()
return fig
with gr.Blocks() as demo:
gr.Markdown("# πŸŒ€ 3D Hidden States Explorer")
gr.Markdown(
"""
Visualize **transformer hidden states** in 3D.
Choose between two modes:
- **Per-token trajectory:** how tokens in a sentence evolve in the final layer.
- **Per-layer trajectory:** how one token moves across all layers.
"""
)
with gr.Row():
mode = gr.Radio(["Per-token trajectory", "Per-layer trajectory"], value="Per-token trajectory", label="Mode")
max_tokens = gr.Slider(10, 64, value=32, step=1, label="Max Tokens")
sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
show_words = gr.Checkbox(label="Show Tokens (per-token mode)", value=True)
focus_token = gr.Textbox(value="hidden", label="Focus Token (per-layer mode)")
plot = gr.Plot()
btn = gr.Button("Generate Plot")
btn.click(plot_hidden_states, inputs=[mode, max_tokens, sentence, show_words, focus_token], outputs=plot)
if __name__ == "__main__":
demo.launch()