Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from transformers import BertTokenizer, BertModel | |
from sklearn.manifold import TSNE | |
import seaborn as sns | |
from captum.attr import IntegratedGradients | |
import io | |
import base64 | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize BERT model and tokenizer with eager attention | |
try: | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased', attn_implementation="eager") | |
model.eval() | |
except Exception as e: | |
logger.error(f"Failed to load BERT model: {e}") | |
raise | |
# Store intermediate activations | |
activations = {} | |
def hook_fn(module, input, output, name): | |
activations[str(name)] = output # Ensure name is a string | |
# Register hooks for BERT layers | |
for name, layer in model.named_modules(): | |
if 'layer' in name or 'embeddings' in name: | |
layer.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n)) | |
def convert_dict_keys_to_str(d): | |
"""Recursively convert all dictionary keys to strings.""" | |
if isinstance(d, dict): | |
return {str(k): convert_dict_keys_to_str(v) for k, v in d.items()} | |
elif isinstance(d, list): | |
return [convert_dict_keys_to_str(item) for item in d] | |
elif isinstance(d, np.ndarray): | |
return d.tolist() # Convert numpy arrays to lists | |
return d | |
def process_input(input_text, layer_name, visualize_option, attribution_target=0): | |
""" | |
Process input text, compute embeddings, activations, attention, and attribution. | |
Parameters: | |
- input_text: User-provided text | |
- layer_name: Selected layer for activation visualization | |
- visualize_option: 'Embeddings', 'Attention', or 'Activations' | |
- attribution_target: Target class for attribution (0 or 1) | |
Returns: | |
- HTML string with base64-encoded image(s) | |
- List of dataframe dictionaries with string keys | |
- Status message | |
""" | |
global activations | |
activations = {} # Reset activations | |
try: | |
# Validate input | |
if not input_text.strip(): | |
return "<p>Error: Input text cannot be empty.</p>", [{"Error": ["Input text cannot be empty."]}], "Error: Input text cannot be empty." | |
# Tokenize input | |
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512) | |
input_ids = inputs['input_ids'].to(dtype=torch.long) # Ensure LongTensor | |
attention_mask = inputs['attention_mask'] | |
# Forward pass | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True) | |
embeddings = outputs.last_hidden_state # [batch, seq_len, hidden_size] | |
attentions = outputs.attentions # List of attention weights | |
# Convert token IDs to tokens | |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
# Initialize outputs | |
html_plots = [] | |
dataframes = [] | |
# Visualization: Embeddings (t-SNE) | |
if visualize_option == "Embeddings": | |
emb = embeddings[0].detach().numpy() # [seq_len, hidden_size] | |
if emb.shape[0] > 1: | |
try: | |
tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, emb.shape[0]-1)) | |
reduced = tsne.fit_transform(emb) | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
ax.scatter(reduced[:, 0], reduced[:, 1], c='blue') | |
for i, token in enumerate(tokens): | |
ax.annotate(token, (reduced[i, 0], reduced[i, 1])) | |
ax.set_title("t-SNE of Token Embeddings") | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="t-SNE Plot" style="max-width:100%;"/>') | |
plt.close() | |
# Dataframe for coordinates | |
dataframe = pd.DataFrame({ | |
"Token": tokens, | |
"t-SNE_X": reduced[:, 0], | |
"t-SNE_Y": reduced[:, 1] | |
}) | |
dataframe.index = [f"idx_{i}" for i in range(len(dataframe))] # String indices | |
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict())) | |
except Exception as e: | |
logger.warning(f"t-SNE failed: {e}") | |
dataframes.append({"Error": [str(e)]}) | |
html_plots.append("<p>Error: t-SNE computation failed.</p>") | |
else: | |
dataframes.append({"Error": ["Too few tokens for t-SNE."]}) | |
html_plots.append("<p>Error: Too few tokens for t-SNE.</p>") | |
# Visualization: Attention Weights | |
elif visualize_option == "Attention": | |
if attentions: | |
attn = attentions[-1][0, 0].detach().numpy() # Last layer, first head | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.heatmap(attn, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax) | |
ax.set_title("Attention Weights (Last Layer, Head 0)") | |
plt.xticks(rotation=45) | |
plt.yticks(rotation=0) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attention Heatmap" style="max-width:100%;"/>') | |
plt.close() | |
# Dataframe for attention weights | |
dataframe = pd.DataFrame(attn, index=tokens, columns=[f"token_{i}" for i in range(len(tokens))]) | |
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict())) | |
else: | |
dataframes.append({"Error": ["No attention weights available."]}) | |
html_plots.append("<p>Error: No attention weights available.</p>") | |
# Visualization: Activations | |
elif visualize_option == "Activations": | |
if layer_name in activations: | |
act = activations[layer_name] | |
if isinstance(act, tuple): | |
act = act[0] | |
act = act[0].detach().numpy() # [seq_len, hidden_size] | |
dataframe = pd.DataFrame(act, index=tokens, columns=[f"dim_{i}" for i in range(act.shape[1])]) | |
dataframes.append(convert_dict_keys_to_str(dataframe.to_dict())) | |
# Plot mean activation per token | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
mean_act = np.mean(act, axis=1) | |
ax.bar(range(len(mean_act)), mean_act) | |
ax.set_xticks(range(len(mean_act))) | |
ax.set_xticklabels(tokens, rotation=45) | |
ax.set_title(f"Mean Activations in {layer_name}") | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Activations Plot" style="max-width:100%;"/>') | |
plt.close() | |
else: | |
dataframes.append({"Error": [f"Layer {layer_name} not found."]}) | |
html_plots.append(f"<p>Error: Layer {layer_name} not found.</p>") | |
# Attribution: Integrated Gradients on embeddings | |
def get_embeddings(inputs, attention_mask=None): | |
with torch.no_grad(): | |
embeddings = model.bert.embeddings(inputs) # Get float embeddings | |
return embeddings | |
def forward_func(embeddings, attention_mask=None): | |
outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask) | |
return outputs.pooler_output[:, int(attribution_target)] | |
ig = IntegratedGradients(forward_func) | |
try: | |
# Get embeddings for input_ids | |
embeddings = get_embeddings(input_ids, attention_mask).requires_grad_(True) | |
attributions, _ = ig.attribute( | |
inputs=embeddings, | |
additional_forward_args=(attention_mask,), | |
target=int(attribution_target), | |
return_convergence_delta=True | |
) | |
attr = attributions[0].detach().numpy().sum(axis=1) # Sum over hidden size | |
attr_df = pd.DataFrame({"Token": tokens, "Attribution": attr}) | |
attr_df.index = [f"idx_{i}" for i in range(len(attr_df))] # String indices | |
dataframes.append(convert_dict_keys_to_str(attr_df.to_dict())) | |
# Plot attributions | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
ax.bar(range(len(attr)), attr) | |
ax.set_xticks(range(len(attr))) | |
ax.set_xticklabels(tokens, rotation=45) | |
ax.set_title("Integrated Gradients Attribution") | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
html_plots.append(f'<img src="data:image/png;base64,{img_base64}" alt="Attribution Plot" style="max-width:100%;"/>') | |
plt.close() | |
except Exception as e: | |
logger.warning(f"Integrated Gradients failed: {e}") | |
dataframes.append({"Error": [str(e)]}) | |
html_plots.append("<p>Error: Attribution computation failed.</p>") | |
# Combine HTML plots | |
html_output = "<div>" + "".join(html_plots) + "</div>" | |
return html_output, dataframes, "Processing complete." | |
except Exception as e: | |
logger.error(f"Processing failed: {e}") | |
return f"<p>Error: {e}</p>", [{"Error": [str(e)]}], f"Error: {e}" | |
# Gradio Interface | |
def create_gradio_interface(): | |
with gr.Blocks(title="Neural Network Visualization Demo") as demo: | |
gr.Markdown("# Neural Network Visualization Demo") | |
gr.Markdown("Analyze BERT's neural network paths. Enter text, select a layer, and choose a visualization.") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
value="The quick brown fox jumps over the lazy dog.", | |
placeholder="Enter text here..." | |
) | |
layer_name = gr.Dropdown( | |
label="Select Layer", | |
choices=[str(name) for name, _ in model.named_modules() if 'layer' in name or 'embeddings' in name], | |
value="embeddings" | |
) | |
visualize_option = gr.Radio( | |
label="Visualization Type", | |
choices=["Embeddings", "Attention", "Activations"], | |
value="Embeddings" | |
) | |
attribution_target = gr.Slider( | |
label="Attribution Target Class (0 or 1)", | |
minimum=0, | |
maximum=1, | |
step=1, | |
value=0 | |
) | |
submit_btn = gr.Button("Analyze") | |
with gr.Column(): | |
plot_output = gr.HTML(label="Visualizations") | |
dataframe_output = gr.Dataframe(label="Data Outputs") | |
text_output = gr.Textbox(label="Messages") | |
submit_btn.click( | |
fn=process_input, | |
inputs=[input_text, layer_name, visualize_option, attribution_target], | |
outputs=[plot_output, dataframe_output, text_output] | |
) | |
return demo | |
# Launch the demo locally | |
if __name__ == "__main__": | |
try: | |
demo = create_gradio_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
except Exception as e: | |
logger.error(f"Failed to launch Gradio demo: {e}") | |
print(f"Error launching demo: {e}. Try running locally with a different port or without share=True.") |