import utils import time import os import numpy as np import mne from mne.channels import read_custom_montage def reorder_data(filename, old_idx, fill_mode, state_obj): old_data = utils.read_train_data(filename) new_data = np.zeros((30, old_data.shape[1])) new_filename = state_obj["filepath"]+'mapped.csv' #print('old data shape: ', old_data.shape) zero_arr = np.zeros((1, old_data.shape[1])) old_data = np.concatenate((old_data, zero_arr), axis=0) for i in range(30): curr_idx_set = old_idx[i] #print("channel_{}'s index set: {}".format(i, curr_idx_set)) if curr_idx_set == []: new_data[i, :] = zero_arr else: tmp_data = [old_data[j, :] for j in curr_idx_set] new_data[i, :] = np.mean(tmp_data, axis=0) #print('new data shape: ', new_data.shape) utils.save_data(new_data, new_filename) return class Channel: def __init__(self, index, name=None, used=False, coord=None, css_position=None, topo_index=None, topo_position=None): self.name = name self.index = index self.used = used self.coord = coord self.css_position = css_position self.topo_index = topo_index self.topo_position = topo_position def prefix(self): ret = ''.join(filter(str.isalpha, self.name)) return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret def suffix(self): return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name))) def pack_data(new_idx, missing_channels, tpl_dict, in_dict, tpl_ordered_name, in_ordered_name): return { "newOrder" : [([i] if i!=-1 else []) for i in new_idx], "missingChannelsIndex" : missing_channels, "templateByName" : {k : v.__dict__ for k,v in tpl_dict.items()}, # dict, {name:object} "templateByIndex" : tpl_ordered_name, # list "inputByName" : {k : v.__dict__ for k,v in in_dict.items()}, "inputByIndex" : in_ordered_name } def mapping(data_file, loc_file, fill_mode): second1 = time.time() data = utils.read_train_data(data_file) template_dict = {} input_dict = {} template_montage = read_custom_montage("./template_chanlocs.loc") input_montage = read_custom_montage(loc_file) montages = [template_montage, input_montage] dicts = [template_dict, input_dict] num = [30, len(input_montage.ch_names)] for i in range(2): fig = montages[i].plot() fig.set_size_inches(5.6, 5.6) ax = fig.axes[0] ax.set_aspect('equal') ax.figure.canvas.draw() #update the figure coords = ax.collections[0].get_offsets().data abs_coords = ax.transData.transform(coords) #print("abs_coords) for j in range(num[i]): channel = montages[i].ch_names[j] # convert all channel names to uppercase montages[i].rename_channels({channel: str.upper(channel)}) css_left = (abs_coords[j][0]-11)/560 css_bottom = (abs_coords[j][1]-7)/560 channel = str.upper(channel) dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel], css_position=[str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"] ) new_idx = [-1]*30 missing_channels = [] exact_missing_channels = [] z_row_idx = data.shape[0] # STAGE_1 # match the template's channel names with the input ones finish_flag = 1 alias = { 'T3': 'T7', 'T4': 'T8', 'T5': 'P7', 'T6': 'P8', 'TP7': 'T5\'', 'TP8': 'T6\'', } for i in range(30): channel = template_montage.ch_names[i] if channel not in input_dict.keys() | alias.keys(): exact_missing_channels.append(i) finish_flag = 0 continue if channel not in input_dict and channel in alias: if alias[channel] in input_dict: template_montage.rename_channels({channel: alias[channel]}) template_dict[alias[channel]] = template_dict.pop(channel) template_dict[alias[channel]].name = alias[channel] channel = alias[channel] else: exact_missing_channels.append(i) finish_flag = 0 continue new_idx[i] = input_dict[channel].index input_dict[channel].used = True if finish_flag == 1: second2 = time.time() print('Finish at stage 1 ! (',second2 - second1,'s)') #print('new idx order:', new_idx) channels_obj = pack_data(new_idx, [], template_dict, input_dict, template_montage.ch_names, input_montage.ch_names) channels_obj.update({"CZImputed" : False}) return channels_obj elif fill_mode == "mean": channels_obj = pack_data(new_idx, exact_missing_channels, template_dict, input_dict, template_montage.ch_names, input_montage.ch_names) channels_obj.update({"CZImputed" : False}) return channels_obj # STAGE_2 # store channel positions in a 2-d array template_topo_pos = [] temporal_channels = [] temporal_row_prefix = ['FC', 'C', 'CP', 'P'] cnt = 0 for i in range(7): tmp = [] for j in range(5): if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]: tmp.append('') else: channel = template_montage.ch_names[cnt] tmp.append(channel) ver = 'front' if i<3 else 'center' if i==3 else 'back' hor = 'left' if j<2 else 'center' if j==2 else 'right' template_dict[channel].topo_index = [i, j] template_dict[channel].topo_position = [ver, hor] if i > 1 and j in [0, 4]: temporal_channels.append(channel) cnt += 1 template_topo_pos.append(tmp) # ensure that CZ is found or imputed by another channel CZ_impute_flag = False if 'CZ' not in input_dict and fill_mode=='adjacent': CZ_impute_flag = True min_dist = 1e5 for channel in input_montage.ch_names: curr_x, curr_y, curr_z = input_dict[channel].coord.round(6) if curr_x**2 + curr_y**2 < min_dist: nearest_channel = channel min_dist = curr_x**2 + curr_y**2 if input_dict[nearest_channel].used == True: missing_channels.append(template_dict['CZ'].index) input_dict[nearest_channel].used = True input_dict['CZ'] = input_dict[nearest_channel] print("CZ's nearest neighbor:", nearest_channel) for i in range(30): if new_idx[i] != -1: continue channel = template_montage.ch_names[i] curr_prefix = template_dict[channel].prefix() curr_suffix = template_dict[channel].suffix() curr_row = template_dict[channel].topo_index[0] curr_col = template_dict[channel].topo_index[1] curr_ver = template_dict[channel].topo_position[0] curr_hor = template_dict[channel].topo_position[1] impute_channel = '' # if the current channel is a temporal channel if channel in temporal_channels: curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2] curr_suffix = 7 if curr_hor=='left' else 8 if fill_mode == 'zero': impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2) if impute_channel not in input_dict or input_dict[impute_channel].used==True: impute_channel = '' new_idx[i] = z_row_idx missing_channels.append(i) continue elif fill_mode == 'adjacent': ver_dir = 1 if curr_ver == 'front' else -1 if curr_hor == 'center': # FZ, FPZ... if curr_prefix+str(1) in input_dict: # ex: FZ<-F1 impute_channel = curr_prefix + str(1) elif template_topo_pos[curr_row+ver_dir][curr_col] in input_dict: # ex: front:FZ<-FCZ, impute_channel = template_topo_pos[curr_row+ver_dir][curr_col] elif curr_prefix+str(3) in input_dict: # ex: FZ<-F3 impute_channel = curr_prefix + str(3) else: impute_channel = 'CZ' elif curr_hor == 'left' or curr_hor == 'right': ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1 # search horizontally cnt = 0 tmp_suffix = curr_suffix while tmp_suffix > 0: # ex: F7<-F5/F3/F1 tmp_suffix = curr_suffix - 2*cnt if curr_prefix+str(tmp_suffix) in input_dict: impute_channel = curr_prefix + str(tmp_suffix) break if cnt == 2: # check row+1/row-1 if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input_dict: impute_channel = template_topo_pos[curr_row+1][curr_col] break if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input_dict: impute_channel = template_topo_pos[curr_row-1][curr_col] break cnt += 1 # search vertically if impute_channel == '': cnt = 0 tmp_row = curr_row + ver_dir while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one if template_topo_pos[tmp_row][curr_col] in input_dict: impute_channel = template_topo_pos[tmp_row][curr_col] break tmp_row += ver_dir # if still cannot find available channel... if impute_channel == '': impute_channel = 'CZ' new_idx[i] = input_dict[impute_channel].index if input_dict[impute_channel].used == True: # this channel is shared with others missing_channels.append(i) input_dict[impute_channel].used = True second2 = time.time() print('Finish at stage 2 ! (',second2 - second1,'s)') #print('new_idx:', new_idx) channels_obj = pack_data(new_idx, missing_channels, template_dict, input_dict, template_montage.ch_names, input_montage.ch_names) channels_obj.update({"CZImputed" : CZ_impute_flag}) return channels_obj # reload