File size: 3,127 Bytes
e422b3f
 
 
3c13c5f
 
 
e422b3f
3c13c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
e422b3f
 
 
 
3c13c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e422b3f
3c13c5f
 
 
 
 
 
 
 
 
 
 
e422b3f
 
 
3c13c5f
e422b3f
 
3c13c5f
 
 
 
e422b3f
 
 
 
3c13c5f
 
e422b3f
 
3c13c5f
 
e422b3f
 
 
 
3c13c5f
e422b3f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA

# Load model & tokenizer once (tiny DistilBERT for speed on Spaces)
MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
model.eval()

def plot_hidden_states(mode, max_tokens, sentence, show_words, focus_token):
    # Tokenize
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=max_tokens)
    with torch.no_grad():
        outputs = model(**inputs)

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    all_layers = torch.stack(outputs.hidden_states).squeeze(1).numpy()  # [num_layers+1, seq_len, hidden_dim]

    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")

    if mode == "Per-token trajectory":
        hs = outputs.last_hidden_state.squeeze(0).numpy()
        xy = PCA(n_components=2).fit_transform(hs)
        x, y = xy[:, 0], xy[:, 1]
        z = np.arange(len(x))

        ax.plot(x, y, z, label="Hidden state trajectory")
        ax.legend()

        if show_words:
            for i, tok in enumerate(tokens):
                ax.text(x[i], y[i], z[i], tok, fontsize=9, color="red")

    elif mode == "Per-layer trajectory":
        if focus_token.strip() in tokens:
            idx = tokens.index(focus_token.strip())
        else:
            idx = 0

        path_layers = all_layers[:, idx, :]  # [num_layers+1, hidden_dim]
        xy = PCA(n_components=2).fit_transform(path_layers)
        x, y = xy[:, 0], xy[:, 1]
        z = np.arange(len(x))

        ax.plot(x, y, z, label=f"Layer evolution for '{tokens[idx]}'")
        ax.legend()

        for i in range(len(z)):
            ax.text(x[i], y[i], z[i], f"L{i}", fontsize=8, color="blue")

    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_zlabel("Index")
    ax.set_title(mode)
    plt.tight_layout()
    return fig

with gr.Blocks() as demo:
    gr.Markdown("# πŸŒ€ 3D Hidden States Explorer")
    gr.Markdown(
        """
        Visualize **transformer hidden states** in 3D.  
        Choose between two modes:  
        - **Per-token trajectory:** how tokens in a sentence evolve in the final layer.  
        - **Per-layer trajectory:** how one token moves across all layers.  
        """
    )

    with gr.Row():
        mode = gr.Radio(["Per-token trajectory", "Per-layer trajectory"], value="Per-token trajectory", label="Mode")
        max_tokens = gr.Slider(10, 64, value=32, step=1, label="Max Tokens")

    sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
    show_words = gr.Checkbox(label="Show Tokens (per-token mode)", value=True)
    focus_token = gr.Textbox(value="hidden", label="Focus Token (per-layer mode)")

    plot = gr.Plot()

    btn = gr.Button("Generate Plot")
    btn.click(plot_hidden_states, inputs=[mode, max_tokens, sentence, show_words, focus_token], outputs=plot)

if __name__ == "__main__":
    demo.launch()