import gradio as gr import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt from transformers import BertTokenizer, BertModel from sklearn.manifold import TSNE import seaborn as sns from captum.attr import IntegratedGradients import io import base64 import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize BERT model and tokenizer with eager attention try: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased', attn_implementation="eager") model.eval() except Exception as e: logger.error(f"Failed to load BERT model: {e}") raise # Store intermediate activations activations = {} def hook_fn(module, input, output, name): activations[str(name)] = output # Ensure name is a string # Register hooks for BERT layers for name, layer in model.named_modules(): if 'layer' in name or 'embeddings' in name: layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n)) def convert_dict_keys_to_str(d): """Recursively convert all dictionary keys to strings.""" if isinstance(d, dict): return {str(k): convert_dict_keys_to_str(v) for k, v in d.items()} elif isinstance(d, list): return [convert_dict_keys_to_str(item) for item in d] elif isinstance(d, np.ndarray): return d.tolist() # Convert numpy arrays to lists return d def process_input(input_text, layer_name, visualize_option, attribution_target=0): """ Process input text, compute embeddings, activations, attention, and attribution. Parameters: - input_text: User-provided text - layer_name: Selected layer for activation visualization - visualize_option: 'Embeddings', 'Attention', or 'Activations' - attribution_target: Target class for attribution (0 or 1) Returns: - HTML string with base64-encoded image(s) - List of dataframe dictionaries with string keys - Status message """ global activations activations = {} # Reset activations try: # Validate input if not input_text.strip(): return "
Error: Input text cannot be empty.
", [{"Error": ["Input text cannot be empty."]}], "Error: Input text cannot be empty." # Tokenize input inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512) input_ids = inputs['input_ids'].to(dtype=torch.long) # Ensure LongTensor attention_mask = inputs['attention_mask'] # Forward pass with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True) embeddings = outputs.last_hidden_state # [batch, seq_len, hidden_size] attentions = outputs.attentions # List of attention weights # Convert token IDs to tokens tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) # Initialize outputs html_plots = [] dataframes = [] # Visualization: Embeddings (t-SNE) if visualize_option == "Embeddings": emb = embeddings[0].detach().numpy() # [seq_len, hidden_size] if emb.shape[0] > 1: try: tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1)) reduced = tsne.fit_transform(emb) fig, ax = plt.subplots(figsize=(8, 6)) ax.scatter(reduced[:, 0], reduced[:, 1], c='blue') for i, token in enumerate(tokens): ax.annotate(token, (reduced[i, 0], reduced[i, 1])) ax.set_title("t-SNE of Token Embeddings") buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_plots.append(f'Error: t-SNE computation failed.
") else: dataframes.append({"Error": ["Too few tokens for t-SNE."]}) html_plots.append("Error: Too few tokens for t-SNE.
") # Visualization: Attention Weights elif visualize_option == "Attention": if attentions: attn = attentions[-1][0, 0].detach().numpy() # Last layer, first head fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax) ax.set_title("Attention Weights (Last Layer, Head 0)") plt.xticks(rotation=45) plt.yticks(rotation=0) buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_plots.append(f'Error: No attention weights available.
") # Visualization: Activations elif visualize_option == "Activations": if layer_name in activations: act = activations[layer_name] if isinstance(act, tuple): act = act[0] act = act[0].detach().numpy() # [seq_len, hidden_size] dataframe = pd.DataFrame(act, index=tokens, columns=[f"dim_{i}" for i in range(act.shape[1])]) dataframes.append(convert_dict_keys_to_str(dataframe.to_dict())) # Plot mean activation per token fig, ax = plt.subplots(figsize=(8, 6)) mean_act = np.mean(act, axis=1) ax.bar(range(len(mean_act)), mean_act) ax.set_xticks(range(len(mean_act))) ax.set_xticklabels(tokens, rotation=45) ax.set_title(f"Mean Activations in {layer_name}") buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_plots.append(f'Error: Layer {layer_name} not found.
") # Attribution: Integrated Gradients on embeddings def get_embeddings(inputs, attention_mask=None): with torch.no_grad(): embeddings = model.bert.embeddings(inputs) # Get float embeddings return embeddings def forward_func(embeddings, attention_mask=None): outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask) return outputs.pooler_output[:, int(attribution_target)] ig = IntegratedGradients(forward_func) try: # Get embeddings for input_ids embeddings = get_embeddings(input_ids, attention_mask).requires_grad_(True) attributions, _ = ig.attribute( inputs=embeddings, additional_forward_args=(attention_mask,), target=int(attribution_target), return_convergence_delta=True ) attr = attributions[0].detach().numpy().sum(axis=1) # Sum over hidden size attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr}) attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices dataframes.append(convert_dict_keys_to_str(attr_df.to_dict())) # Plot attributions fig, ax = plt.subplots(figsize=(8, 6)) ax.bar(range(len(attr)), attr) ax.set_xticks(range(len(attr))) ax.set_xticklabels(tokens, rotation=45) ax.set_title("Integrated Gradients Attribution") buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') html_plots.append(f'Error: Attribution computation failed.
") # Combine HTML plots html_output = "Error: {e}
", [{"Error": [str(e)]}], f"Error: {e}" # Gradio Interface def create_gradio_interface(): with gr.Blocks(title="Neural Network Visualization Demo") as demo: gr.Markdown("# Neural Network Visualization Demo") gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Input Text", value="The quick brown fox jumps over the lazy dog.", placeholder="Enter text here..." ) layer_name = gr.Dropdown( label="Select Layer", choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name], value="embeddings" ) visualize_option = gr.Radio( label="Visualization Type", choices=["Embeddings", "Attention", "Activations"], value="Embeddings" ) attribution_target = gr.Slider( label="Attribution Target Class (0 or 1)", minimum=0, maximum=1, step=1, value=0 ) submit_btn = gr.Button("Analyze") with gr.Column(): plot_output = gr.HTML(label="Visualizations") dataframe_output = gr.Dataframe(label="Data Outputs") text_output = gr.Textbox(label="Messages") submit_btn.click( fn=process_input, inputs=[input_text, layer_name, visualize_option, attribution_target], outputs=[plot_output, dataframe_output, text_output] ) return demo # Launch the demo locally if __name__ == "__main__": try: demo = create_gradio_interface() demo.launch(server_name="0.0.0.0", server_port=7860, share=False) except Exception as e: logger.error(f"Failed to launch Gradio demo: {e}") print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")