import gradio as gr import matplotlib.pyplot as plt import numpy as np def plot_3d(max_t, num_points, sentence, show_words): # Generate spiral t = np.linspace(0, max_t, num_points) x = np.sin(t) y = np.cos(t) z = t # Plot fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ax.plot(x, y, z, label="3D Spiral") ax.legend() # Add token labels if requested if show_words and sentence.strip(): tokens = sentence.strip().split() idxs = np.linspace(0, len(t) - 1, len(tokens), dtype=int) for i, token in zip(idxs, tokens): ax.text(x[i], y[i], z[i], token, fontsize=9, color="red") return fig with gr.Blocks() as demo: gr.Markdown("# 3D Hidden States Visualization") gr.Markdown( """ This plot shows a **3D spiral**, generated using sine and cosine for the x/y axes and a linear sequence for the z-axis: - $x = \\sin(t)$ - $y = \\cos(t)$ - $z = t$ Together, these equations trace a spiral around the z-axis. Think of it as an analogy for **hidden states in a neural network**, which evolve over time (the z-axis), while oscillating in complex patterns (x & y axes). ✨ Try typing your own sentence — each word will be placed along the spiral, showing how tokens could be mapped into hidden state space. """ ) with gr.Row(): max_t = gr.Slider(5, 50, value=20, step=1, label="Spiral Length") num_points = gr.Slider(100, 2000, value=500, step=50, label="Number of Points") sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence") show_words = gr.Checkbox(label="Show Tokens", value=True) plot = gr.Plot() btn = gr.Button("Generate Plot") btn.click(plot_3d, inputs=[max_t, num_points, sentence, show_words], outputs=plot) if __name__ == "__main__": demo.launch()