ArtifactRemovalTransformer / channel_mapping.py
roseDwayane's picture
test
115555e
import utils
import time
import os
import numpy as np
import mne
from mne.channels import read_custom_montage
def reorder_data(filename, old_idx, fill_mode, state_obj):
old_data = utils.read_train_data(filename)
new_data = np.zeros((30, old_data.shape[1]))
new_filename = state_obj["filepath"]+'mapped.csv'
#print('old data shape: ', old_data.shape)
zero_arr = np.zeros((1, old_data.shape[1]))
old_data = np.concatenate((old_data, zero_arr), axis=0)
for i in range(30):
curr_idx_set = old_idx[i]
#print("channel_{}'s index set: {}".format(i, curr_idx_set))
if curr_idx_set == []:
new_data[i, :] = zero_arr
else:
tmp_data = [old_data[j, :] for j in curr_idx_set]
new_data[i, :] = np.mean(tmp_data, axis=0)
#print('new data shape: ', new_data.shape)
utils.save_data(new_data, new_filename)
return
class Channel:
def __init__(self, index, name=None, used=False, coord=None, css_position=None, topo_index=None, topo_position=None):
self.name = name
self.index = index
self.used = used
self.coord = coord
self.css_position = css_position
self.topo_index = topo_index
self.topo_position = topo_position
def prefix(self):
ret = ''.join(filter(str.isalpha, self.name))
return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret
def suffix(self):
return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name)))
def pack_data(new_idx, missing_channels, tpl_dict, in_dict, tpl_ordered_name, in_ordered_name):
return {
"newOrder" : [([i] if i!=-1 else []) for i in new_idx],
"missingChannelsIndex" : missing_channels,
"templateByName" : {k : v.__dict__ for k,v in tpl_dict.items()}, # dict, {name:object}
"templateByIndex" : tpl_ordered_name, # list
"inputByName" : {k : v.__dict__ for k,v in in_dict.items()},
"inputByIndex" : in_ordered_name
}
def mapping(data_file, loc_file, fill_mode):
second1 = time.time()
data = utils.read_train_data(data_file)
template_dict = {}
input_dict = {}
template_montage = read_custom_montage("./template_chanlocs.loc")
input_montage = read_custom_montage(loc_file)
montages = [template_montage, input_montage]
dicts = [template_dict, input_dict]
num = [30, len(input_montage.ch_names)]
for i in range(2):
fig = montages[i].plot()
fig.set_size_inches(5.6, 5.6)
ax = fig.axes[0]
ax.set_aspect('equal')
ax.figure.canvas.draw() #update the figure
coords = ax.collections[0].get_offsets().data
abs_coords = ax.transData.transform(coords)
#print("abs_coords)
for j in range(num[i]):
channel = montages[i].ch_names[j]
# convert all channel names to uppercase
montages[i].rename_channels({channel: str.upper(channel)})
css_left = (abs_coords[j][0]-11)/560
css_bottom = (abs_coords[j][1]-7)/560
channel = str.upper(channel)
dicts[i][channel] = Channel(index=j,
name=channel,
coord=montages[i].get_positions()['ch_pos'][channel],
css_position=[str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
)
new_idx = [-1]*30
missing_channels = []
exact_missing_channels = []
z_row_idx = data.shape[0]
# STAGE_1
# match the template's channel names with the input ones
finish_flag = 1
alias = {
'T3': 'T7',
'T4': 'T8',
'T5': 'P7',
'T6': 'P8',
'TP7': 'T5\'',
'TP8': 'T6\'',
}
for i in range(30):
channel = template_montage.ch_names[i]
if channel not in input_dict.keys() | alias.keys():
exact_missing_channels.append(i)
finish_flag = 0
continue
if channel not in input_dict and channel in alias:
if alias[channel] in input_dict:
template_montage.rename_channels({channel: alias[channel]})
template_dict[alias[channel]] = template_dict.pop(channel)
template_dict[alias[channel]].name = alias[channel]
channel = alias[channel]
else:
exact_missing_channels.append(i)
finish_flag = 0
continue
new_idx[i] = input_dict[channel].index
input_dict[channel].used = True
if finish_flag == 1:
second2 = time.time()
print('Finish at stage 1 ! (',second2 - second1,'s)')
#print('new idx order:', new_idx)
channels_obj = pack_data(new_idx, [],
template_dict, input_dict,
template_montage.ch_names, input_montage.ch_names)
channels_obj.update({"CZImputed" : False})
return channels_obj
elif fill_mode == "mean":
channels_obj = pack_data(new_idx, exact_missing_channels,
template_dict, input_dict,
template_montage.ch_names, input_montage.ch_names)
channels_obj.update({"CZImputed" : False})
return channels_obj
# STAGE_2
# store channel positions in a 2-d array
template_topo_pos = []
temporal_channels = []
temporal_row_prefix = ['FC', 'C', 'CP', 'P']
cnt = 0
for i in range(7):
tmp = []
for j in range(5):
if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
tmp.append('')
else:
channel = template_montage.ch_names[cnt]
tmp.append(channel)
ver = 'front' if i<3 else 'center' if i==3 else 'back'
hor = 'left' if j<2 else 'center' if j==2 else 'right'
template_dict[channel].topo_index = [i, j]
template_dict[channel].topo_position = [ver, hor]
if i > 1 and j in [0, 4]:
temporal_channels.append(channel)
cnt += 1
template_topo_pos.append(tmp)
# ensure that CZ is found or imputed by another channel
CZ_impute_flag = False
if 'CZ' not in input_dict and fill_mode=='adjacent':
CZ_impute_flag = True
min_dist = 1e5
for channel in input_montage.ch_names:
curr_x, curr_y, curr_z = input_dict[channel].coord.round(6)
if curr_x**2 + curr_y**2 < min_dist:
nearest_channel = channel
min_dist = curr_x**2 + curr_y**2
if input_dict[nearest_channel].used == True:
missing_channels.append(template_dict['CZ'].index)
input_dict[nearest_channel].used = True
input_dict['CZ'] = input_dict[nearest_channel]
print("CZ's nearest neighbor:", nearest_channel)
for i in range(30):
if new_idx[i] != -1:
continue
channel = template_montage.ch_names[i]
curr_prefix = template_dict[channel].prefix()
curr_suffix = template_dict[channel].suffix()
curr_row = template_dict[channel].topo_index[0]
curr_col = template_dict[channel].topo_index[1]
curr_ver = template_dict[channel].topo_position[0]
curr_hor = template_dict[channel].topo_position[1]
impute_channel = ''
# if the current channel is a temporal channel
if channel in temporal_channels:
curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2]
curr_suffix = 7 if curr_hor=='left' else 8
if fill_mode == 'zero':
impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2)
if impute_channel not in input_dict or input_dict[impute_channel].used==True:
impute_channel = ''
new_idx[i] = z_row_idx
missing_channels.append(i)
continue
elif fill_mode == 'adjacent':
ver_dir = 1 if curr_ver == 'front' else -1
if curr_hor == 'center': # FZ, FPZ...
if curr_prefix+str(1) in input_dict: # ex: FZ<-F1
impute_channel = curr_prefix + str(1)
elif template_topo_pos[curr_row+ver_dir][curr_col] in input_dict: # ex: front:FZ<-FCZ,
impute_channel = template_topo_pos[curr_row+ver_dir][curr_col]
elif curr_prefix+str(3) in input_dict: # ex: FZ<-F3
impute_channel = curr_prefix + str(3)
else:
impute_channel = 'CZ'
elif curr_hor == 'left' or curr_hor == 'right':
ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1
# search horizontally
cnt = 0
tmp_suffix = curr_suffix
while tmp_suffix > 0: # ex: F7<-F5/F3/F1
tmp_suffix = curr_suffix - 2*cnt
if curr_prefix+str(tmp_suffix) in input_dict:
impute_channel = curr_prefix + str(tmp_suffix)
break
if cnt == 2:
# check row+1/row-1
if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input_dict:
impute_channel = template_topo_pos[curr_row+1][curr_col]
break
if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input_dict:
impute_channel = template_topo_pos[curr_row-1][curr_col]
break
cnt += 1
# search vertically
if impute_channel == '':
cnt = 0
tmp_row = curr_row + ver_dir
while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one
if template_topo_pos[tmp_row][curr_col] in input_dict:
impute_channel = template_topo_pos[tmp_row][curr_col]
break
tmp_row += ver_dir
# if still cannot find available channel...
if impute_channel == '':
impute_channel = 'CZ'
new_idx[i] = input_dict[impute_channel].index
if input_dict[impute_channel].used == True: # this channel is shared with others
missing_channels.append(i)
input_dict[impute_channel].used = True
second2 = time.time()
print('Finish at stage 2 ! (',second2 - second1,'s)')
#print('new_idx:', new_idx)
channels_obj = pack_data(new_idx, missing_channels,
template_dict, input_dict,
template_montage.ch_names, input_montage.ch_names)
channels_obj.update({"CZImputed" : CZ_impute_flag})
return channels_obj
# reload