|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
import mne |
|
|
from pathlib import Path |
|
|
import zipfile |
|
|
import os |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="EEG Mental Arithmetic Explorer", |
|
|
page_icon="🧠", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* Main header styling */ |
|
|
.main-header { |
|
|
font-size: 2.8rem; |
|
|
font-weight: 700; |
|
|
text-align: center; |
|
|
color: #1e3a8a; |
|
|
margin-bottom: 0.5rem; |
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
|
} |
|
|
|
|
|
.sub-header { |
|
|
text-align: center; |
|
|
color: #64748b; |
|
|
font-size: 1.1rem; |
|
|
margin-bottom: 2.5rem; |
|
|
font-weight: 400; |
|
|
} |
|
|
|
|
|
/* Sidebar styling */ |
|
|
[data-testid="stSidebar"] { |
|
|
background-color: #1e293b; |
|
|
} |
|
|
|
|
|
[data-testid="stSidebar"] [data-testid="stMarkdownContainer"] p { |
|
|
color: #e2e8f0; |
|
|
} |
|
|
|
|
|
[data-testid="stSidebar"] h1, |
|
|
[data-testid="stSidebar"] h2, |
|
|
[data-testid="stSidebar"] h3 { |
|
|
color: #f1f5f9; |
|
|
} |
|
|
|
|
|
/* Sidebar selectbox and radio buttons */ |
|
|
[data-testid="stSidebar"] .stSelectbox label, |
|
|
[data-testid="stSidebar"] .stRadio label { |
|
|
color: #f1f5f9 !important; |
|
|
font-weight: 500; |
|
|
} |
|
|
|
|
|
/* Dropdown menu background */ |
|
|
[data-testid="stSidebar"] [data-baseweb="select"] > div { |
|
|
background-color: #334155; |
|
|
color: #f1f5f9; |
|
|
} |
|
|
|
|
|
/* Radio button text */ |
|
|
[data-testid="stSidebar"] [data-baseweb="radio"] label { |
|
|
color: #e2e8f0; |
|
|
} |
|
|
|
|
|
/* Success and info boxes in sidebar */ |
|
|
[data-testid="stSidebar"] .stAlert { |
|
|
background-color: #334155; |
|
|
color: #e2e8f0; |
|
|
} |
|
|
|
|
|
/* Tabs styling */ |
|
|
.stTabs [data-baseweb="tab-list"] { |
|
|
gap: 1rem; |
|
|
background-color: #f1f5f9; |
|
|
padding: 0.5rem; |
|
|
border-radius: 0.5rem; |
|
|
} |
|
|
|
|
|
.stTabs [data-baseweb="tab"] { |
|
|
padding: 0.75rem 1.5rem; |
|
|
font-weight: 500; |
|
|
border-radius: 0.375rem; |
|
|
color: #334155; |
|
|
} |
|
|
|
|
|
.stTabs [data-baseweb="tab"][aria-selected="true"] { |
|
|
background-color: #1e40af; |
|
|
color: white; |
|
|
} |
|
|
|
|
|
/* Metric cards */ |
|
|
[data-testid="stMetricValue"] { |
|
|
font-size: 1.75rem; |
|
|
font-weight: 600; |
|
|
color: #1e40af; |
|
|
} |
|
|
|
|
|
/* Info boxes */ |
|
|
.stAlert { |
|
|
border-radius: 0.5rem; |
|
|
} |
|
|
|
|
|
/* Section headers */ |
|
|
h3 { |
|
|
color: #1e40af; |
|
|
font-weight: 600; |
|
|
margin-top: 1.5rem; |
|
|
margin-bottom: 1rem; |
|
|
border-bottom: 2px solid #e2e8f0; |
|
|
padding-bottom: 0.5rem; |
|
|
} |
|
|
|
|
|
/* Dataframe styling */ |
|
|
[data-testid="stDataFrame"] { |
|
|
border-radius: 0.5rem; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown('<p class="main-header">EEG Mental Arithmetic Explorer</p>', unsafe_allow_html=True) |
|
|
st.markdown('<p class="sub-header">Cognitive Workload Assessment through Brain Activity Analysis</p>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
ZIP_FILE_PATH = "edf_files.zip" |
|
|
EDF_EXTRACT_PATH = "edf_extracted" |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def extract_edf_files(): |
|
|
"""Extract EDF files from ZIP if not already extracted""" |
|
|
if not os.path.exists(EDF_EXTRACT_PATH): |
|
|
if os.path.exists(ZIP_FILE_PATH): |
|
|
with st.spinner("Extracting EDF files... This may take a moment."): |
|
|
os.makedirs(EDF_EXTRACT_PATH, exist_ok=True) |
|
|
with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref: |
|
|
file_list = zip_ref.namelist() |
|
|
for file in file_list: |
|
|
if file.endswith('.edf') and not file.startswith('__MACOSX'): |
|
|
|
|
|
filename = os.path.basename(file) |
|
|
target_path = os.path.join(EDF_EXTRACT_PATH, filename) |
|
|
if not os.path.exists(target_path): |
|
|
with zip_ref.open(file) as source, open(target_path, 'wb') as target: |
|
|
target.write(source.read()) |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
return True |
|
|
|
|
|
extraction_success = extract_edf_files() |
|
|
|
|
|
if not extraction_success: |
|
|
st.error(f"Could not find {ZIP_FILE_PATH}") |
|
|
st.info(""" |
|
|
Expected structure: |
|
|
``` |
|
|
space/ |
|
|
├── app.py |
|
|
├── requirements.txt |
|
|
├── README.md |
|
|
└── edf_files.zip |
|
|
``` |
|
|
""") |
|
|
st.stop() |
|
|
|
|
|
def get_available_subjects(): |
|
|
"""Get list of available subjects from EDF files""" |
|
|
edf_files = list_available_files() |
|
|
subjects = set() |
|
|
for f in edf_files: |
|
|
|
|
|
name = f.stem |
|
|
if '_' in name: |
|
|
subject_id = name.split('_')[0] |
|
|
subjects.add(subject_id) |
|
|
return sorted(list(subjects)) |
|
|
|
|
|
def list_available_files(): |
|
|
"""List available EDF files in extracted directory""" |
|
|
if not os.path.exists(EDF_EXTRACT_PATH): |
|
|
return [] |
|
|
|
|
|
edf_files = [f for f in Path(EDF_EXTRACT_PATH).glob("*.edf")] |
|
|
return edf_files |
|
|
|
|
|
@st.cache_resource |
|
|
def load_edf_data(subject_id, suffix): |
|
|
"""Load EDF EEG data from extracted files""" |
|
|
|
|
|
file_path = f"{EDF_EXTRACT_PATH}/{subject_id}{suffix}.edf" |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
|
|
|
available_files = list(Path(EDF_EXTRACT_PATH).glob("*.edf")) |
|
|
available_names = sorted([f.name for f in available_files]) |
|
|
raise FileNotFoundError( |
|
|
f"Could not find: {subject_id}{suffix}.edf\n" |
|
|
f"Available files ({len(available_names)}): {available_names[:10]}" |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
raw = mne.io.read_raw_edf(file_path, preload=True, verbose=True) |
|
|
|
|
|
|
|
|
data = raw.get_data() |
|
|
|
|
|
|
|
|
data_uv = data * 1e6 |
|
|
|
|
|
channels = raw.ch_names |
|
|
sfreq = raw.info['sfreq'] |
|
|
n_samples = data.shape[1] |
|
|
time = np.arange(n_samples) / sfreq |
|
|
|
|
|
|
|
|
df = pd.DataFrame(data_uv.T, columns=channels) |
|
|
df.insert(0, 'time', time) |
|
|
|
|
|
return df, sfreq, channels, file_path |
|
|
except Exception as e: |
|
|
raise Exception(f"Error loading EDF file {file_path}: {e}") |
|
|
|
|
|
def list_available_files(): |
|
|
"""List available EDF files in extracted directory""" |
|
|
if not os.path.exists(EDF_EXTRACT_PATH): |
|
|
return [] |
|
|
|
|
|
edf_files = [f for f in Path(EDF_EXTRACT_PATH).glob("*.edf")] |
|
|
return edf_files |
|
|
|
|
|
|
|
|
st.sidebar.header("Dataset Controls") |
|
|
|
|
|
|
|
|
edf_files = list_available_files() |
|
|
|
|
|
if not edf_files: |
|
|
st.error("No EDF files found after extraction!") |
|
|
st.info(f"Checked directory: {EDF_EXTRACT_PATH}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
unique_files = len(edf_files) |
|
|
st.sidebar.success(f"Found {unique_files} EDF files") |
|
|
|
|
|
subject_ids = get_available_subjects() |
|
|
|
|
|
if not subject_ids: |
|
|
st.error("No subject files found!") |
|
|
st.stop() |
|
|
|
|
|
selected_subject = st.sidebar.selectbox( |
|
|
"Select Subject", |
|
|
subject_ids, |
|
|
index=0 |
|
|
) |
|
|
|
|
|
recording_type = st.sidebar.radio( |
|
|
"Recording Type", |
|
|
["Resting State (Baseline)", "Mental Arithmetic Task"], |
|
|
index=0 |
|
|
) |
|
|
|
|
|
suffix = "_1" if recording_type == "Resting State (Baseline)" else "_2" |
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.markdown("") |
|
|
st.sidebar.markdown("### Subject Information") |
|
|
st.sidebar.markdown(f"**ID:** {selected_subject}") |
|
|
st.sidebar.markdown(f"**Recording:** {recording_type}") |
|
|
|
|
|
st.sidebar.markdown("") |
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.markdown("### Data Source") |
|
|
st.sidebar.info("Data loaded from EDF files") |
|
|
|
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs(["Signal Viewer", "Spectral Analysis", "Statistics", "About Dataset"]) |
|
|
|
|
|
|
|
|
try: |
|
|
with st.spinner(f"Loading {selected_subject}{suffix}..."): |
|
|
df, sfreq, channels, file_path = load_edf_data(selected_subject, suffix) |
|
|
|
|
|
data_loaded = True |
|
|
st.sidebar.success(f"Loaded: {Path(file_path).name}") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error loading data: {e}") |
|
|
st.info(f"Attempting to load: {selected_subject}{suffix}") |
|
|
data_loaded = False |
|
|
|
|
|
if data_loaded: |
|
|
|
|
|
|
|
|
with tab1: |
|
|
st.markdown("### EEG Signal Visualization") |
|
|
|
|
|
col1, col2, col3 = st.columns([2, 2, 1]) |
|
|
|
|
|
with col1: |
|
|
time_range = st.slider( |
|
|
"Time Window (seconds)", |
|
|
min_value=0.0, |
|
|
max_value=float(df['time'].max()), |
|
|
value=(0.0, min(10.0, float(df['time'].max()))), |
|
|
step=0.5 |
|
|
) |
|
|
|
|
|
with col2: |
|
|
selected_channels = st.multiselect( |
|
|
"Select Channels", |
|
|
channels, |
|
|
default=channels[:6] if len(channels) >= 6 else channels |
|
|
) |
|
|
|
|
|
with col3: |
|
|
plot_style = st.selectbox( |
|
|
"Plot Style", |
|
|
["Stacked", "Overlay"] |
|
|
) |
|
|
|
|
|
if selected_channels: |
|
|
|
|
|
mask = (df['time'] >= time_range[0]) & (df['time'] <= time_range[1]) |
|
|
df_plot = df[mask] |
|
|
|
|
|
if plot_style == "Stacked": |
|
|
|
|
|
fig = make_subplots( |
|
|
rows=len(selected_channels), |
|
|
cols=1, |
|
|
shared_xaxes=True, |
|
|
vertical_spacing=0.02, |
|
|
subplot_titles=selected_channels |
|
|
) |
|
|
|
|
|
for idx, channel in enumerate(selected_channels, 1): |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=df_plot['time'], |
|
|
y=df_plot[channel], |
|
|
mode='lines', |
|
|
name=channel, |
|
|
line=dict(width=1), |
|
|
showlegend=False |
|
|
), |
|
|
row=idx, col=1 |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
height=150 * len(selected_channels), |
|
|
showlegend=False, |
|
|
hovermode='x unified' |
|
|
) |
|
|
fig.update_xaxes(title_text="Time (s)", row=len(selected_channels), col=1) |
|
|
|
|
|
else: |
|
|
fig = go.Figure() |
|
|
|
|
|
for channel in selected_channels: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=df_plot['time'], |
|
|
y=df_plot[channel], |
|
|
mode='lines', |
|
|
name=channel, |
|
|
line=dict(width=1) |
|
|
) |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
height=600, |
|
|
xaxis_title="Time (s)", |
|
|
yaxis_title="Amplitude (μV)", |
|
|
hovermode='x unified', |
|
|
legend=dict( |
|
|
orientation="v", |
|
|
yanchor="top", |
|
|
y=1, |
|
|
xanchor="left", |
|
|
x=1.01 |
|
|
) |
|
|
) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
st.markdown("### Signal Metrics") |
|
|
metric_cols = st.columns(4) |
|
|
|
|
|
with metric_cols[0]: |
|
|
st.metric("Channels", len(selected_channels)) |
|
|
with metric_cols[1]: |
|
|
st.metric("Sampling Rate", f"{sfreq:.0f} Hz") |
|
|
with metric_cols[2]: |
|
|
st.metric("Duration", f"{df['time'].max():.2f} s") |
|
|
with metric_cols[3]: |
|
|
st.metric("Samples", len(df_plot)) |
|
|
else: |
|
|
st.warning("Please select at least one channel to display") |
|
|
|
|
|
|
|
|
with tab2: |
|
|
st.markdown("### Power Spectral Density Analysis") |
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
|
|
|
with col2: |
|
|
channel_for_psd = st.selectbox( |
|
|
"Select Channel for PSD", |
|
|
channels, |
|
|
index=0 |
|
|
) |
|
|
|
|
|
freq_bands = st.checkbox("Show Frequency Bands", value=True) |
|
|
|
|
|
|
|
|
from scipy import signal |
|
|
|
|
|
channel_data = df[channel_for_psd].values |
|
|
frequencies, psd = signal.welch(channel_data, fs=sfreq, nperseg=min(256, len(channel_data))) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=frequencies, |
|
|
y=10 * np.log10(psd), |
|
|
mode='lines', |
|
|
name='PSD', |
|
|
line=dict(color='steelblue', width=2) |
|
|
)) |
|
|
|
|
|
|
|
|
if freq_bands: |
|
|
bands = { |
|
|
'Delta': (0.5, 4, 'rgba(255, 0, 0, 0.1)'), |
|
|
'Theta': (4, 8, 'rgba(255, 165, 0, 0.1)'), |
|
|
'Alpha': (8, 13, 'rgba(255, 255, 0, 0.1)'), |
|
|
'Beta': (13, 30, 'rgba(0, 255, 0, 0.1)'), |
|
|
'Gamma': (30, 50, 'rgba(0, 0, 255, 0.1)') |
|
|
} |
|
|
|
|
|
|
|
|
for band_name, (low, high, color) in bands.items(): |
|
|
fig.add_vrect( |
|
|
x0=low, x1=high, |
|
|
fillcolor=color, |
|
|
layer="below", |
|
|
line_width=0 |
|
|
) |
|
|
|
|
|
|
|
|
y_max = 10 * np.log10(psd).max() |
|
|
annotations = [] |
|
|
for band_name, (low, high, color) in bands.items(): |
|
|
mid_freq = (low + high) / 2 |
|
|
annotations.append( |
|
|
dict( |
|
|
x=mid_freq, |
|
|
y=y_max, |
|
|
text=band_name, |
|
|
showarrow=False, |
|
|
font=dict(size=10, color='black'), |
|
|
bgcolor='rgba(255, 255, 255, 0.8)', |
|
|
borderpad=4 |
|
|
) |
|
|
) |
|
|
|
|
|
fig.update_layout(annotations=annotations) |
|
|
|
|
|
fig.update_layout( |
|
|
height=500, |
|
|
xaxis_title="Frequency (Hz)", |
|
|
yaxis_title="Power Spectral Density (dB/Hz)", |
|
|
hovermode='x' |
|
|
) |
|
|
|
|
|
fig.update_xaxes(range=[0, 100]) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
st.markdown("### Band Power Analysis") |
|
|
|
|
|
bands_power = { |
|
|
'Delta': (0.5, 4), |
|
|
'Theta': (4, 8), |
|
|
'Alpha': (8, 13), |
|
|
'Beta': (13, 30), |
|
|
'Gamma': (30, 50) |
|
|
} |
|
|
|
|
|
band_powers = {} |
|
|
for band_name, (low, high) in bands_power.items(): |
|
|
mask = (frequencies >= low) & (frequencies <= high) |
|
|
|
|
|
band_powers[band_name] = np.trapezoid(psd[mask], frequencies[mask]) |
|
|
|
|
|
|
|
|
fig_bands = go.Figure(data=[ |
|
|
go.Bar( |
|
|
x=list(band_powers.keys()), |
|
|
y=list(band_powers.values()), |
|
|
marker_color=['#ff6b6b', '#ffa500', '#ffff00', '#90ee90', '#6495ed'] |
|
|
) |
|
|
]) |
|
|
|
|
|
fig_bands.update_layout( |
|
|
height=400, |
|
|
xaxis_title="Frequency Band", |
|
|
yaxis_title="Absolute Power", |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
st.plotly_chart(fig_bands, use_container_width=True) |
|
|
|
|
|
|
|
|
with tab3: |
|
|
st.markdown("### Statistical Analysis") |
|
|
|
|
|
|
|
|
stats_data = [] |
|
|
for channel in channels: |
|
|
channel_series = df[channel] |
|
|
mean_val = float(channel_series.mean()) |
|
|
std_val = float(channel_series.std()) |
|
|
min_val = float(channel_series.min()) |
|
|
max_val = float(channel_series.max()) |
|
|
|
|
|
stats_data.append({ |
|
|
'Channel': channel, |
|
|
'Mean (μV)': mean_val, |
|
|
'Std (μV)': std_val, |
|
|
'Min (μV)': min_val, |
|
|
'Max (μV)': max_val, |
|
|
'Range (μV)': max_val - min_val |
|
|
}) |
|
|
|
|
|
stats_df = pd.DataFrame(stats_data) |
|
|
|
|
|
|
|
|
numeric_cols = ['Mean (μV)', 'Std (μV)', 'Min (μV)', 'Max (μV)', 'Range (μV)'] |
|
|
for col in numeric_cols: |
|
|
stats_df[col] = stats_df[col].apply(lambda x: f"{x:.2f}") |
|
|
|
|
|
st.dataframe(stats_df, height=400) |
|
|
|
|
|
|
|
|
st.markdown("### Channel Correlation Matrix") |
|
|
|
|
|
corr_matrix = df[channels].corr() |
|
|
|
|
|
fig_corr = go.Figure(data=go.Heatmap( |
|
|
z=corr_matrix.values, |
|
|
x=channels, |
|
|
y=channels, |
|
|
colorscale='RdBu', |
|
|
zmid=0, |
|
|
text=corr_matrix.values, |
|
|
texttemplate='%{text:.2f}', |
|
|
textfont={"size": 8}, |
|
|
colorbar=dict(title="Correlation") |
|
|
)) |
|
|
|
|
|
fig_corr.update_layout( |
|
|
height=750, |
|
|
title="Channel Correlation Matrix" |
|
|
) |
|
|
|
|
|
st.plotly_chart(fig_corr, use_container_width=True) |
|
|
|
|
|
|
|
|
with tab4: |
|
|
st.markdown(""" |
|
|
### About This Dataset |
|
|
|
|
|
This dataset contains EEG recordings from 36 healthy participants during resting state |
|
|
and mental arithmetic task performance. |
|
|
|
|
|
#### Key Features |
|
|
- **Participants**: 36 healthy subjects |
|
|
- **Recordings**: Paired (resting state + task) |
|
|
- **Channels**: 23 EEG channels (International 10/20 system) |
|
|
- **Duration**: 60 seconds per recording |
|
|
- **Sampling Rate**: Approximately 500 Hz |
|
|
- **Task**: Serial subtraction (4-digit minus 2-digit numbers) |
|
|
|
|
|
#### Subject Groups |
|
|
- **Good Performers** (24 subjects): Mean 21 operations in 4 minutes |
|
|
- **Poor Performers** (12 subjects): Mean 7 operations in 4 minutes |
|
|
|
|
|
#### Preprocessing |
|
|
- High-pass filter at 30 Hz |
|
|
- Notch filter at 50 Hz |
|
|
- ICA artifact removal (eyes, muscles, cardiac) |
|
|
|
|
|
#### Citation |
|
|
``` |
|
|
Zyma I, Tukaev S, Seleznov I, Kiyono K, Popov A, Chernykh M, Shpenkov O. |
|
|
Electroencephalograms during Mental Arithmetic Task Performance. |
|
|
Data. 2019; 4(1):14. |
|
|
https://doi.org/10.3390/data4010014 |
|
|
``` |
|
|
|
|
|
#### Resources |
|
|
- [PhysioNet Dataset](https://physionet.org/content/eegmat/1.0.0/) |
|
|
- [Original Paper](https://doi.org/10.3390/data4010014) |
|
|
- [Hugging Face Dataset](https://huggingface.co/datasets/BrainSpectralAnalytics/eeg-mental-arithmetic) |
|
|
|
|
|
#### Contact |
|
|
Ivan Seleznov: ivan.seleznov1@gmail.com |
|
|
""") |
|
|
|
|
|
else: |
|
|
st.warning("Unable to load data. Please check the selected subject and recording type.") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
'<p style="text-align: center; color: #94a3b8; font-size: 0.9rem;">Built with Streamlit | EEG Mental Arithmetic Dataset Explorer</p>', |
|
|
unsafe_allow_html=True |
|
|
) |