Spaces:
Runtime error
Runtime error
import utils | |
import time | |
import os | |
import numpy as np | |
import mne | |
from mne.channels import read_custom_montage | |
def reorder_data(filename, old_idx, fill_mode, state_obj): | |
old_data = utils.read_train_data(filename) | |
new_data = np.zeros((30, old_data.shape[1])) | |
new_filename = state_obj["filepath"]+'mapped.csv' | |
#print('old data shape: ', old_data.shape) | |
zero_arr = np.zeros((1, old_data.shape[1])) | |
old_data = np.concatenate((old_data, zero_arr), axis=0) | |
for i in range(30): | |
curr_idx_set = old_idx[i] | |
#print("channel_{}'s index set: {}".format(i, curr_idx_set)) | |
if curr_idx_set == []: | |
new_data[i, :] = zero_arr | |
else: | |
tmp_data = [old_data[j, :] for j in curr_idx_set] | |
new_data[i, :] = np.mean(tmp_data, axis=0) | |
#print('new data shape: ', new_data.shape) | |
utils.save_data(new_data, new_filename) | |
return | |
class Channel: | |
def __init__(self, index, name=None, used=False, coord=None, css_position=None, topo_index=None, topo_position=None): | |
self.name = name | |
self.index = index | |
self.used = used | |
self.coord = coord | |
self.css_position = css_position | |
self.topo_index = topo_index | |
self.topo_position = topo_position | |
def prefix(self): | |
ret = ''.join(filter(str.isalpha, self.name)) | |
return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret | |
def suffix(self): | |
return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name))) | |
def pack_data(new_idx, missing_channels, tpl_dict, in_dict, tpl_ordered_name, in_ordered_name): | |
return { | |
"newOrder" : [([i] if i!=-1 else []) for i in new_idx], | |
"missingChannelsIndex" : missing_channels, | |
"templateByName" : {k : v.__dict__ for k,v in tpl_dict.items()}, # dict, {name:object} | |
"templateByIndex" : tpl_ordered_name, # list | |
"inputByName" : {k : v.__dict__ for k,v in in_dict.items()}, | |
"inputByIndex" : in_ordered_name | |
} | |
def mapping(data_file, loc_file, fill_mode): | |
second1 = time.time() | |
data = utils.read_train_data(data_file) | |
template_dict = {} | |
input_dict = {} | |
template_montage = read_custom_montage("./template_chanlocs.loc") | |
input_montage = read_custom_montage(loc_file) | |
montages = [template_montage, input_montage] | |
dicts = [template_dict, input_dict] | |
num = [30, len(input_montage.ch_names)] | |
for i in range(2): | |
fig = montages[i].plot() | |
fig.set_size_inches(5.6, 5.6) | |
ax = fig.axes[0] | |
ax.set_aspect('equal') | |
ax.figure.canvas.draw() #update the figure | |
coords = ax.collections[0].get_offsets().data | |
abs_coords = ax.transData.transform(coords) | |
#print("abs_coords) | |
for j in range(num[i]): | |
channel = montages[i].ch_names[j] | |
# convert all channel names to uppercase | |
montages[i].rename_channels({channel: str.upper(channel)}) | |
css_left = (abs_coords[j][0]-11)/560 | |
css_bottom = (abs_coords[j][1]-7)/560 | |
channel = str.upper(channel) | |
dicts[i][channel] = Channel(index=j, | |
name=channel, | |
coord=montages[i].get_positions()['ch_pos'][channel], | |
css_position=[str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"] | |
) | |
new_idx = [-1]*30 | |
missing_channels = [] | |
exact_missing_channels = [] | |
z_row_idx = data.shape[0] | |
# STAGE_1 | |
# match the template's channel names with the input ones | |
finish_flag = 1 | |
alias = { | |
'T3': 'T7', | |
'T4': 'T8', | |
'T5': 'P7', | |
'T6': 'P8', | |
'TP7': 'T5\'', | |
'TP8': 'T6\'', | |
} | |
for i in range(30): | |
channel = template_montage.ch_names[i] | |
if channel not in input_dict.keys() | alias.keys(): | |
exact_missing_channels.append(i) | |
finish_flag = 0 | |
continue | |
if channel not in input_dict and channel in alias: | |
if alias[channel] in input_dict: | |
template_montage.rename_channels({channel: alias[channel]}) | |
template_dict[alias[channel]] = template_dict.pop(channel) | |
template_dict[alias[channel]].name = alias[channel] | |
channel = alias[channel] | |
else: | |
exact_missing_channels.append(i) | |
finish_flag = 0 | |
continue | |
new_idx[i] = input_dict[channel].index | |
input_dict[channel].used = True | |
if finish_flag == 1: | |
second2 = time.time() | |
print('Finish at stage 1 ! (',second2 - second1,'s)') | |
#print('new idx order:', new_idx) | |
channels_obj = pack_data(new_idx, [], | |
template_dict, input_dict, | |
template_montage.ch_names, input_montage.ch_names) | |
channels_obj.update({"CZImputed" : False}) | |
return channels_obj | |
elif fill_mode == "mean": | |
channels_obj = pack_data(new_idx, exact_missing_channels, | |
template_dict, input_dict, | |
template_montage.ch_names, input_montage.ch_names) | |
channels_obj.update({"CZImputed" : False}) | |
return channels_obj | |
# STAGE_2 | |
# store channel positions in a 2-d array | |
template_topo_pos = [] | |
temporal_channels = [] | |
temporal_row_prefix = ['FC', 'C', 'CP', 'P'] | |
cnt = 0 | |
for i in range(7): | |
tmp = [] | |
for j in range(5): | |
if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]: | |
tmp.append('') | |
else: | |
channel = template_montage.ch_names[cnt] | |
tmp.append(channel) | |
ver = 'front' if i<3 else 'center' if i==3 else 'back' | |
hor = 'left' if j<2 else 'center' if j==2 else 'right' | |
template_dict[channel].topo_index = [i, j] | |
template_dict[channel].topo_position = [ver, hor] | |
if i > 1 and j in [0, 4]: | |
temporal_channels.append(channel) | |
cnt += 1 | |
template_topo_pos.append(tmp) | |
# ensure that CZ is found or imputed by another channel | |
CZ_impute_flag = False | |
if 'CZ' not in input_dict and fill_mode=='adjacent': | |
CZ_impute_flag = True | |
min_dist = 1e5 | |
for channel in input_montage.ch_names: | |
curr_x, curr_y, curr_z = input_dict[channel].coord.round(6) | |
if curr_x**2 + curr_y**2 < min_dist: | |
nearest_channel = channel | |
min_dist = curr_x**2 + curr_y**2 | |
if input_dict[nearest_channel].used == True: | |
missing_channels.append(template_dict['CZ'].index) | |
input_dict[nearest_channel].used = True | |
input_dict['CZ'] = input_dict[nearest_channel] | |
print("CZ's nearest neighbor:", nearest_channel) | |
for i in range(30): | |
if new_idx[i] != -1: | |
continue | |
channel = template_montage.ch_names[i] | |
curr_prefix = template_dict[channel].prefix() | |
curr_suffix = template_dict[channel].suffix() | |
curr_row = template_dict[channel].topo_index[0] | |
curr_col = template_dict[channel].topo_index[1] | |
curr_ver = template_dict[channel].topo_position[0] | |
curr_hor = template_dict[channel].topo_position[1] | |
impute_channel = '' | |
# if the current channel is a temporal channel | |
if channel in temporal_channels: | |
curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2] | |
curr_suffix = 7 if curr_hor=='left' else 8 | |
if fill_mode == 'zero': | |
impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2) | |
if impute_channel not in input_dict or input_dict[impute_channel].used==True: | |
impute_channel = '' | |
new_idx[i] = z_row_idx | |
missing_channels.append(i) | |
continue | |
elif fill_mode == 'adjacent': | |
ver_dir = 1 if curr_ver == 'front' else -1 | |
if curr_hor == 'center': # FZ, FPZ... | |
if curr_prefix+str(1) in input_dict: # ex: FZ<-F1 | |
impute_channel = curr_prefix + str(1) | |
elif template_topo_pos[curr_row+ver_dir][curr_col] in input_dict: # ex: front:FZ<-FCZ, | |
impute_channel = template_topo_pos[curr_row+ver_dir][curr_col] | |
elif curr_prefix+str(3) in input_dict: # ex: FZ<-F3 | |
impute_channel = curr_prefix + str(3) | |
else: | |
impute_channel = 'CZ' | |
elif curr_hor == 'left' or curr_hor == 'right': | |
ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1 | |
# search horizontally | |
cnt = 0 | |
tmp_suffix = curr_suffix | |
while tmp_suffix > 0: # ex: F7<-F5/F3/F1 | |
tmp_suffix = curr_suffix - 2*cnt | |
if curr_prefix+str(tmp_suffix) in input_dict: | |
impute_channel = curr_prefix + str(tmp_suffix) | |
break | |
if cnt == 2: | |
# check row+1/row-1 | |
if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input_dict: | |
impute_channel = template_topo_pos[curr_row+1][curr_col] | |
break | |
if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input_dict: | |
impute_channel = template_topo_pos[curr_row-1][curr_col] | |
break | |
cnt += 1 | |
# search vertically | |
if impute_channel == '': | |
cnt = 0 | |
tmp_row = curr_row + ver_dir | |
while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one | |
if template_topo_pos[tmp_row][curr_col] in input_dict: | |
impute_channel = template_topo_pos[tmp_row][curr_col] | |
break | |
tmp_row += ver_dir | |
# if still cannot find available channel... | |
if impute_channel == '': | |
impute_channel = 'CZ' | |
new_idx[i] = input_dict[impute_channel].index | |
if input_dict[impute_channel].used == True: # this channel is shared with others | |
missing_channels.append(i) | |
input_dict[impute_channel].used = True | |
second2 = time.time() | |
print('Finish at stage 2 ! (',second2 - second1,'s)') | |
#print('new_idx:', new_idx) | |
channels_obj = pack_data(new_idx, missing_channels, | |
template_dict, input_dict, | |
template_montage.ch_names, input_montage.ch_names) | |
channels_obj.update({"CZImputed" : CZ_impute_flag}) | |
return channels_obj | |
# reload |