File size: 1,960 Bytes
29570bf
132e707
 
 
ddf6886
 
1b07353
29570bf
 
 
132e707
1b07353
29570bf
1b07353
29570bf
 
132e707
ddf6886
 
 
 
 
 
 
29570bf
132e707
ddf6886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29570bf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np

def plot_3d(max_t, num_points, sentence, show_words):
    # Generate spiral
    t = np.linspace(0, max_t, num_points)
    x = np.sin(t)
    y = np.cos(t)
    z = t

    # Plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax.plot(x, y, z, label="3D Spiral")
    ax.legend()

    # Add token labels if requested
    if show_words and sentence.strip():
        tokens = sentence.strip().split()
        idxs = np.linspace(0, len(t) - 1, len(tokens), dtype=int)
        for i, token in zip(idxs, tokens):
            ax.text(x[i], y[i], z[i], token, fontsize=9, color="red")

    return fig

with gr.Blocks() as demo:
    gr.Markdown("# 3D Hidden States Visualization")
    gr.Markdown(
        """
        This plot shows a **3D spiral**, generated using sine and cosine for the x/y axes 
        and a linear sequence for the z-axis:

        - $x = \\sin(t)$  
        - $y = \\cos(t)$  
        - $z = t$  

        Together, these equations trace a spiral around the z-axis.  
        Think of it as an analogy for **hidden states in a neural network**, 
        which evolve over time (the z-axis), while oscillating in complex patterns (x & y axes).  

        ✨ Try typing your own sentence — each word will be placed along the spiral, 
        showing how tokens could be mapped into hidden state space.
        """
    )

    with gr.Row():
        max_t = gr.Slider(5, 50, value=20, step=1, label="Spiral Length")
        num_points = gr.Slider(100, 2000, value=500, step=50, label="Number of Points")

    sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
    show_words = gr.Checkbox(label="Show Tokens", value=True)

    plot = gr.Plot()

    btn = gr.Button("Generate Plot")
    btn.click(plot_3d, inputs=[max_t, num_points, sentence, show_words], outputs=plot)

if __name__ == "__main__":
    demo.launch()