|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
from sklearn.decomposition import PCA |
|
|
|
|
|
MODEL_NAME = "distilbert-base-uncased" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True) |
|
model.eval() |
|
|
|
def plot_hidden_states(mode, max_tokens, sentence, show_words, focus_token): |
|
|
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=max_tokens) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
all_layers = torch.stack(outputs.hidden_states).squeeze(1).numpy() |
|
|
|
fig = plt.figure() |
|
ax = fig.add_subplot(111, projection="3d") |
|
|
|
if mode == "Per-token trajectory": |
|
hs = outputs.last_hidden_state.squeeze(0).numpy() |
|
xy = PCA(n_components=2).fit_transform(hs) |
|
x, y = xy[:, 0], xy[:, 1] |
|
z = np.arange(len(x)) |
|
|
|
ax.plot(x, y, z, label="Hidden state trajectory") |
|
ax.legend() |
|
|
|
if show_words: |
|
for i, tok in enumerate(tokens): |
|
ax.text(x[i], y[i], z[i], tok, fontsize=9, color="red") |
|
|
|
elif mode == "Per-layer trajectory": |
|
if focus_token.strip() in tokens: |
|
idx = tokens.index(focus_token.strip()) |
|
else: |
|
idx = 0 |
|
|
|
path_layers = all_layers[:, idx, :] |
|
xy = PCA(n_components=2).fit_transform(path_layers) |
|
x, y = xy[:, 0], xy[:, 1] |
|
z = np.arange(len(x)) |
|
|
|
ax.plot(x, y, z, label=f"Layer evolution for '{tokens[idx]}'") |
|
ax.legend() |
|
|
|
for i in range(len(z)): |
|
ax.text(x[i], y[i], z[i], f"L{i}", fontsize=8, color="blue") |
|
|
|
ax.set_xlabel("PC1") |
|
ax.set_ylabel("PC2") |
|
ax.set_zlabel("Index") |
|
ax.set_title(mode) |
|
plt.tight_layout() |
|
return fig |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π 3D Hidden States Explorer") |
|
gr.Markdown( |
|
""" |
|
Visualize **transformer hidden states** in 3D. |
|
Choose between two modes: |
|
- **Per-token trajectory:** how tokens in a sentence evolve in the final layer. |
|
- **Per-layer trajectory:** how one token moves across all layers. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
mode = gr.Radio(["Per-token trajectory", "Per-layer trajectory"], value="Per-token trajectory", label="Mode") |
|
max_tokens = gr.Slider(10, 64, value=32, step=1, label="Max Tokens") |
|
|
|
sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence") |
|
show_words = gr.Checkbox(label="Show Tokens (per-token mode)", value=True) |
|
focus_token = gr.Textbox(value="hidden", label="Focus Token (per-layer mode)") |
|
|
|
plot = gr.Plot() |
|
|
|
btn = gr.Button("Generate Plot") |
|
btn.click(plot_hidden_states, inputs=[mode, max_tokens, sentence, show_words, focus_token], outputs=plot) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|