Spaces:
Runtime error
Runtime error
File size: 16,177 Bytes
788c373 0a50e94 788c373 0a50e94 788c373 0a50e94 788c373 0a50e94 788c373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
import gradio as gr
import numpy as np
import os
import random
import utils
from channel_mapping import mapping, reorder_data
import mne
from mne.channels import read_custom_montage
quickstart = """
# Quickstart
## 1. Channel mapping
### Raw data
1. The data need to be a two-dimensional array (channel, timepoint).
2. Make sure you have **resampled** your data to **256 Hz**.
3. Upload your EEG data in `.csv` format.
### Channel locations
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
>If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
### Imputation
The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`.
We expect your input data to include these channels as well.
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
<u>Manually</u>:
- **mean**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
<u>Automatically</u>:
Firstly, we will attempt to find neighboring channel to use as alternative. For instance, if the required channel is **FC3** but you only have **FC1**, we will use it as a replacement for **FC3**.
Then, depending on the **Imputation** way you chose, we will:
- **zero**: fill the missing channels with zeros.
- **adjacent**: fill the missing channels using neighboring channels which are located closer to the center. For example, if the required channel is **FC3** but you only have **F3, C3**, then we will choose **C3** as the imputing value for **FC3**.
>Note: The imputed channels **need to be removed** after the data being reconstructed.
### Mapping result
Once the mapping process is finished, the **template montage** and the **input montage**(with the channels choosen by the mapping function displaying their names) will be shown.
### Missing channels
The channels displayed here are those for which the template didn't find suitable channels to use, and utilized **Imputation** to fill the missing values.
Therefore, you need to
<span style="color:red">**remove these channels**</span>
after you download the denoised data.
### Template location file
You need to use this as the **new location file** for the denoised data.
## 2. Decode data
### Model
Select the model you want to use.
The detailed description of the models can be found in other pages.
"""
icunet = """
# IC-U-Net
### Abstract
Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting.
"""
unetpp = """
# IC-U-Net++
### Abstract
Electroencephalographic (EEG) data is considered contaminated with various types of artifacts. Deep learning has been successfully applied to developing EEG artifact removal techniques to increase the signal-to-noise ratio (SNR) and enhance brain-computer interface performance. Recently, our research team has proposed an end-to-end UNet-based EEG artifact removal technique, IC-U-Net, which can reconstruct signals against various artifacts. However, this model suffers from being prone to overfitting with a limited training dataset size and demanding a high computational cost. To address these issues, this study attempted to leverage the architecture of UNet++ to improve the practicability of IC-U-Net by introducing dense skip connections in the encoder-decoder architecture. Results showed that this proposed model obtained superior SNR to the original model with half the number of parameters. Also, this proposed model achieved comparable convergency using a quarter of the training data size.
"""
chkbox_js = """
(state_json) => {
state_json = JSON.parse(JSON.stringify(state_json));
if(state_json.state == "finished") return;
document.querySelector("#chs-chkbox>div:nth-of-type(2)").style.cssText = `
position: relative;
width: 560px;
height: 560px;
background: url("file=${state_json.files.raw_montage}");
`;
let all_chkbox = document.querySelectorAll("#chs-chkbox> div:nth-of-type(2)> label");
all_chkbox = Array.apply(null, all_chkbox);
all_chkbox.forEach((item, index) => {
let channel = state_json.inputByIndex[index];
let left = state_json.inputByName[channel].css_position[0];
let bottom = state_json.inputByName[channel].css_position[1];
//console.log(`left: ${left}, bottom: ${bottom}`);
item.style.cssText = `
position: absolute;
left: ${left};
bottom: ${bottom};
`;
item.className = "";
item.querySelector("span").innerText = "";
});
}
"""
with gr.Blocks() as demo:
state_json = gr.JSON(elem_id="state", visible=False)
with gr.Row():
gr.Markdown(
"""
"""
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
# 1.Channel Mapping
"""
)
with gr.Row():
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
with gr.Row():
in_fill_mode = gr.Dropdown(choices=["zero",
("adjacent channel", "adjacent"),
("mean (manually select channels)", "mean")],
value="zero",
label="Imputation",
scale=2)
map_btn = gr.Button("Mapping", scale=1)
channels_json = gr.JSON(visible=False)
res_md = gr.Markdown(
"""
### Mapping result:
""",
visible=False
)
with gr.Row():
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
map_montage = gr.Image(label="Choosen channels", visible=False)
chs_chkbox = gr.CheckboxGroup(elem_id="chs-chkbox", label="", visible=False)
next_btn = gr.Button("Next", interactive=False, visible=False)
miss_txtbox = gr.Textbox(label="Missing channels", visible=False)
tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False)
with gr.Column():
gr.Markdown(
"""
# 2.Decode Data
"""
)
with gr.Row():
in_model_name = gr.Dropdown(choices=["IC-U-Net", "IC-U-Net++", "IC-U-Net-Attn", "ART", "(mapped data)"],
value="IC-U-Net",
label="Model",
scale=2)
run_btn = gr.Button(scale=1, interactive=False)
out_denoised_data = gr.File(label="Denoised data")
with gr.Row():
with gr.Tab("ART"):
gr.Markdown()
with gr.Tab("IC-U-Net"):
gr.Markdown(icunet)
with gr.Tab("IC-U-Net++"):
gr.Markdown(unetpp)
with gr.Tab("IC-U-Net-Att"):
gr.Markdown()
with gr.Tab("QuickStart"):
gr.Markdown(quickstart)
#demo.load(js=js)
def reset_layout(raw_data):
# establish temp folder
filepath = os.path.dirname(str(raw_data))
try:
os.mkdir(filepath+"/temp_data/")
except OSError as e:
utils.dataDelete(filepath+"/temp_data/")
os.mkdir(filepath+"/temp_data/")
#print(e)
state_obj = {
"filepath": filepath+"/temp_data/",
"files": {}
}
return {state_json : state_obj,
chs_chkbox : gr.CheckboxGroup(choices=[], value=[], label="", visible=False), # choices, value ???
next_btn : gr.Button("Next", interactive=False, visible=False),
run_btn : gr.Button(interactive=False),
tpl_montage : gr.Image(visible=False),
map_montage : gr.Image(value=None, visible=False),
miss_txtbox : gr.Textbox(visible=False),
res_md : gr.Markdown(visible=False),
tpl_loc_file : gr.File(visible=False)}
def mapping_result(state_obj, channels_obj, raw_data, fill_mode):
state_obj.update(channels_obj)
if fill_mode=="mean" and channels_obj["missingChannelsIndex"]!=[]:
state_obj.update({
"state" : "initializing",
"fillingCount" : 0,
"totalFillingNum" : len(channels_obj["missingChannelsIndex"])-1
})
#print("Missing channels:", state_obj["missingChannelsIndex"])
return {state_json : state_obj,
next_btn : gr.Button(visible=True)}
else:
reorder_data(raw_data, channels_obj["newOrder"], fill_mode, state_obj)
missing_channels = [state_obj["templateByIndex"][idx] for idx in state_obj["missingChannelsIndex"]]
missing_channels = ', '.join(missing_channels)
state_obj.update({
"state" : "finished",
#"fillingCount" : -1,
#"totalFillingNum" : -1
})
return {state_json : state_obj,
res_md : gr.Markdown(visible=True),
miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
tpl_loc_file : gr.File(visible=True),
run_btn : gr.Button(interactive=True)}
def show_montage(state_obj, raw_loc):
filepath = state_obj["filepath"]
raw_montage = read_custom_montage(raw_loc)
# convert all channel names to uppercase
for i in range(len(raw_montage.ch_names)):
channel = raw_montage.ch_names[i]
raw_montage.rename_channels({channel: str.upper(channel)})
if state_obj["state"] == "initializing":
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
state_obj["files"]["raw_montage"] = filename
raw_fig = raw_montage.plot()
raw_fig.set_size_inches(5.6, 5.6)
raw_fig.savefig(filename, pad_inches=0)
return {state_json : state_obj}#,
#tpl_montage : gr.Image(visible=True),
#in_montage : gr.Image(value=filename, visible=True),
#map_montage : gr.Image(visible=False)}
elif state_obj["state"] == "finished":
# didn't find any way to hide the dark points...
# tmp
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
state_obj["files"]["map_montage"] = filename
show_names= []
for channel in state_obj["inputByName"]:
if state_obj["inputByName"][channel]["used"]:
if channel=='CZ' and state_obj["CZImputed"]:
continue
show_names.append(channel)
mapped_fig = raw_montage.plot(show_names=show_names)
mapped_fig.set_size_inches(5.6, 5.6)
mapped_fig.savefig(filename, pad_inches=0)
return {state_json : state_obj,
tpl_montage : gr.Image(visible=True),
map_montage : gr.Image(value=filename, visible=True)}
elif state_obj["state"] == "selecting":
# update in_montage here ?
#return {in_montage : gr.Image()}
return {state_json : state_obj}
def generate_chkbox(state_obj):
if state_obj["state"] == "initializing":
in_channels = [channel for channel in state_obj["inputByName"]]
state_obj["state"] = "selecting"
first_idx = state_obj["missingChannelsIndex"][0]
first_name = state_obj["templateByIndex"][first_idx]
chkbox_label = first_name+' (1/'+str(state_obj["totalFillingNum"]+1)+')'
return {state_json : state_obj,
chs_chkbox : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
next_btn : gr.Button(interactive=True)}
else:
return {state_json : state_obj}
map_btn.click(
fn = reset_layout,
inputs = in_raw_data,
outputs = [state_json, chs_chkbox, next_btn, run_btn, tpl_montage, map_montage, miss_txtbox,
res_md, tpl_loc_file]
).success(
fn = mapping,
inputs = [in_raw_data, in_raw_loc, in_fill_mode],
outputs = channels_json
).success(
fn = mapping_result,
inputs = [state_json, channels_json, in_raw_data, in_fill_mode],
outputs = [state_json, chs_chkbox, next_btn, miss_txtbox, res_md, tpl_loc_file, run_btn]
).success(
fn = show_montage,
inputs = [state_json, in_raw_loc],
outputs = [state_json, tpl_montage, map_montage]
).success(
fn = generate_chkbox,
inputs = state_json,
outputs = [state_json, chs_chkbox, next_btn]
).success(
fn = None,
js = chkbox_js,
inputs = state_json,
outputs = []
)
def check_next(state_obj, selected, raw_data, fill_mode):
if state_obj["state"] == "selecting":
# save info before clicking on next_btn
prev_target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]]
prev_target_name = state_obj["templateByIndex"][prev_target_idx]
selected_idx = [state_obj["inputByName"][channel]["index"] for channel in selected]
state_obj["newOrder"][prev_target_idx] = selected_idx
if len(selected)==1 and state_obj["inputByName"][selected[0]]["used"]==False:
state_obj["inputByName"][selected[0]]["used"] = True
state_obj["missingChannelsIndex"][state_obj["fillingCount"]] = -1
print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
# update next round
state_obj["fillingCount"] += 1
if state_obj["fillingCount"] <= state_obj["totalFillingNum"]:
target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]]
target_name = state_obj["templateByIndex"][target_idx]
chkbox_label = target_name+' ('+str(state_obj["fillingCount"]+1)+'/'+str(state_obj["totalFillingNum"]+1)+')'
btn_label = "Submit" if state_obj["fillingCount"]==state_obj["totalFillingNum"] else "Next"
return {state_json : state_obj,
chs_chkbox : gr.CheckboxGroup(value=[], label=chkbox_label),
next_btn : gr.Button(btn_label)}
else:
state_obj["state"] = "finished"
reorder_data(raw_data, state_obj["newOrder"], fill_mode, state_obj)
missing_channels = []
for idx in state_obj["missingChannelsIndex"]:
if idx != -1:
missing_channels.append(state_obj["templateByIndex"][idx])
missing_channels = ', '.join(missing_channels)
return {state_json : state_obj,
chs_chkbox : gr.CheckboxGroup(visible=False),
next_btn : gr.Button(visible=False),
res_md : gr.Markdown(visible=True),
miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
tpl_loc_file : gr.File(visible=True),
run_btn : gr.Button(interactive=True)}
next_btn.click(
fn = check_next,
inputs = [state_json, chs_chkbox, in_raw_data, in_fill_mode],
outputs = [state_json, chs_chkbox, next_btn, run_btn, res_md, miss_txtbox, tpl_loc_file]
).success(
fn = show_montage,
inputs = [state_json, in_raw_loc],
outputs = [state_json, tpl_montage, map_montage]
)
@run_btn.click(inputs=[state_json, in_raw_data, in_model_name], outputs=out_denoised_data)
def run_model(state_obj, raw_file, model_name):
filepath = state_obj["filepath"]
input_name = os.path.basename(str(raw_file))
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
if model_name == "(mapped data)":
return filepath + 'mapped.csv'
# step1: Data preprocessing
total_file_num = utils.preprocessing(filepath, 'mapped.csv', 256)
# step2: Signal reconstruction
utils.reconstruct(model_name, total_file_num, filepath, output_name)
return filepath + output_name
if __name__ == "__main__":
demo.launch()
|