import streamlit as st import pandas as pd import json import os import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import numpy as np from pathlib import Path import glob import requests from io import StringIO import zipfile import tempfile import shutil # Set page config st.set_page_config( page_title="Attention Analysis Results Explorer", page_icon="🔍", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) class AttentionResultsExplorer: def __init__(self, github_repo="ACMCMC/attention", use_cache=True): self.github_repo = github_repo self.use_cache = use_cache self.cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" self.base_path = self.cache_dir # Initialize cache directory if not self.cache_dir.exists(): self.cache_dir.mkdir(parents=True, exist_ok=True) # Download and cache data if needed if not self._cache_exists() or not use_cache: self._download_repository() self.languages = self._get_available_languages() self.relation_types = None def _cache_exists(self): """Check if cached data exists""" return (self.cache_dir / "results_en").exists() def _download_repository(self): """Download repository data from GitHub""" st.info("🔄 Downloading results data from GitHub... This may take a moment.") # GitHub API to get the repository contents api_url = f"https://api.github.com/repos/{self.github_repo}/contents" try: # Get list of result directories response = requests.get(api_url) response.raise_for_status() contents = response.json() result_dirs = [item['name'] for item in contents if item['type'] == 'dir' and item['name'].startswith('results_')] st.write(f"Found {len(result_dirs)} result directories: {', '.join(result_dirs)}") # Download each result directory progress_bar = st.progress(0) for i, result_dir in enumerate(result_dirs): st.write(f"Downloading {result_dir}...") self._download_directory(result_dir) progress_bar.progress((i + 1) / len(result_dirs)) st.success("✅ Download completed!") except Exception as e: st.error(f"❌ Error downloading repository: {str(e)}") st.error("Please check the repository URL and your internet connection.") raise def _download_directory(self, dir_name, path=""): """Recursively download a directory from GitHub""" url = f"https://api.github.com/repos/{self.github_repo}/contents/{path}{dir_name}" try: response = requests.get(url) response.raise_for_status() contents = response.json() local_dir = self.cache_dir / path / dir_name local_dir.mkdir(parents=True, exist_ok=True) for item in contents: if item['type'] == 'file': self._download_file(item, local_dir) elif item['type'] == 'dir': self._download_directory(item['name'], f"{path}{dir_name}/") except Exception as e: st.warning(f"Could not download {dir_name}: {str(e)}") def _download_file(self, file_info, local_dir): """Download a single file from GitHub""" try: # Download file content response = requests.get(file_info['download_url']) response.raise_for_status() # Save to local cache local_file = local_dir / file_info['name'] # Handle different file types if file_info['name'].endswith(('.csv', '.json')): with open(local_file, 'w', encoding='utf-8') as f: f.write(response.text) else: # Binary files like PDFs with open(local_file, 'wb') as f: f.write(response.content) except Exception as e: st.warning(f"Could not download file {file_info['name']}: {str(e)}") def _get_available_languages(self): """Get all available language directories""" if not self.base_path.exists(): return [] result_dirs = [d.name for d in self.base_path.iterdir() if d.is_dir() and d.name.startswith("results_")] languages = [d.replace("results_", "") for d in result_dirs] return sorted(languages) def _get_experimental_configs(self, language): """Get all experimental configurations for a language""" lang_dir = self.base_path / f"results_{language}" if not lang_dir.exists(): return [] configs = [d.name for d in lang_dir.iterdir() if d.is_dir()] return sorted(configs) def _get_models(self, language, config): """Get all models for a language and configuration""" config_dir = self.base_path / f"results_{language}" / config if not config_dir.exists(): return [] models = [d.name for d in config_dir.iterdir() if d.is_dir()] return sorted(models) def _parse_config_name(self, config_name): """Parse configuration name into readable format""" parts = config_name.split('+') config_dict = {} for part in parts: if '_' in part: key, value = part.split('_', 1) config_dict[key.replace('_', ' ').title()] = value return config_dict def _load_metadata(self, language, config, model): """Load metadata for a specific combination""" metadata_path = self.base_path / f"results_{language}" / config / model / "metadata" / "metadata.json" if metadata_path.exists(): with open(metadata_path, 'r') as f: return json.load(f) return None def _load_uas_scores(self, language, config, model): """Load UAS scores data""" uas_dir = self.base_path / f"results_{language}" / config / model / "uas_scores" if not uas_dir.exists(): return {} uas_data = {} csv_files = list(uas_dir.glob("uas_*.csv")) if csv_files: progress_bar = st.progress(0) status_text = st.empty() for i, csv_file in enumerate(csv_files): relation = csv_file.stem.replace("uas_", "") status_text.text(f"Loading UAS data: {relation}") try: df = pd.read_csv(csv_file, index_col=0) uas_data[relation] = df except Exception as e: st.warning(f"Could not load {csv_file.name}: {e}") progress_bar.progress((i + 1) / len(csv_files)) progress_bar.empty() status_text.empty() return uas_data def _load_head_matching(self, language, config, model): """Load head matching data""" heads_dir = self.base_path / f"results_{language}" / config / model / "number_of_heads_matching" if not heads_dir.exists(): return {} heads_data = {} csv_files = list(heads_dir.glob("heads_matching_*.csv")) if csv_files: progress_bar = st.progress(0) status_text = st.empty() for i, csv_file in enumerate(csv_files): relation = csv_file.stem.replace("heads_matching_", "").replace(f"_{model}", "") status_text.text(f"Loading head matching data: {relation}") try: df = pd.read_csv(csv_file, index_col=0) heads_data[relation] = df except Exception as e: st.warning(f"Could not load {csv_file.name}: {e}") progress_bar.progress((i + 1) / len(csv_files)) progress_bar.empty() status_text.empty() return heads_data def _load_variability(self, language, config, model): """Load variability data""" var_path = self.base_path / f"results_{language}" / config / model / "variability" / "variability_list.csv" if var_path.exists(): try: return pd.read_csv(var_path, index_col=0) except Exception as e: st.warning(f"Could not load variability data: {e}") return None def _get_available_figures(self, language, config, model): """Get all available figure files""" figures_dir = self.base_path / f"results_{language}" / config / model / "figures" if not figures_dir.exists(): return [] return list(figures_dir.glob("*.pdf")) def main(): # Title st.markdown('
🔍 Attention Analysis Results Explorer
', unsafe_allow_html=True) # Sidebar for navigation st.sidebar.title("🔧 Configuration") # Cache management section st.sidebar.markdown("### 📁 Data Management") # Initialize explorer use_cache = st.sidebar.checkbox("Use cached data", value=True, help="Use previously downloaded data if available") if st.sidebar.button("🔄 Refresh Data", help="Download fresh data from GitHub"): # Clear cache and re-download cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" if cache_dir.exists(): shutil.rmtree(cache_dir) st.rerun() # Show cache status cache_dir = Path(tempfile.gettempdir()) / "attention_results_cache" if cache_dir.exists(): st.sidebar.success("✅ Data cached locally") else: st.sidebar.info("📥 Will download data from GitHub") st.sidebar.markdown("---") # Initialize explorer with error handling try: explorer = AttentionResultsExplorer(use_cache=use_cache) except Exception as e: st.error(f"❌ Failed to initialize data explorer: {str(e)}") st.error("Please check your internet connection and try again.") return # Check if any languages are available if not explorer.languages: st.error("❌ No result data found. Please check the GitHub repository.") return # Language selection selected_language = st.sidebar.selectbox( "Select Language", options=explorer.languages, help="Choose the language dataset to explore" ) # Get configurations for selected language configs = explorer._get_experimental_configs(selected_language) if not configs: st.error(f"No configurations found for language: {selected_language}") return # Configuration selection selected_config = st.sidebar.selectbox( "Select Experimental Configuration", options=configs, help="Choose the experimental configuration" ) # Parse and display configuration details config_details = explorer._parse_config_name(selected_config) st.sidebar.markdown("**Configuration Details:**") for key, value in config_details.items(): st.sidebar.markdown(f"- **{key}**: {value}") # Get models for selected language and config models = explorer._get_models(selected_language, selected_config) if not models: st.error(f"No models found for {selected_language}/{selected_config}") return # Model selection selected_model = st.sidebar.selectbox( "Select Model", options=models, help="Choose the model to analyze" ) # Main content area tab1, tab2, tab3, tab4, tab5 = st.tabs([ "📊 Overview", "🎯 UAS Scores", "🧠 Head Matching", "📈 Variability", "🖼️ Figures" ]) # Tab 1: Overview with tab1: st.markdown('
Experiment Overview
', unsafe_allow_html=True) # Load metadata metadata = explorer._load_metadata(selected_language, selected_config, selected_model) if metadata: col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Samples", metadata.get('total_number', 'N/A')) with col2: st.metric("Processed Correctly", metadata.get('number_processed_correctly', 'N/A')) with col3: st.metric("Errors", metadata.get('number_errored', 'N/A')) with col4: success_rate = (metadata.get('number_processed_correctly', 0) / metadata.get('total_number', 1)) * 100 if metadata.get('total_number') else 0 st.metric("Success Rate", f"{success_rate:.1f}%") st.markdown("**Random Seed:**", metadata.get('random_seed', 'N/A')) if metadata.get('errored_phrases'): st.markdown("**Errored Phrase IDs:**") st.write(metadata['errored_phrases']) else: st.warning("No metadata available for this configuration.") # Quick stats about available data st.markdown('
Available Data
', unsafe_allow_html=True) uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model) heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model) variability_data = explorer._load_variability(selected_language, selected_config, selected_model) figures = explorer._get_available_figures(selected_language, selected_config, selected_model) col1, col2, col3, col4 = st.columns(4) with col1: st.metric("UAS Relations", len(uas_data)) with col2: st.metric("Head Matching Relations", len(heads_data)) with col3: st.metric("Variability Data", "✓" if variability_data is not None else "✗") with col4: st.metric("Figure Files", len(figures)) # Tab 2: UAS Scores with tab2: st.markdown('
UAS (Unlabeled Attachment Score) Analysis
', unsafe_allow_html=True) uas_data = explorer._load_uas_scores(selected_language, selected_config, selected_model) if uas_data: # Relation selection selected_relation = st.selectbox( "Select Dependency Relation", options=list(uas_data.keys()), help="Choose a dependency relation to visualize UAS scores" ) if selected_relation and selected_relation in uas_data: df = uas_data[selected_relation] # Display the data table st.markdown("**UAS Scores Matrix (Layer × Head)**") st.dataframe(df, use_container_width=True) # Create heatmap fig = px.imshow( df.values, x=[f"Head {i}" for i in df.columns], y=[f"Layer {i}" for i in df.index], color_continuous_scale="Viridis", title=f"UAS Scores Heatmap - {selected_relation}", labels=dict(color="UAS Score") ) fig.update_layout(height=600) st.plotly_chart(fig, use_container_width=True) # Statistics st.markdown("**Statistics**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Max Score", f"{df.values.max():.4f}") with col2: st.metric("Min Score", f"{df.values.min():.4f}") with col3: st.metric("Mean Score", f"{df.values.mean():.4f}") with col4: st.metric("Std Dev", f"{df.values.std():.4f}") else: st.warning("No UAS score data available for this configuration.") # Tab 3: Head Matching with tab3: st.markdown('
Attention Head Matching Analysis
', unsafe_allow_html=True) heads_data = explorer._load_head_matching(selected_language, selected_config, selected_model) if heads_data: # Relation selection selected_relation = st.selectbox( "Select Dependency Relation", options=list(heads_data.keys()), help="Choose a dependency relation to visualize head matching patterns", key="heads_relation" ) if selected_relation and selected_relation in heads_data: df = heads_data[selected_relation] # Display the data table st.markdown("**Head Matching Counts Matrix (Layer × Head)**") st.dataframe(df, use_container_width=True) # Create heatmap fig = px.imshow( df.values, x=[f"Head {i}" for i in df.columns], y=[f"Layer {i}" for i in df.index], color_continuous_scale="Blues", title=f"Head Matching Counts - {selected_relation}", labels=dict(color="Match Count") ) fig.update_layout(height=600) st.plotly_chart(fig, use_container_width=True) # Create bar chart of total matches per layer layer_totals = df.sum(axis=1) fig_bar = px.bar( x=layer_totals.index, y=layer_totals.values, title=f"Total Matches per Layer - {selected_relation}", labels={"x": "Layer", "y": "Total Matches"} ) fig_bar.update_layout(height=400) st.plotly_chart(fig_bar, use_container_width=True) # Statistics st.markdown("**Statistics**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Matches", int(df.values.sum())) with col2: st.metric("Max per Cell", int(df.values.max())) with col3: best_layer = layer_totals.idxmax() st.metric("Best Layer", f"Layer {best_layer}") with col4: best_head_idx = np.unravel_index(df.values.argmax(), df.values.shape) st.metric("Best Head", f"L{best_head_idx[0]}-H{best_head_idx[1]}") else: st.warning("No head matching data available for this configuration.") # Tab 4: Variability with tab4: st.markdown('
Attention Variability Analysis
', unsafe_allow_html=True) variability_data = explorer._load_variability(selected_language, selected_config, selected_model) if variability_data is not None: # Display the data table st.markdown("**Variability Matrix (Layer × Head)**") st.dataframe(variability_data, use_container_width=True) # Create heatmap fig = px.imshow( variability_data.values, x=[f"Head {i}" for i in variability_data.columns], y=[f"Layer {i}" for i in variability_data.index], color_continuous_scale="Reds", title="Attention Variability Heatmap", labels=dict(color="Variability Score") ) fig.update_layout(height=600) st.plotly_chart(fig, use_container_width=True) # Create line plot for variability trends fig_line = go.Figure() for col in variability_data.columns: fig_line.add_trace(go.Scatter( x=variability_data.index, y=variability_data[col], mode='lines+markers', name=f'Head {col}', line=dict(width=2) )) fig_line.update_layout( title="Variability Trends Across Layers", xaxis_title="Layer", yaxis_title="Variability Score", height=500 ) st.plotly_chart(fig_line, use_container_width=True) # Statistics st.markdown("**Statistics**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Max Variability", f"{variability_data.values.max():.4f}") with col2: st.metric("Min Variability", f"{variability_data.values.min():.4f}") with col3: st.metric("Mean Variability", f"{variability_data.values.mean():.4f}") with col4: most_variable_idx = np.unravel_index(variability_data.values.argmax(), variability_data.values.shape) st.metric("Most Variable", f"L{most_variable_idx[0]}-H{most_variable_idx[1]}") else: st.warning("No variability data available for this configuration.") # Tab 5: Figures with tab5: st.markdown('
Generated Figures
', unsafe_allow_html=True) figures = explorer._get_available_figures(selected_language, selected_config, selected_model) if figures: st.markdown(f"**Available Figures: {len(figures)}**") # Group figures by relation type figure_groups = {} for fig_path in figures: # Extract relation from filename filename = fig_path.stem relation = filename.replace("heads_matching_", "").replace(f"_{selected_model}", "") if relation not in figure_groups: figure_groups[relation] = [] figure_groups[relation].append(fig_path) # Select relation to view selected_fig_relation = st.selectbox( "Select Relation for Figure View", options=list(figure_groups.keys()), help="Choose a dependency relation to view its figure" ) if selected_fig_relation and selected_fig_relation in figure_groups: fig_path = figure_groups[selected_fig_relation][0] st.markdown(f"**Figure: {fig_path.name}**") st.markdown(f"**Path:** `{fig_path}`") # Note about PDF viewing st.info( "📄 PDF figures are available in the results directory. " "Due to Streamlit limitations, PDF files cannot be displayed directly in the browser. " "You can download or view them locally." ) # Provide download link try: with open(fig_path, "rb") as file: st.download_button( label=f"📥 Download {fig_path.name}", data=file.read(), file_name=fig_path.name, mime="application/pdf" ) except Exception as e: st.error(f"Could not load figure: {e}") # List all available figures st.markdown("**All Available Figures:**") for relation, paths in figure_groups.items(): with st.expander(f"📊 {relation} ({len(paths)} files)"): for path in paths: st.markdown(f"- `{path.name}`") else: st.warning("No figures available for this configuration.") # Footer st.markdown("---") # Data source information col1, col2 = st.columns([2, 1]) with col1: st.markdown( "🔬 **Attention Analysis Results Explorer** | " f"Currently viewing: {selected_language.upper()} - {selected_model} | " "Built with Streamlit" ) with col2: st.markdown( f"📊 **Data Source**: [GitHub Repository](https://github.com/{explorer.github_repo})" ) if __name__ == "__main__": main()