weights / app.py
broadfield-dev's picture
Update app.py
124df05 verified
import gradio as gr
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertModel
from sklearn.manifold import TSNE
import seaborn as sns
from captum.attr import IntegratedGradients
import io
import base64
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize BERT model and tokenizer with eager attention
try:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', attn_implementation="eager")
model.eval()
except Exception as e:
logger.error(f"Failed to load BERT model: {e}")
raise
# Store intermediate activations
activations = {}
def hook_fn(module, input, output, name):
activations[str(name)] = output # Ensure name is a string
# Register hooks for BERT layers
for name, layer in model.named_modules():
if 'layer' in name or 'embeddings' in name:
layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))
def convert_dict_keys_to_str(d):
"""Recursively convert all dictionary keys to strings."""
if isinstance(d, dict):
return {str(k): convert_dict_keys_to_str(v) for k, v in d.items()}
elif isinstance(d, list):
return [convert_dict_keys_to_str(item) for item in d]
elif isinstance(d, np.ndarray):
return d.tolist() # Convert numpy arrays to lists
return d
def process_input(input_text, layer_name, visualize_option, attribution_target=0):
"""
Process input text, compute embeddings, activations, attention, and attribution.
Parameters:
- input_text: User-provided text
- layer_name: Selected layer for activation visualization
- visualize_option: 'Embeddings', 'Attention', or 'Activations'
- attribution_target: Target class for attribution (0 or 1)
Returns:
- HTML string with base64-encoded image(s)
- List of dataframe dictionaries with string keys
- Status message
"""
global activations
activations = {} # Reset activations
try:
# Validate input
if not input_text.strip():
return "<p>Error: Input text cannot be empty.</p>", [{"Error": ["Input text cannot be empty."]}], "Error: Input text cannot be empty."
# Tokenize input
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
input_ids = inputs['input_ids'].to(dtype=torch.long) # Ensure LongTensor
attention_mask = inputs['attention_mask']
# Forward pass
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
embeddings = outputs.last_hidden_state # [batch, seq_len, hidden_size]
attentions = outputs.attentions # List of attention weights
# Convert token IDs to tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Initialize outputs
html_plots = []
dataframes = []
# Visualization: Embeddings (t-SNE)
if visualize_option == "Embeddings":
emb = embeddings[0].detach().numpy() # [seq_len, hidden_size]
if emb.shape[0] > 1:
try:
tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1))
reduced = tsne.fit_transform(emb)
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(reduced[:, 0], reduced[:, 1], c='blue')
for i, token in enumerate(tokens):
ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
ax.set_title("t-SNE of Token Embeddings")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="t-SNE Plot" style="max-width:100%;"/>')
plt.close()
# Dataframe for coordinates
dataframe = pd.DataFrame({
"Token": tokens,
"t-SNE_X": reduced[:, 0],
"t-SNE_Y": reduced[:, 1]
})
dataframe.index = [f"idx_{i}" for i in range(len(dataframe))] # String indices
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
except Exception as e:
logger.warning(f"t-SNE failed: {e}")
dataframes.append({"Error": [str(e)]})
html_plots.append("<p>Error: t-SNE computation failed.</p>")
else:
dataframes.append({"Error": ["Too few tokens for t-SNE."]})
html_plots.append("<p>Error: Too few tokens for t-SNE.</p>")
# Visualization: Attention Weights
elif visualize_option == "Attention":
if attentions:
attn = attentions[-1][0, 0].detach().numpy() # Last layer, first head
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
ax.set_title("Attention Weights (Last Layer, Head 0)")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>')
plt.close()
# Dataframe for attention weights
dataframe = pd.DataFrame(attn, index=tokens, columns=[f"token_{i}" for i in range(len(tokens))])
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
else:
dataframes.append({"Error": ["No attention weights available."]})
html_plots.append("<p>Error: No attention weights available.</p>")
# Visualization: Activations
elif visualize_option == "Activations":
if layer_name in activations:
act = activations[layer_name]
if isinstance(act, tuple):
act = act[0]
act = act[0].detach().numpy() # [seq_len, hidden_size]
dataframe = pd.DataFrame(act, index=tokens, columns=[f"dim_{i}" for i in range(act.shape[1])])
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict()))
# Plot mean activation per token
fig, ax = plt.subplots(figsize=(8, 6))
mean_act = np.mean(act, axis=1)
ax.bar(range(len(mean_act)), mean_act)
ax.set_xticks(range(len(mean_act)))
ax.set_xticklabels(tokens, rotation=45)
ax.set_title(f"Mean Activations in {layer_name}")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Activations Plot" style="max-width:100%;"/>')
plt.close()
else:
dataframes.append({"Error": [f"Layer {layer_name} not found."]})
html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>")
# Attribution: Integrated Gradients on embeddings
def get_embeddings(inputs, attention_mask=None):
with torch.no_grad():
embeddings = model.bert.embeddings(inputs) # Get float embeddings
return embeddings
def forward_func(embeddings, attention_mask=None):
outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
return outputs.pooler_output[:, int(attribution_target)]
ig = IntegratedGradients(forward_func)
try:
# Get embeddings for input_ids
embeddings = get_embeddings(input_ids, attention_mask).requires_grad_(True)
attributions, _ = ig.attribute(
inputs=embeddings,
additional_forward_args=(attention_mask,),
target=int(attribution_target),
return_convergence_delta=True
)
attr = attributions[0].detach().numpy().sum(axis=1) # Sum over hidden size
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr})
attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices
dataframes.append(convert_dict_keys_to_str(attr_df.to_dict()))
# Plot attributions
fig, ax = plt.subplots(figsize=(8, 6))
ax.bar(range(len(attr)), attr)
ax.set_xticks(range(len(attr)))
ax.set_xticklabels(tokens, rotation=45)
ax.set_title("Integrated Gradients Attribution")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>')
plt.close()
except Exception as e:
logger.warning(f"Integrated Gradients failed: {e}")
dataframes.append({"Error": [str(e)]})
html_plots.append("<p>Error: Attribution computation failed.</p>")
# Combine HTML plots
html_output = "<div>" + "".join(html_plots) + "</div>"
return html_output, dataframes, "Processing complete."
except Exception as e:
logger.error(f"Processing failed: {e}")
return f"<p>Error: {e}</p>", [{"Error": [str(e)]}], f"Error: {e}"
# Gradio Interface
def create_gradio_interface():
with gr.Blocks(title="Neural Network Visualization Demo") as demo:
gr.Markdown("# Neural Network Visualization Demo")
gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
value="The quick brown fox jumps over the lazy dog.",
placeholder="Enter text here..."
)
layer_name = gr.Dropdown(
label="Select Layer",
choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name],
value="embeddings"
)
visualize_option = gr.Radio(
label="Visualization Type",
choices=["Embeddings", "Attention", "Activations"],
value="Embeddings"
)
attribution_target = gr.Slider(
label="Attribution Target Class (0 or 1)",
minimum=0,
maximum=1,
step=1,
value=0
)
submit_btn = gr.Button("Analyze")
with gr.Column():
plot_output = gr.HTML(label="Visualizations")
dataframe_output = gr.Dataframe(label="Data Outputs")
text_output = gr.Textbox(label="Messages")
submit_btn.click(
fn=process_input,
inputs=[input_text, layer_name, visualize_option, attribution_target],
outputs=[plot_output, dataframe_output, text_output]
)
return demo
# Launch the demo locally
if __name__ == "__main__":
try:
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
except Exception as e:
logger.error(f"Failed to launch Gradio demo: {e}")
print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.")