visualizer / streamlit_app.py
acmc's picture
Create streamlit_app.py
6f92421 verified
raw
history blame
25.8 kB
import streamlit as st
import pandas as pd
import json
import os
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from pathlib import Path
import glob
import requests
from io import StringIO
import zipfile
import tempfile
import shutil
# Set page config
st.set_page_config(
page_title="Attention Analysis Results Explorer",
page_icon="🔍",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
font-weight: bold;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.section-header {
font-size: 1.5rem;
font-weight: bold;
color: #ff7f0e;
margin-top: 2rem;
margin-bottom: 1rem;
}
.metric-container {
background-color: #f0f2f6;
padding: 1rem;
border-radius: 0.5rem;
margin: 0.5rem 0;
}
.stSelectbox > div > div {
background-color: white;
}
</style>
""", unsafe_allow_html=True)
class AttentionResultsExplorer:
def __init__(self, github_repo="ACMCMC/attention", use_cache=True):
self.github_repo = github_repo
self.use_cache = use_cache
self.cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
self.base_path = self.cache_dir
# Initialize cache directory
if not self.cache_dir.exists():
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Download and cache data if needed
if not self._cache_exists() or not use_cache:
self._download_repository()
self.languages = self._get_available_languages()
self.relation_types = None
def _cache_exists(self):
"""Check if cached data exists"""
return (self.cache_dir / "results_en").exists()
def _download_repository(self):
"""Download repository data from GitHub"""
st.info("🔄 Downloading results data from GitHub... This may take a moment.")
# GitHub API to get the repository contents
api_url = f"https://api.github.com/repos/{self.github_repo}/contents"
try:
# Get list of result directories
response = requests.get(api_url)
response.raise_for_status()
contents = response.json()
result_dirs = [item['name'] for item in contents
if item['type'] == 'dir' and item['name'].startswith('results_')]
st.write(f"Found {len(result_dirs)} result directories: {', '.join(result_dirs)}")
# Download each result directory
progress_bar = st.progress(0)
for i, result_dir in enumerate(result_dirs):
st.write(f"Downloading {result_dir}...")
self._download_directory(result_dir)
progress_bar.progress((i + 1) / len(result_dirs))
st.success("✅ Download completed!")
except Exception as e:
st.error(f"❌ Error downloading repository: {str(e)}")
st.error("Please check the repository URL and your internet connection.")
raise
def _download_directory(self, dir_name, path=""):
"""Recursively download a directory from GitHub"""
url = f"https://api.github.com/repos/{self.github_repo}/contents/{path}{dir_name}"
try:
response = requests.get(url)
response.raise_for_status()
contents = response.json()
local_dir = self.cache_dir / path / dir_name
local_dir.mkdir(parents=True, exist_ok=True)
for item in contents:
if item['type'] == 'file':
self._download_file(item, local_dir)
elif item['type'] == 'dir':
self._download_directory(item['name'], f"{path}{dir_name}/")
except Exception as e:
st.warning(f"Could not download {dir_name}: {str(e)}")
def _download_file(self, file_info, local_dir):
"""Download a single file from GitHub"""
try:
# Download file content
response = requests.get(file_info['download_url'])
response.raise_for_status()
# Save to local cache
local_file = local_dir / file_info['name']
# Handle different file types
if file_info['name'].endswith(('.csv', '.json')):
with open(local_file, 'w', encoding='utf-8') as f:
f.write(response.text)
else: # Binary files like PDFs
with open(local_file, 'wb') as f:
f.write(response.content)
except Exception as e:
st.warning(f"Could not download file {file_info['name']}: {str(e)}")
def _get_available_languages(self):
"""Get all available language directories"""
if not self.base_path.exists():
return []
result_dirs = [d.name for d in self.base_path.iterdir()
if d.is_dir() and d.name.startswith("results_")]
languages = [d.replace("results_", "") for d in result_dirs]
return sorted(languages)
def _get_experimental_configs(self, language):
"""Get all experimental configurations for a language"""
lang_dir = self.base_path / f"results_{language}"
if not lang_dir.exists():
return []
configs = [d.name for d in lang_dir.iterdir() if d.is_dir()]
return sorted(configs)
def _get_models(self, language, config):
"""Get all models for a language and configuration"""
config_dir = self.base_path / f"results_{language}" / config
if not config_dir.exists():
return []
models = [d.name for d in config_dir.iterdir() if d.is_dir()]
return sorted(models)
def _parse_config_name(self, config_name):
"""Parse configuration name into readable format"""
parts = config_name.split('+')
config_dict = {}
for part in parts:
if '_' in part:
key, value = part.split('_', 1)
config_dict[key.replace('_', ' ').title()] = value
return config_dict
def _load_metadata(self, language, config, model):
"""Load metadata for a specific combination"""
metadata_path = self.base_path / f"results_{language}" / config / model / "metadata" / "metadata.json"
if metadata_path.exists():
with open(metadata_path, 'r') as f:
return json.load(f)
return None
def _load_uas_scores(self, language, config, model):
"""Load UAS scores data"""
uas_dir = self.base_path / f"results_{language}" / config / model / "uas_scores"
if not uas_dir.exists():
return {}
uas_data = {}
csv_files = list(uas_dir.glob("uas_*.csv"))
if csv_files:
progress_bar = st.progress(0)
status_text = st.empty()
for i, csv_file in enumerate(csv_files):
relation = csv_file.stem.replace("uas_", "")
status_text.text(f"Loading UAS data: {relation}")
try:
df = pd.read_csv(csv_file, index_col=0)
uas_data[relation] = df
except Exception as e:
st.warning(f"Could not load {csv_file.name}: {e}")
progress_bar.progress((i + 1) / len(csv_files))
progress_bar.empty()
status_text.empty()
return uas_data
def _load_head_matching(self, language, config, model):
"""Load head matching data"""
heads_dir = self.base_path / f"results_{language}" / config / model / "number_of_heads_matching"
if not heads_dir.exists():
return {}
heads_data = {}
csv_files = list(heads_dir.glob("heads_matching_*.csv"))
if csv_files:
progress_bar = st.progress(0)
status_text = st.empty()
for i, csv_file in enumerate(csv_files):
relation = csv_file.stem.replace("heads_matching_", "").replace(f"_{model}", "")
status_text.text(f"Loading head matching data: {relation}")
try:
df = pd.read_csv(csv_file, index_col=0)
heads_data[relation] = df
except Exception as e:
st.warning(f"Could not load {csv_file.name}: {e}")
progress_bar.progress((i + 1) / len(csv_files))
progress_bar.empty()
status_text.empty()
return heads_data
def _load_variability(self, language, config, model):
"""Load variability data"""
var_path = self.base_path / f"results_{language}" / config / model / "variability" / "variability_list.csv"
if var_path.exists():
try:
return pd.read_csv(var_path, index_col=0)
except Exception as e:
st.warning(f"Could not load variability data: {e}")
return None
def _get_available_figures(self, language, config, model):
"""Get all available figure files"""
figures_dir = self.base_path / f"results_{language}" / config / model / "figures"
if not figures_dir.exists():
return []
return list(figures_dir.glob("*.pdf"))
def main():
# Title
st.markdown('<div class="main-header">🔍 Attention Analysis Results Explorer</div>', unsafe_allow_html=True)
# Sidebar for navigation
st.sidebar.title("🔧 Configuration")
# Cache management section
st.sidebar.markdown("### 📁 Data Management")
# Initialize explorer
use_cache = st.sidebar.checkbox("Use cached data", value=True,
help="Use previously downloaded data if available")
if st.sidebar.button("🔄 Refresh Data", help="Download fresh data from GitHub"):
# Clear cache and re-download
cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
if cache_dir.exists():
shutil.rmtree(cache_dir)
st.rerun()
# Show cache status
cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache"
if cache_dir.exists():
st.sidebar.success("✅ Data cached locally")
else:
st.sidebar.info("📥 Will download data from GitHub")
st.sidebar.markdown("---")
# Initialize explorer with error handling
try:
explorer = AttentionResultsExplorer(use_cache=use_cache)
except Exception as e:
st.error(f"❌ Failed to initialize data explorer: {str(e)}")
st.error("Please check your internet connection and try again.")
return
# Check if any languages are available
if not explorer.languages:
st.error("❌ No result data found. Please check the GitHub repository.")
return
# Language selection
selected_language = st.sidebar.selectbox(
"Select Language",
options=explorer.languages,
help="Choose the language dataset to explore"
)
# Get configurations for selected language
configs = explorer._get_experimental_configs(selected_language)
if not configs:
st.error(f"No configurations found for language: {selected_language}")
return
# Configuration selection
selected_config = st.sidebar.selectbox(
"Select Experimental Configuration",
options=configs,
help="Choose the experimental configuration"
)
# Parse and display configuration details
config_details = explorer._parse_config_name(selected_config)
st.sidebar.markdown("**Configuration Details:**")
for key, value in config_details.items():
st.sidebar.markdown(f"- **{key}**: {value}")
# Get models for selected language and config
models = explorer._get_models(selected_language, selected_config)
if not models:
st.error(f"No models found for {selected_language}/{selected_config}")
return
# Model selection
selected_model = st.sidebar.selectbox(
"Select Model",
options=models,
help="Choose the model to analyze"
)
# Main content area
tab1, tab2, tab3, tab4, tab5 = st.tabs([
"📊 Overview",
"🎯 UAS Scores",
"🧠 Head Matching",
"📈 Variability",
"🖼️ Figures"
])
# Tab 1: Overview
with tab1:
st.markdown('<div class="section-header">Experiment Overview</div>', unsafe_allow_html=True)
# Load metadata
metadata = explorer._load_metadata(selected_language, selected_config, selected_model)
if metadata:
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Samples", metadata.get('total_number', 'N/A'))
with col2:
st.metric("Processed Correctly", metadata.get('number_processed_correctly', 'N/A'))
with col3:
st.metric("Errors", metadata.get('number_errored', 'N/A'))
with col4:
success_rate = (metadata.get('number_processed_correctly', 0) /
metadata.get('total_number', 1)) * 100 if metadata.get('total_number') else 0
st.metric("Success Rate", f"{success_rate:.1f}%")
st.markdown("**Random Seed:**", metadata.get('random_seed', 'N/A'))
if metadata.get('errored_phrases'):
st.markdown("**Errored Phrase IDs:**")
st.write(metadata['errored_phrases'])
else:
st.warning("No metadata available for this configuration.")
# Quick stats about available data
st.markdown('<div class="section-header">Available Data</div>', unsafe_allow_html=True)
uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model)
heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model)
variability_data = explorer._load_variability(selected_language, selected_config, selected_model)
figures = explorer._get_available_figures(selected_language, selected_config, selected_model)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("UAS Relations", len(uas_data))
with col2:
st.metric("Head Matching Relations", len(heads_data))
with col3:
st.metric("Variability Data", "✓" if variability_data is not None else "✗")
with col4:
st.metric("Figure Files", len(figures))
# Tab 2: UAS Scores
with tab2:
st.markdown('<div class="section-header">UAS (Unlabeled Attachment Score) Analysis</div>', unsafe_allow_html=True)
uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model)
if uas_data:
# Relation selection
selected_relation = st.selectbox(
"Select Dependency Relation",
options=list(uas_data.keys()),
help="Choose a dependency relation to visualize UAS scores"
)
if selected_relation and selected_relation in uas_data:
df = uas_data[selected_relation]
# Display the data table
st.markdown("**UAS Scores Matrix (Layer × Head)**")
st.dataframe(df, use_container_width=True)
# Create heatmap
fig = px.imshow(
df.values,
x=[f"Head {i}" for i in df.columns],
y=[f"Layer {i}" for i in df.index],
color_continuous_scale="Viridis",
title=f"UAS Scores Heatmap - {selected_relation}",
labels=dict(color="UAS Score")
)
fig.update_layout(height=600)
st.plotly_chart(fig, use_container_width=True)
# Statistics
st.markdown("**Statistics**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Max Score", f"{df.values.max():.4f}")
with col2:
st.metric("Min Score", f"{df.values.min():.4f}")
with col3:
st.metric("Mean Score", f"{df.values.mean():.4f}")
with col4:
st.metric("Std Dev", f"{df.values.std():.4f}")
else:
st.warning("No UAS score data available for this configuration.")
# Tab 3: Head Matching
with tab3:
st.markdown('<div class="section-header">Attention Head Matching Analysis</div>', unsafe_allow_html=True)
heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model)
if heads_data:
# Relation selection
selected_relation = st.selectbox(
"Select Dependency Relation",
options=list(heads_data.keys()),
help="Choose a dependency relation to visualize head matching patterns",
key="heads_relation"
)
if selected_relation and selected_relation in heads_data:
df = heads_data[selected_relation]
# Display the data table
st.markdown("**Head Matching Counts Matrix (Layer × Head)**")
st.dataframe(df, use_container_width=True)
# Create heatmap
fig = px.imshow(
df.values,
x=[f"Head {i}" for i in df.columns],
y=[f"Layer {i}" for i in df.index],
color_continuous_scale="Blues",
title=f"Head Matching Counts - {selected_relation}",
labels=dict(color="Match Count")
)
fig.update_layout(height=600)
st.plotly_chart(fig, use_container_width=True)
# Create bar chart of total matches per layer
layer_totals = df.sum(axis=1)
fig_bar = px.bar(
x=layer_totals.index,
y=layer_totals.values,
title=f"Total Matches per Layer - {selected_relation}",
labels={"x": "Layer", "y": "Total Matches"}
)
fig_bar.update_layout(height=400)
st.plotly_chart(fig_bar, use_container_width=True)
# Statistics
st.markdown("**Statistics**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Matches", int(df.values.sum()))
with col2:
st.metric("Max per Cell", int(df.values.max()))
with col3:
best_layer = layer_totals.idxmax()
st.metric("Best Layer", f"Layer {best_layer}")
with col4:
best_head_idx = np.unravel_index(df.values.argmax(), df.values.shape)
st.metric("Best Head", f"L{best_head_idx[0]}-H{best_head_idx[1]}")
else:
st.warning("No head matching data available for this configuration.")
# Tab 4: Variability
with tab4:
st.markdown('<div class="section-header">Attention Variability Analysis</div>', unsafe_allow_html=True)
variability_data = explorer._load_variability(selected_language, selected_config, selected_model)
if variability_data is not None:
# Display the data table
st.markdown("**Variability Matrix (Layer × Head)**")
st.dataframe(variability_data, use_container_width=True)
# Create heatmap
fig = px.imshow(
variability_data.values,
x=[f"Head {i}" for i in variability_data.columns],
y=[f"Layer {i}" for i in variability_data.index],
color_continuous_scale="Reds",
title="Attention Variability Heatmap",
labels=dict(color="Variability Score")
)
fig.update_layout(height=600)
st.plotly_chart(fig, use_container_width=True)
# Create line plot for variability trends
fig_line = go.Figure()
for col in variability_data.columns:
fig_line.add_trace(go.Scatter(
x=variability_data.index,
y=variability_data[col],
mode='lines+markers',
name=f'Head {col}',
line=dict(width=2)
))
fig_line.update_layout(
title="Variability Trends Across Layers",
xaxis_title="Layer",
yaxis_title="Variability Score",
height=500
)
st.plotly_chart(fig_line, use_container_width=True)
# Statistics
st.markdown("**Statistics**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Max Variability", f"{variability_data.values.max():.4f}")
with col2:
st.metric("Min Variability", f"{variability_data.values.min():.4f}")
with col3:
st.metric("Mean Variability", f"{variability_data.values.mean():.4f}")
with col4:
most_variable_idx = np.unravel_index(variability_data.values.argmax(), variability_data.values.shape)
st.metric("Most Variable", f"L{most_variable_idx[0]}-H{most_variable_idx[1]}")
else:
st.warning("No variability data available for this configuration.")
# Tab 5: Figures
with tab5:
st.markdown('<div class="section-header">Generated Figures</div>', unsafe_allow_html=True)
figures = explorer._get_available_figures(selected_language, selected_config, selected_model)
if figures:
st.markdown(f"**Available Figures: {len(figures)}**")
# Group figures by relation type
figure_groups = {}
for fig_path in figures:
# Extract relation from filename
filename = fig_path.stem
relation = filename.replace("heads_matching_", "").replace(f"_{selected_model}", "")
if relation not in figure_groups:
figure_groups[relation] = []
figure_groups[relation].append(fig_path)
# Select relation to view
selected_fig_relation = st.selectbox(
"Select Relation for Figure View",
options=list(figure_groups.keys()),
help="Choose a dependency relation to view its figure"
)
if selected_fig_relation and selected_fig_relation in figure_groups:
fig_path = figure_groups[selected_fig_relation][0]
st.markdown(f"**Figure: {fig_path.name}**")
st.markdown(f"**Path:** `{fig_path}`")
# Note about PDF viewing
st.info(
"📄 PDF figures are available in the results directory. "
"Due to Streamlit limitations, PDF files cannot be displayed directly in the browser. "
"You can download or view them locally."
)
# Provide download link
try:
with open(fig_path, "rb") as file:
st.download_button(
label=f"📥 Download {fig_path.name}",
data=file.read(),
file_name=fig_path.name,
mime="application/pdf"
)
except Exception as e:
st.error(f"Could not load figure: {e}")
# List all available figures
st.markdown("**All Available Figures:**")
for relation, paths in figure_groups.items():
with st.expander(f"📊 {relation} ({len(paths)} files)"):
for path in paths:
st.markdown(f"- `{path.name}`")
else:
st.warning("No figures available for this configuration.")
# Footer
st.markdown("---")
# Data source information
col1, col2 = st.columns([2, 1])
with col1:
st.markdown(
"🔬 **Attention Analysis Results Explorer** | "
f"Currently viewing: {selected_language.upper()} - {selected_model} | "
"Built with Streamlit"
)
with col2:
st.markdown(
f"📊 **Data Source**: [GitHub Repository](https://github.com/{explorer.github_repo})"
)
if __name__ == "__main__":
main()