rdave88's picture
Switch to Gradio version
ddf6886
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
def plot_3d(max_t, num_points, sentence, show_words):
# Generate spiral
t = np.linspace(0, max_t, num_points)
x = np.sin(t)
y = np.cos(t)
z = t
# Plot
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot(x, y, z, label="3D Spiral")
ax.legend()
# Add token labels if requested
if show_words and sentence.strip():
tokens = sentence.strip().split()
idxs = np.linspace(0, len(t) - 1, len(tokens), dtype=int)
for i, token in zip(idxs, tokens):
ax.text(x[i], y[i], z[i], token, fontsize=9, color="red")
return fig
with gr.Blocks() as demo:
gr.Markdown("# 3D Hidden States Visualization")
gr.Markdown(
"""
This plot shows a **3D spiral**, generated using sine and cosine for the x/y axes
and a linear sequence for the z-axis:
- $x = \\sin(t)$
- $y = \\cos(t)$
- $z = t$
Together, these equations trace a spiral around the z-axis.
Think of it as an analogy for **hidden states in a neural network**,
which evolve over time (the z-axis), while oscillating in complex patterns (x & y axes).
✨ Try typing your own sentence — each word will be placed along the spiral,
showing how tokens could be mapped into hidden state space.
"""
)
with gr.Row():
max_t = gr.Slider(5, 50, value=20, step=1, label="Spiral Length")
num_points = gr.Slider(100, 2000, value=500, step=50, label="Number of Points")
sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
show_words = gr.Checkbox(label="Show Tokens", value=True)
plot = gr.Plot()
btn = gr.Button("Generate Plot")
btn.click(plot_3d, inputs=[max_t, num_points, sentence, show_words], outputs=plot)
if __name__ == "__main__":
demo.launch()