import gradio as gr import numpy as np import os import random import utils from channel_mapping import mapping, reorder_data import mne from mne.channels import read_custom_montage quickstart = """ # Quickstart ## 1. Channel mapping ### Raw data 1. The data need to be a two-dimensional array (channel, timepoint). 2. Make sure you have **resampled** your data to **256 Hz**. 3. Upload your EEG data in `.csv` format. ### Channel locations Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**. >If you cannot obtain it, we recommend you to download the standard montage here. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage. ### Imputation The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`. We expect your input data to include these channels as well. If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from: Manually: - **mean**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**. Automatically: Firstly, we will attempt to find neighboring channel to use as alternative. For instance, if the required channel is **FC3** but you only have **FC1**, we will use it as a replacement for **FC3**. Then, depending on the **Imputation** way you chose, we will: - **zero**: fill the missing channels with zeros. - **adjacent**: fill the missing channels using neighboring channels which are located closer to the center. For example, if the required channel is **FC3** but you only have **F3, C3**, then we will choose **C3** as the imputing value for **FC3**. >Note: The imputed channels **need to be removed** after the data being reconstructed. ### Mapping result Once the mapping process is finished, the **template montage** and the **input montage**(with the channels choosen by the mapping function displaying their names) will be shown. ### Missing channels The channels displayed here are those for which the template didn't find suitable channels to use, and utilized **Imputation** to fill the missing values. Therefore, you need to **remove these channels** after you download the denoised data. ### Template location file You need to use this as the **new location file** for the denoised data. ## 2. Decode data ### Model Select the model you want to use. The detailed description of the models can be found in other pages. """ icunet = """ # IC-U-Net ### Abstract Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting. """ unetpp = """ # IC-U-Net++ ### Abstract Electroencephalographic (EEG) data is considered contaminated with various types of artifacts. Deep learning has been successfully applied to developing EEG artifact removal techniques to increase the signal-to-noise ratio (SNR) and enhance brain-computer interface performance. Recently, our research team has proposed an end-to-end UNet-based EEG artifact removal technique, IC-U-Net, which can reconstruct signals against various artifacts. However, this model suffers from being prone to overfitting with a limited training dataset size and demanding a high computational cost. To address these issues, this study attempted to leverage the architecture of UNet++ to improve the practicability of IC-U-Net by introducing dense skip connections in the encoder-decoder architecture. Results showed that this proposed model obtained superior SNR to the original model with half the number of parameters. Also, this proposed model achieved comparable convergency using a quarter of the training data size. """ chkbox_js = """ (state_json) => { state_json = JSON.parse(JSON.stringify(state_json)); if(state_json.state == "finished") return; document.querySelector("#chs-chkbox>div:nth-of-type(2)").style.cssText = ` position: relative; width: 560px; height: 560px; background: url("file=${state_json.files.raw_montage}"); `; let all_chkbox = document.querySelectorAll("#chs-chkbox> div:nth-of-type(2)> label"); all_chkbox = Array.apply(null, all_chkbox); all_chkbox.forEach((item, index) => { let channel = state_json.inputByIndex[index]; let left = state_json.inputByName[channel].css_position[0]; let bottom = state_json.inputByName[channel].css_position[1]; //console.log(`left: ${left}, bottom: ${bottom}`); item.style.cssText = ` position: absolute; left: ${left}; bottom: ${bottom}; `; item.className = ""; item.querySelector("span").innerText = ""; }); } """ with gr.Blocks() as demo: state_json = gr.JSON(elem_id="state", visible=False) with gr.Row(): gr.Markdown( """ """ ) with gr.Row(): with gr.Column(): gr.Markdown( """ # 1.Channel Mapping """ ) with gr.Row(): in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"]) in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"]) with gr.Row(): in_fill_mode = gr.Dropdown(choices=["zero", ("adjacent channel", "adjacent"), ("mean (manually select channels)", "mean")], value="zero", label="Imputation", scale=2) map_btn = gr.Button("Mapping", scale=1) channels_json = gr.JSON(visible=False) res_md = gr.Markdown( """ ### Mapping result: """, visible=False ) with gr.Row(): tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False) map_montage = gr.Image(label="Choosen channels", visible=False) chs_chkbox = gr.CheckboxGroup(elem_id="chs-chkbox", label="", visible=False) next_btn = gr.Button("Next", interactive=False, visible=False) miss_txtbox = gr.Textbox(label="Missing channels", visible=False) tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False) with gr.Column(): gr.Markdown( """ # 2.Decode Data """ ) with gr.Row(): in_model_name = gr.Dropdown(choices=["IC-U-Net", "IC-U-Net++", "IC-U-Net-Attn", "ART", "(mapped data)"], value="IC-U-Net", label="Model", scale=2) run_btn = gr.Button(scale=1, interactive=False) out_denoised_data = gr.File(label="Denoised data") with gr.Row(): with gr.Tab("ART"): gr.Markdown() with gr.Tab("IC-U-Net"): gr.Markdown(icunet) with gr.Tab("IC-U-Net++"): gr.Markdown(unetpp) with gr.Tab("IC-U-Net-Att"): gr.Markdown() with gr.Tab("QuickStart"): gr.Markdown(quickstart) #demo.load(js=js) def reset_layout(raw_data): # establish temp folder filepath = os.path.dirname(str(raw_data)) try: os.mkdir(filepath+"/temp_data/") except OSError as e: utils.dataDelete(filepath+"/temp_data/") os.mkdir(filepath+"/temp_data/") #print(e) state_obj = { "filepath": filepath+"/temp_data/", "files": {} } return {state_json : state_obj, chs_chkbox : gr.CheckboxGroup(choices=[], value=[], label="", visible=False), # choices, value ??? next_btn : gr.Button("Next", interactive=False, visible=False), run_btn : gr.Button(interactive=False), tpl_montage : gr.Image(visible=False), map_montage : gr.Image(value=None, visible=False), miss_txtbox : gr.Textbox(visible=False), res_md : gr.Markdown(visible=False), tpl_loc_file : gr.File(visible=False)} def mapping_result(state_obj, channels_obj, raw_data, fill_mode): state_obj.update(channels_obj) if fill_mode=="mean" and channels_obj["missingChannelsIndex"]!=[]: state_obj.update({ "state" : "initializing", "fillingCount" : 0, "totalFillingNum" : len(channels_obj["missingChannelsIndex"])-1 }) #print("Missing channels:", state_obj["missingChannelsIndex"]) return {state_json : state_obj, next_btn : gr.Button(visible=True)} else: reorder_data(raw_data, channels_obj["newOrder"], fill_mode, state_obj) missing_channels = [state_obj["templateByIndex"][idx] for idx in state_obj["missingChannelsIndex"]] missing_channels = ', '.join(missing_channels) state_obj.update({ "state" : "finished", #"fillingCount" : -1, #"totalFillingNum" : -1 }) return {state_json : state_obj, res_md : gr.Markdown(visible=True), miss_txtbox : gr.Textbox(value=missing_channels, visible=True), tpl_loc_file : gr.File(visible=True), run_btn : gr.Button(interactive=True)} def show_montage(state_obj, raw_loc): filepath = state_obj["filepath"] raw_montage = read_custom_montage(raw_loc) # convert all channel names to uppercase for i in range(len(raw_montage.ch_names)): channel = raw_montage.ch_names[i] raw_montage.rename_channels({channel: str.upper(channel)}) if state_obj["state"] == "initializing": filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png" state_obj["files"]["raw_montage"] = filename raw_fig = raw_montage.plot() raw_fig.set_size_inches(5.6, 5.6) raw_fig.savefig(filename, pad_inches=0) return {state_json : state_obj}#, #tpl_montage : gr.Image(visible=True), #in_montage : gr.Image(value=filename, visible=True), #map_montage : gr.Image(visible=False)} elif state_obj["state"] == "finished": # didn't find any way to hide the dark points... # tmp filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png" state_obj["files"]["map_montage"] = filename show_names= [] for channel in state_obj["inputByName"]: if state_obj["inputByName"][channel]["used"]: if channel=='CZ' and state_obj["CZImputed"]: continue show_names.append(channel) mapped_fig = raw_montage.plot(show_names=show_names) mapped_fig.set_size_inches(5.6, 5.6) mapped_fig.savefig(filename, pad_inches=0) return {state_json : state_obj, tpl_montage : gr.Image(visible=True), map_montage : gr.Image(value=filename, visible=True)} elif state_obj["state"] == "selecting": # update in_montage here ? #return {in_montage : gr.Image()} return {state_json : state_obj} def generate_chkbox(state_obj): if state_obj["state"] == "initializing": in_channels = [channel for channel in state_obj["inputByName"]] state_obj["state"] = "selecting" first_idx = state_obj["missingChannelsIndex"][0] first_name = state_obj["templateByIndex"][first_idx] chkbox_label = first_name+' (1/'+str(state_obj["totalFillingNum"]+1)+')' return {state_json : state_obj, chs_chkbox : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True), next_btn : gr.Button(interactive=True)} else: return {state_json : state_obj} map_btn.click( fn = reset_layout, inputs = in_raw_data, outputs = [state_json, chs_chkbox, next_btn, run_btn, tpl_montage, map_montage, miss_txtbox, res_md, tpl_loc_file] ).success( fn = mapping, inputs = [in_raw_data, in_raw_loc, in_fill_mode], outputs = channels_json ).success( fn = mapping_result, inputs = [state_json, channels_json, in_raw_data, in_fill_mode], outputs = [state_json, chs_chkbox, next_btn, miss_txtbox, res_md, tpl_loc_file, run_btn] ).success( fn = show_montage, inputs = [state_json, in_raw_loc], outputs = [state_json, tpl_montage, map_montage] ).success( fn = generate_chkbox, inputs = state_json, outputs = [state_json, chs_chkbox, next_btn] ).success( fn = None, js = chkbox_js, inputs = state_json, outputs = [] ) def check_next(state_obj, selected, raw_data, fill_mode): if state_obj["state"] == "selecting": # save info before clicking on next_btn prev_target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]] prev_target_name = state_obj["templateByIndex"][prev_target_idx] selected_idx = [state_obj["inputByName"][channel]["index"] for channel in selected] state_obj["newOrder"][prev_target_idx] = selected_idx if len(selected)==1 and state_obj["inputByName"][selected[0]]["used"]==False: state_obj["inputByName"][selected[0]]["used"] = True state_obj["missingChannelsIndex"][state_obj["fillingCount"]] = -1 print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected)) # update next round state_obj["fillingCount"] += 1 if state_obj["fillingCount"] <= state_obj["totalFillingNum"]: target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]] target_name = state_obj["templateByIndex"][target_idx] chkbox_label = target_name+' ('+str(state_obj["fillingCount"]+1)+'/'+str(state_obj["totalFillingNum"]+1)+')' btn_label = "Submit" if state_obj["fillingCount"]==state_obj["totalFillingNum"] else "Next" return {state_json : state_obj, chs_chkbox : gr.CheckboxGroup(value=[], label=chkbox_label), next_btn : gr.Button(btn_label)} else: state_obj["state"] = "finished" reorder_data(raw_data, state_obj["newOrder"], fill_mode, state_obj) missing_channels = [] for idx in state_obj["missingChannelsIndex"]: if idx != -1: missing_channels.append(state_obj["templateByIndex"][idx]) missing_channels = ', '.join(missing_channels) return {state_json : state_obj, chs_chkbox : gr.CheckboxGroup(visible=False), next_btn : gr.Button(visible=False), res_md : gr.Markdown(visible=True), miss_txtbox : gr.Textbox(value=missing_channels, visible=True), tpl_loc_file : gr.File(visible=True), run_btn : gr.Button(interactive=True)} next_btn.click( fn = check_next, inputs = [state_json, chs_chkbox, in_raw_data, in_fill_mode], outputs = [state_json, chs_chkbox, next_btn, run_btn, res_md, miss_txtbox, tpl_loc_file] ).success( fn = show_montage, inputs = [state_json, in_raw_loc], outputs = [state_json, tpl_montage, map_montage] ) @run_btn.click(inputs=[state_json, in_raw_data, in_model_name], outputs=out_denoised_data) def run_model(state_obj, raw_file, model_name): filepath = state_obj["filepath"] input_name = os.path.basename(str(raw_file)) output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv' if model_name == "(mapped data)": return filepath + 'mapped.csv' # step1: Data preprocessing total_file_num = utils.preprocessing(filepath, 'mapped.csv', 256) # step2: Signal reconstruction utils.reconstruct(model_name, total_file_num, filepath, output_name) return filepath + output_name if __name__ == "__main__": demo.launch()