beta3's picture
Update src/streamlit_app.py
a7719c0 verified
import streamlit as st
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import mne
from pathlib import Path
import zipfile
import os
st.set_page_config(
page_title="EEG Mental Arithmetic Explorer",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
/* Main header styling */
.main-header {
font-size: 2.8rem;
font-weight: 700;
text-align: center;
color: #1e3a8a;
margin-bottom: 0.5rem;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.sub-header {
text-align: center;
color: #64748b;
font-size: 1.1rem;
margin-bottom: 2.5rem;
font-weight: 400;
}
/* Sidebar styling */
[data-testid="stSidebar"] {
background-color: #1e293b;
}
[data-testid="stSidebar"] [data-testid="stMarkdownContainer"] p {
color: #e2e8f0;
}
[data-testid="stSidebar"] h1,
[data-testid="stSidebar"] h2,
[data-testid="stSidebar"] h3 {
color: #f1f5f9;
}
/* Sidebar selectbox and radio buttons */
[data-testid="stSidebar"] .stSelectbox label,
[data-testid="stSidebar"] .stRadio label {
color: #f1f5f9 !important;
font-weight: 500;
}
/* Dropdown menu background */
[data-testid="stSidebar"] [data-baseweb="select"] > div {
background-color: #334155;
color: #f1f5f9;
}
/* Radio button text */
[data-testid="stSidebar"] [data-baseweb="radio"] label {
color: #e2e8f0;
}
/* Success and info boxes in sidebar */
[data-testid="stSidebar"] .stAlert {
background-color: #334155;
color: #e2e8f0;
}
/* Tabs styling */
.stTabs [data-baseweb="tab-list"] {
gap: 1rem;
background-color: #f1f5f9;
padding: 0.5rem;
border-radius: 0.5rem;
}
.stTabs [data-baseweb="tab"] {
padding: 0.75rem 1.5rem;
font-weight: 500;
border-radius: 0.375rem;
color: #334155;
}
.stTabs [data-baseweb="tab"][aria-selected="true"] {
background-color: #1e40af;
color: white;
}
/* Metric cards */
[data-testid="stMetricValue"] {
font-size: 1.75rem;
font-weight: 600;
color: #1e40af;
}
/* Info boxes */
.stAlert {
border-radius: 0.5rem;
}
/* Section headers */
h3 {
color: #1e40af;
font-weight: 600;
margin-top: 1.5rem;
margin-bottom: 1rem;
border-bottom: 2px solid #e2e8f0;
padding-bottom: 0.5rem;
}
/* Dataframe styling */
[data-testid="stDataFrame"] {
border-radius: 0.5rem;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<p class="main-header">EEG Mental Arithmetic Explorer</p>', unsafe_allow_html=True)
st.markdown('<p class="sub-header">Cognitive Workload Assessment through Brain Activity Analysis</p>', unsafe_allow_html=True)
# Data paths - Root level structure
ZIP_FILE_PATH = "edf_files.zip"
EDF_EXTRACT_PATH = "edf_extracted"
# Uncompress EDF files if needed
@st.cache_resource
def extract_edf_files():
"""Extract EDF files from ZIP if not already extracted"""
if not os.path.exists(EDF_EXTRACT_PATH):
if os.path.exists(ZIP_FILE_PATH):
with st.spinner("Extracting EDF files... This may take a moment."):
os.makedirs(EDF_EXTRACT_PATH, exist_ok=True)
with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
file_list = zip_ref.namelist()
for file in file_list:
if file.endswith('.edf') and not file.startswith('__MACOSX'):
# Extract to root of EDF_EXTRACT_PATH, removing any subdirectories
filename = os.path.basename(file)
target_path = os.path.join(EDF_EXTRACT_PATH, filename)
if not os.path.exists(target_path):
with zip_ref.open(file) as source, open(target_path, 'wb') as target:
target.write(source.read())
return True
else:
return False
return True
extraction_success = extract_edf_files()
if not extraction_success:
st.error(f"Could not find {ZIP_FILE_PATH}")
st.info("""
Expected structure:
```
space/
├── app.py
├── requirements.txt
├── README.md
└── edf_files.zip
```
""")
st.stop()
def get_available_subjects():
"""Get list of available subjects from EDF files"""
edf_files = list_available_files()
subjects = set()
for f in edf_files:
# Extract subject ID from filename (e.g., Subject01_1.edf -> Subject01)
name = f.stem
if '_' in name:
subject_id = name.split('_')[0]
subjects.add(subject_id)
return sorted(list(subjects))
def list_available_files():
"""List available EDF files in extracted directory"""
if not os.path.exists(EDF_EXTRACT_PATH):
return []
# Get only .edf files directly in the extract path (no subdirectories)
edf_files = [f for f in Path(EDF_EXTRACT_PATH).glob("*.edf")]
return edf_files
@st.cache_resource
def load_edf_data(subject_id, suffix):
"""Load EDF EEG data from extracted files"""
# Direct path in extracted directory
file_path = f"{EDF_EXTRACT_PATH}/{subject_id}{suffix}.edf"
if not os.path.exists(file_path):
# List available files for debugging
available_files = list(Path(EDF_EXTRACT_PATH).glob("*.edf"))
available_names = sorted([f.name for f in available_files])
raise FileNotFoundError(
f"Could not find: {subject_id}{suffix}.edf\n"
f"Available files ({len(available_names)}): {available_names[:10]}"
)
try:
# Load EDF with verbose to see any warnings
raw = mne.io.read_raw_edf(file_path, preload=True, verbose=True)
# Get data in Volts (MNE returns data in Volts by default)
data = raw.get_data() # Shape: (n_channels, n_samples)
# Convert to microvolts
data_uv = data * 1e6
channels = raw.ch_names
sfreq = raw.info['sfreq']
n_samples = data.shape[1]
time = np.arange(n_samples) / sfreq
# Create DataFrame with microvolts
df = pd.DataFrame(data_uv.T, columns=channels)
df.insert(0, 'time', time)
return df, sfreq, channels, file_path
except Exception as e:
raise Exception(f"Error loading EDF file {file_path}: {e}")
def list_available_files():
"""List available EDF files in extracted directory"""
if not os.path.exists(EDF_EXTRACT_PATH):
return []
# Get only .edf files directly in the extract path (no subdirectories)
edf_files = [f for f in Path(EDF_EXTRACT_PATH).glob("*.edf")]
return edf_files
st.sidebar.header("Dataset Controls")
# Check available files
edf_files = list_available_files()
if not edf_files:
st.error("No EDF files found after extraction!")
st.info(f"Checked directory: {EDF_EXTRACT_PATH}")
st.stop()
unique_files = len(edf_files)
st.sidebar.success(f"Found {unique_files} EDF files")
subject_ids = get_available_subjects()
if not subject_ids:
st.error("No subject files found!")
st.stop()
selected_subject = st.sidebar.selectbox(
"Select Subject",
subject_ids,
index=0
)
recording_type = st.sidebar.radio(
"Recording Type",
["Resting State (Baseline)", "Mental Arithmetic Task"],
index=0
)
suffix = "_1" if recording_type == "Resting State (Baseline)" else "_2"
st.sidebar.markdown("---")
st.sidebar.markdown("") # Espacio adicional
st.sidebar.markdown("### Subject Information")
st.sidebar.markdown(f"**ID:** {selected_subject}")
st.sidebar.markdown(f"**Recording:** {recording_type}")
st.sidebar.markdown("") # Espacio adicional
st.sidebar.markdown("---")
st.sidebar.markdown("### Data Source")
st.sidebar.info("Data loaded from EDF files")
# Main content
tab1, tab2, tab3, tab4 = st.tabs(["Signal Viewer", "Spectral Analysis", "Statistics", "About Dataset"])
# Load data
try:
with st.spinner(f"Loading {selected_subject}{suffix}..."):
df, sfreq, channels, file_path = load_edf_data(selected_subject, suffix)
data_loaded = True
st.sidebar.success(f"Loaded: {Path(file_path).name}")
except Exception as e:
st.error(f"Error loading data: {e}")
st.info(f"Attempting to load: {selected_subject}{suffix}")
data_loaded = False
if data_loaded:
# TAB 1: Signal Viewer
with tab1:
st.markdown("### EEG Signal Visualization")
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
time_range = st.slider(
"Time Window (seconds)",
min_value=0.0,
max_value=float(df['time'].max()),
value=(0.0, min(10.0, float(df['time'].max()))),
step=0.5
)
with col2:
selected_channels = st.multiselect(
"Select Channels",
channels,
default=channels[:6] if len(channels) >= 6 else channels
)
with col3:
plot_style = st.selectbox(
"Plot Style",
["Stacked", "Overlay"]
)
if selected_channels:
# Filter data by time range
mask = (df['time'] >= time_range[0]) & (df['time'] <= time_range[1])
df_plot = df[mask]
if plot_style == "Stacked":
# Create stacked subplots
fig = make_subplots(
rows=len(selected_channels),
cols=1,
shared_xaxes=True,
vertical_spacing=0.02,
subplot_titles=selected_channels
)
for idx, channel in enumerate(selected_channels, 1):
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot[channel],
mode='lines',
name=channel,
line=dict(width=1),
showlegend=False
),
row=idx, col=1
)
fig.update_layout(
height=150 * len(selected_channels),
showlegend=False,
hovermode='x unified'
)
fig.update_xaxes(title_text="Time (s)", row=len(selected_channels), col=1)
else: # Overlay
fig = go.Figure()
for channel in selected_channels:
fig.add_trace(
go.Scatter(
x=df_plot['time'],
y=df_plot[channel],
mode='lines',
name=channel,
line=dict(width=1)
)
)
fig.update_layout(
height=600,
xaxis_title="Time (s)",
yaxis_title="Amplitude (μV)",
hovermode='x unified',
legend=dict(
orientation="v",
yanchor="top",
y=1,
xanchor="left",
x=1.01
)
)
st.plotly_chart(fig, use_container_width=True)
# Signal metrics
st.markdown("### Signal Metrics")
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Channels", len(selected_channels))
with metric_cols[1]:
st.metric("Sampling Rate", f"{sfreq:.0f} Hz")
with metric_cols[2]:
st.metric("Duration", f"{df['time'].max():.2f} s")
with metric_cols[3]:
st.metric("Samples", len(df_plot))
else:
st.warning("Please select at least one channel to display")
# TAB 2: Spectral Analysis
with tab2:
st.markdown("### Power Spectral Density Analysis")
col1, col2 = st.columns([3, 1])
with col2:
channel_for_psd = st.selectbox(
"Select Channel for PSD",
channels,
index=0
)
freq_bands = st.checkbox("Show Frequency Bands", value=True)
# Compute PSD
from scipy import signal
channel_data = df[channel_for_psd].values
frequencies, psd = signal.welch(channel_data, fs=sfreq, nperseg=min(256, len(channel_data)))
# Plot PSD
fig = go.Figure()
fig.add_trace(go.Scatter(
x=frequencies,
y=10 * np.log10(psd),
mode='lines',
name='PSD',
line=dict(color='steelblue', width=2)
))
# Add frequency bands if selected
if freq_bands:
bands = {
'Delta': (0.5, 4, 'rgba(255, 0, 0, 0.1)'),
'Theta': (4, 8, 'rgba(255, 165, 0, 0.1)'),
'Alpha': (8, 13, 'rgba(255, 255, 0, 0.1)'),
'Beta': (13, 30, 'rgba(0, 255, 0, 0.1)'),
'Gamma': (30, 50, 'rgba(0, 0, 255, 0.1)')
}
# Add colored bands
for band_name, (low, high, color) in bands.items():
fig.add_vrect(
x0=low, x1=high,
fillcolor=color,
layer="below",
line_width=0
)
# Add annotations at the top of the plot
y_max = 10 * np.log10(psd).max()
annotations = []
for band_name, (low, high, color) in bands.items():
mid_freq = (low + high) / 2
annotations.append(
dict(
x=mid_freq,
y=y_max,
text=band_name,
showarrow=False,
font=dict(size=10, color='black'),
bgcolor='rgba(255, 255, 255, 0.8)',
borderpad=4
)
)
fig.update_layout(annotations=annotations)
fig.update_layout(
height=500,
xaxis_title="Frequency (Hz)",
yaxis_title="Power Spectral Density (dB/Hz)",
hovermode='x'
)
fig.update_xaxes(range=[0, 100])
st.plotly_chart(fig, use_container_width=True)
# Band power analysis
st.markdown("### Band Power Analysis")
bands_power = {
'Delta': (0.5, 4),
'Theta': (4, 8),
'Alpha': (8, 13),
'Beta': (13, 30),
'Gamma': (30, 50)
}
band_powers = {}
for band_name, (low, high) in bands_power.items():
mask = (frequencies >= low) & (frequencies <= high)
# Use trapezoid instead of trapz (numpy 2.0+)
band_powers[band_name] = np.trapezoid(psd[mask], frequencies[mask])
# Plot band powers
fig_bands = go.Figure(data=[
go.Bar(
x=list(band_powers.keys()),
y=list(band_powers.values()),
marker_color=['#ff6b6b', '#ffa500', '#ffff00', '#90ee90', '#6495ed']
)
])
fig_bands.update_layout(
height=400,
xaxis_title="Frequency Band",
yaxis_title="Absolute Power",
showlegend=False
)
st.plotly_chart(fig_bands, use_container_width=True)
# TAB 3: Statistics
with tab3:
st.markdown("### Statistical Analysis")
# Channel statistics table
stats_data = []
for channel in channels:
channel_series = df[channel]
mean_val = float(channel_series.mean())
std_val = float(channel_series.std())
min_val = float(channel_series.min())
max_val = float(channel_series.max())
stats_data.append({
'Channel': channel,
'Mean (μV)': mean_val,
'Std (μV)': std_val,
'Min (μV)': min_val,
'Max (μV)': max_val,
'Range (μV)': max_val - min_val
})
stats_df = pd.DataFrame(stats_data)
# Format numeric columns to 2 decimals
numeric_cols = ['Mean (μV)', 'Std (μV)', 'Min (μV)', 'Max (μV)', 'Range (μV)']
for col in numeric_cols:
stats_df[col] = stats_df[col].apply(lambda x: f"{x:.2f}")
st.dataframe(stats_df, height=400)
# Correlation heatmap
st.markdown("### Channel Correlation Matrix")
corr_matrix = df[channels].corr()
fig_corr = go.Figure(data=go.Heatmap(
z=corr_matrix.values,
x=channels,
y=channels,
colorscale='RdBu',
zmid=0,
text=corr_matrix.values,
texttemplate='%{text:.2f}',
textfont={"size": 8},
colorbar=dict(title="Correlation")
))
fig_corr.update_layout(
height=750,
title="Channel Correlation Matrix"
)
st.plotly_chart(fig_corr, use_container_width=True)
# TAB 4: About
with tab4:
st.markdown("""
### About This Dataset
This dataset contains EEG recordings from 36 healthy participants during resting state
and mental arithmetic task performance.
#### Key Features
- **Participants**: 36 healthy subjects
- **Recordings**: Paired (resting state + task)
- **Channels**: 23 EEG channels (International 10/20 system)
- **Duration**: 60 seconds per recording
- **Sampling Rate**: Approximately 500 Hz
- **Task**: Serial subtraction (4-digit minus 2-digit numbers)
#### Subject Groups
- **Good Performers** (24 subjects): Mean 21 operations in 4 minutes
- **Poor Performers** (12 subjects): Mean 7 operations in 4 minutes
#### Preprocessing
- High-pass filter at 30 Hz
- Notch filter at 50 Hz
- ICA artifact removal (eyes, muscles, cardiac)
#### Citation
```
Zyma I, Tukaev S, Seleznov I, Kiyono K, Popov A, Chernykh M, Shpenkov O.
Electroencephalograms during Mental Arithmetic Task Performance.
Data. 2019; 4(1):14.
https://doi.org/10.3390/data4010014
```
#### Resources
- [PhysioNet Dataset](https://physionet.org/content/eegmat/1.0.0/)
- [Original Paper](https://doi.org/10.3390/data4010014)
- [Hugging Face Dataset](https://huggingface.co/datasets/BrainSpectralAnalytics/eeg-mental-arithmetic)
#### Contact
Ivan Seleznov: ivan.seleznov1@gmail.com
""")
else:
st.warning("Unable to load data. Please check the selected subject and recording type.")
# Footer
st.markdown("---")
st.markdown(
'<p style="text-align: center; color: #94a3b8; font-size: 0.9rem;">Built with Streamlit | EEG Mental Arithmetic Dataset Explorer</p>',
unsafe_allow_html=True
)