import torch import torch.nn as nn import gradio as gr import matplotlib.pyplot as plt import seaborn as sns import io from PIL import Image # -------- 1) Define a Tiny RNN Model (LSTM) and Vocab -------- # For demonstration, we keep the model untrained with small dimensions. # A small toy vocab: vocab_list = ["", "", "the", "cat", "dog", "was", "chasing", "and", "it", "fell", "over", "hello", "world"] vocab_dict = {word: i for i, word in enumerate(vocab_list)} vocab_size = len(vocab_list) # e.g., 13 embedding_dim = 8 hidden_dim = 8 # Simple LSTM model class TinyRNN(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim): super(TinyRNN, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) def forward(self, input_ids): # input_ids: (batch_size, seq_len) embeds = self.embedding(input_ids) # -> (batch_size, seq_len, embedding_dim) outputs, (h_n, c_n) = self.lstm(embeds) # outputs: (batch_size, seq_len, hidden_dim) -> the hidden state at *each* time step # h_n: (1, batch_size, hidden_dim) -> final hidden state return outputs, (h_n, c_n) # Initialize the model (untrained, random weights) tiny_rnn = TinyRNN(vocab_size, embedding_dim, hidden_dim) tiny_rnn.eval() # Not training, just forward pass for visualization # -------- 2) Tokenizer / Indexing Functions -------- def simple_tokenize(text): # Very naive whitespace tokenizer tokens = text.lower().split() return tokens def numericalize(tokens): # Convert tokens to vocab indices, use for OOV indices = [] for t in tokens: if t in vocab_dict: indices.append(vocab_dict[t]) else: indices.append(vocab_dict[""]) return indices # -------- 3) Visualization Function -------- def visualize_rnn_states(input_text): """ 1) Tokenize input_text 2) Convert to vocab indices 3) Forward pass through LSTM 4) Plot heatmap of hidden states across timesteps """ # Tokenize & numericalize tokens = simple_tokenize(input_text) if len(tokens) == 0: tokens = [""] indices = numericalize(tokens) # Convert to Tensor, shape (batch_size=1, seq_len) input_tensor = torch.tensor(indices).unsqueeze(0) # shape (1, seq_len) # LSTM forward with torch.no_grad(): outputs, (h_n, c_n) = tiny_rnn(input_tensor) # outputs shape: (1, seq_len, hidden_dim) outputs = outputs.squeeze(0).cpu().numpy() # shape: (seq_len, hidden_dim) # Create heatmap seq_len, hidden_dim_ = outputs.shape plt.figure(figsize=(6, max(3, seq_len * 0.4))) # dynamic height if many tokens sns.heatmap( outputs, yticklabels=tokens, xticklabels=[f"h{i}" for i in range(hidden_dim_)], cmap="coolwarm", center=0 ) plt.title("RNN Hidden States Heatmap") plt.ylabel("Tokens") plt.xlabel("Hidden State Dimensions (size=8)") plt.tight_layout() # Convert plot to an image for Gradio buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() return Image.open(buf) # -------- 4) Gradio Interface -------- demo = gr.Interface( fn=visualize_rnn_states, inputs=gr.Textbox(lines=2, label="Input Text", value="The cat was chasing the dog"), outputs="image", title="RNN (LSTM) Hidden States Visualizer", description=( "Visualize how an untrained LSTM's hidden state (dim=8) changes " "for each token in your input text. Rows=timesteps, Columns=hidden dim." ), ) demo.launch(debug=True, share=True)