Files changed (1) hide show
  1. app_utils.py +331 -0
app_utils.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utils
2
+ import os
3
+ import math
4
+ import json
5
+ import jsbeautifier
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import mne
9
+ from mne.channels import read_custom_montage
10
+ from scipy.interpolate import Rbf
11
+ from scipy.optimize import linear_sum_assignment
12
+ from sklearn.neighbors import NearestNeighbors
13
+
14
+ def get_matched(tpl_names, tpl_dict):
15
+ return [name for name in tpl_names if tpl_dict[name]["matched"]==True]
16
+
17
+ def get_empty_template(tpl_names, tpl_dict):
18
+ return [name for name in tpl_names if tpl_dict[name]["matched"]==False]
19
+
20
+ def get_unassigned_input(in_names, in_dict):
21
+ return [name for name in in_names if in_dict[name]["assigned"]==False]
22
+
23
+ def read_montage(loc_file):
24
+ tpl_montage = read_custom_montage("./template_chanlocs.loc")
25
+ in_montage = read_custom_montage(loc_file)
26
+ tpl_names = tpl_montage.ch_names
27
+ in_names = in_montage.ch_names
28
+ tpl_dict = {}
29
+ in_dict = {}
30
+
31
+ # convert all channel names to uppercase and store their information
32
+ for i, name in enumerate(tpl_names):
33
+ up_name = str.upper(name)
34
+ tpl_montage.rename_channels({name: up_name})
35
+ tpl_dict[up_name] = {
36
+ "index" : i,
37
+ "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_name],
38
+ "matched" : False
39
+ }
40
+ for i, name in enumerate(in_names):
41
+ up_name = str.upper(name)
42
+ in_montage.rename_channels({name: up_name})
43
+ in_dict[up_name] = {
44
+ "index" : i,
45
+ "coord_3d" : in_montage.get_positions()['ch_pos'][up_name],
46
+ "assigned" : False
47
+ }
48
+ return tpl_montage, in_montage, tpl_dict, in_dict
49
+
50
+ def match_name(stage1_info):
51
+ # read the location file
52
+ loc_file = stage1_info["fileNames"]["inputData"]
53
+ tpl_montage, in_montage, tpl_dict, in_dict = read_montage(loc_file)
54
+ tpl_names = tpl_montage.ch_names
55
+ in_names = in_montage.ch_names
56
+ old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
57
+ is_orig_data = [False]*30
58
+
59
+ alias_dict = {
60
+ 'T3': 'T7',
61
+ 'T4': 'T8',
62
+ 'T5': 'P7',
63
+ 'T6': 'P8'
64
+ }
65
+ for i, name in enumerate(tpl_names):
66
+ if name in alias_dict and alias_dict[name] in in_dict:
67
+ tpl_montage.rename_channels({name: alias_dict[name]})
68
+ tpl_dict[alias_dict[name]] = tpl_dict.pop(name)
69
+ name = alias_dict[name]
70
+
71
+ if name in in_dict:
72
+ old_idx[i] = [in_dict[name]["index"]]
73
+ is_orig_data[i] = True
74
+ tpl_dict[name]["matched"] = True
75
+ in_dict[name]["assigned"] = True
76
+
77
+ # update the names
78
+ tpl_names = tpl_montage.ch_names
79
+
80
+ stage1_info.update({
81
+ "unassignedInput" : get_unassigned_input(in_names, in_dict),
82
+ "emptyTemplate" : get_empty_template(tpl_names, tpl_dict),
83
+ "mappingResult" : [
84
+ {
85
+ "index" : old_idx,
86
+ "isOriginalData" : is_orig_data
87
+ }
88
+ ]
89
+ })
90
+ channel_info = {
91
+ "templateNames" : tpl_names,
92
+ "inputNames" : in_names,
93
+ "templateDict" : tpl_dict,
94
+ "inputDict" : in_dict
95
+ }
96
+ return stage1_info, channel_info, tpl_montage, in_montage
97
+
98
+ def align_coords(channel_info, tpl_montage, in_montage):
99
+ tpl_names = channel_info["templateNames"]
100
+ in_names = channel_info["inputNames"]
101
+ tpl_dict = channel_info["templateDict"]
102
+ in_dict = channel_info["inputDict"]
103
+ matched_names = get_matched(tpl_names, tpl_dict)
104
+
105
+ # 2D alignment (for visualization purposes)
106
+ fig = [tpl_montage.plot(), in_montage.plot()]
107
+ ax = [fig[0].axes[0], fig[1].axes[0]]
108
+
109
+ # extract the displayed 2D coordinates
110
+ all_tpl = ax[0].collections[0].get_offsets().data
111
+ all_in= ax[1].collections[0].get_offsets().data
112
+ matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
113
+ matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
114
+ plt.close('all')
115
+
116
+ # apply TPS to transform in_channels to align with tpl_channels positions
117
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
118
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
119
+
120
+ # apply the transformation to all in_channels
121
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
122
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
123
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
124
+
125
+ for i, name in enumerate(tpl_names):
126
+ tpl_dict[name]["coord_2d"] = all_tpl[i]
127
+ for i, name in enumerate(in_names):
128
+ in_dict[name]["coord_2d"] = transformed_in[i].tolist()
129
+
130
+
131
+ # 3D alignment
132
+ all_tpl = np.array([tpl_dict[name]["coord_3d"].tolist() for name in tpl_names])
133
+ all_in = np.array([in_dict[name]["coord_3d"].tolist() for name in in_names])
134
+ matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
135
+ matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
136
+
137
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
138
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
139
+ rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
140
+
141
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
142
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
143
+ transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
144
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
145
+
146
+ for i, name in enumerate(in_names):
147
+ in_dict[name]["coord_3d"] = transformed_in[i].tolist()
148
+
149
+ channel_info.update({
150
+ "templateDict" : tpl_dict,
151
+ "inputDict" : in_dict
152
+ })
153
+ return channel_info
154
+
155
+ def save_figure(channel_info, tpl_montage, filename1, filename2):
156
+ tpl_names = channel_info["templateNames"]
157
+ in_names = channel_info["inputNames"]
158
+ tpl_dict = channel_info["templateDict"]
159
+ in_dict = channel_info["inputDict"]
160
+
161
+ tpl_x = [tpl_dict[name]["coord_2d"][0] for name in tpl_names]
162
+ tpl_y = [tpl_dict[name]["coord_2d"][1] for name in tpl_names]
163
+ in_x = [in_dict[name]["coord_2d"][0] for name in in_names]
164
+ in_y = [in_dict[name]["coord_2d"][1] for name in in_names]
165
+ tpl_coords = np.vstack((tpl_x, tpl_y)).T
166
+ in_coords = np.vstack((in_x, in_y)).T
167
+
168
+ # extract template's head figure
169
+ tpl_fig = tpl_montage.plot()
170
+ tpl_ax = tpl_fig.axes[0]
171
+ lines = tpl_ax.lines
172
+ head_lines = []
173
+ for line in lines:
174
+ x, y = line.get_data()
175
+ head_lines.append((x,y))
176
+
177
+ # -------------------------plot input montage------------------------------
178
+ fig = plt.figure(figsize=(6.4,6.4), dpi=100)
179
+ ax = fig.add_subplot(111)
180
+ fig.tight_layout()
181
+ ax.set_aspect('equal')
182
+ ax.axis('off')
183
+
184
+ # plot template's head
185
+ for x, y in head_lines:
186
+ ax.plot(x, y, color='black', linewidth=1.0)
187
+ # plot in_channels on it
188
+ ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
189
+ for i, name in enumerate(in_names):
190
+ ax.text(in_coords[i,0]+0.004, in_coords[i,1], name, color='black', fontsize=10.0, va='center')
191
+ # save input_montage
192
+ fig.savefig(filename1)
193
+
194
+ # ---------------------------add indications-------------------------------
195
+ # plot unmatched input channels in red
196
+ indices = [in_dict[name]["index"] for name in in_names if in_dict[name]["assigned"]==False]
197
+ if indices != []:
198
+ ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
199
+ for i in indices:
200
+ ax.text(in_coords[i,0]+0.004, in_coords[i,1], in_names[i], color='red', fontsize=10.0, va='center')
201
+ # save mapped_montage
202
+ fig.savefig(filename2)
203
+
204
+ # -------------------------------------------------------------------------
205
+ # store the tpl and in_channels' display positions (in px)
206
+ tpl_coords = ax.transData.transform(tpl_coords)
207
+ in_coords = ax.transData.transform(in_coords)
208
+ plt.close('all')
209
+
210
+ for i, name in enumerate(tpl_names):
211
+ left = tpl_coords[i,0]/6.4
212
+ bottom = tpl_coords[i,1]/6.4
213
+ tpl_dict[name]["css_position"] = [round(left, 2), round(bottom, 2)]
214
+ for i, name in enumerate(in_names):
215
+ left = in_coords[i,0]/6.4
216
+ bottom = in_coords[i,1]/6.4
217
+ in_dict[name]["css_position"] = [round(left, 2), round(bottom, 2)]
218
+
219
+ channel_info.update({
220
+ "templateDict" : tpl_dict,
221
+ "inputDict" : in_dict
222
+ })
223
+ return channel_info
224
+
225
+ def find_neighbors(channel_info, empty_tpl_names, old_idx):
226
+ in_names = channel_info["inputNames"]
227
+ tpl_dict = channel_info["templateDict"]
228
+ in_dict = channel_info["inputDict"]
229
+
230
+ all_in = [np.array(in_dict[name]["coord_3d"]) for name in in_names]
231
+ empty_tpl = [np.array(tpl_dict[name]["coord_3d"]) for name in empty_tpl_names]
232
+
233
+ # use KNN to choose k nearest channels
234
+ k = 4 if len(in_names)>4 else len(in_names)
235
+ knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
236
+ knn.fit(all_in)
237
+ for i, name in enumerate(empty_tpl_names):
238
+ distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
239
+ idx = tpl_dict[name]["index"]
240
+ old_idx[idx] = indices[0].tolist()
241
+
242
+ return old_idx
243
+
244
+ def optimal_mapping(channel_info):
245
+ tpl_names = channel_info["templateNames"]
246
+ in_names = channel_info["inputNames"]
247
+ tpl_dict = channel_info["templateDict"]
248
+ in_dict = channel_info["inputDict"]
249
+ unass_in_names = get_unassigned_input(in_names, in_dict)
250
+ # reset all tpl.matched to False
251
+ for name in tpl_dict:
252
+ tpl_dict[name]["matched"] = False
253
+
254
+ all_tpl = np.array([tpl_dict[name]["coord_3d"] for name in tpl_names])
255
+ unass_in = np.array([in_dict[name]["coord_3d"] for name in unass_in_names])
256
+
257
+ # initialize the cost matrix for the Hungarian algorithm
258
+ if len(unass_in_names) < 30:
259
+ cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
260
+ else:
261
+ cost_matrix = np.zeros((30, len(unass_in_names)))
262
+ # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
263
+ for i in range(30):
264
+ for j in range(len(unass_in_names)):
265
+ cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
266
+
267
+ # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
268
+ # by minimizing the total distances between their positions.
269
+ row_idx, col_idx = linear_sum_assignment(cost_matrix)
270
+
271
+ # store the mapping result
272
+ old_idx = [[None]]*30
273
+ is_orig_data = [False]*30
274
+ for i, j in zip(row_idx, col_idx):
275
+ if j < len(unass_in_names): # filter out dummy channels
276
+ tpl_name = tpl_names[i]
277
+ in_name = unass_in_names[j]
278
+
279
+ old_idx[i] = [in_dict[in_name]["index"]]
280
+ is_orig_data[i] = True
281
+ tpl_dict[tpl_name]["matched"] = True
282
+ in_dict[in_name]["assigned"] = True
283
+
284
+ # fill the remaining empty tpl_channels
285
+ empty_tpl_names = get_empty_template(tpl_names, tpl_dict)
286
+ if empty_tpl_names != []:
287
+ old_idx = find_neighbors(channel_info, empty_tpl_names, old_idx)
288
+
289
+ result = {
290
+ "index" : old_idx,
291
+ "isOriginalData" : is_orig_data
292
+ }
293
+ channel_info["inputDict"] = in_dict
294
+ return result, channel_info
295
+
296
+ def mapping_result(stage1_info, channel_info, filename):
297
+ unassigned_num = len(stage1_info["unassignedInput"])
298
+ batch = math.ceil(unassigned_num/30) + 1
299
+
300
+ # map the remaining in_channels
301
+ results = stage1_info["mappingResult"]
302
+ for i in range(1, batch):
303
+ # optimally select 30 in_channels to map to the tpl_channels based on proximity
304
+ result, channel_info = optimal_mapping(channel_info)
305
+ results += [result]
306
+ '''
307
+ for i in range(batch):
308
+ results[i]["name"] = {}
309
+ for j, indices in enumerate(results[i]["index"]):
310
+ names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
311
+ results[i]["name"][channel_info["templateNames"][j]] = names
312
+ '''
313
+ data = {
314
+ #"templateNames" : channel_info["templateNames"],
315
+ #"inputNames" : channel_info["inputNames"],
316
+ "channelNum" : len(channel_info["inputNames"]),
317
+ "batch" : batch,
318
+ "mappingResult" : results
319
+ }
320
+ options = jsbeautifier.default_options()
321
+ options.indent_size = 4
322
+ json_data = jsbeautifier.beautify(json.dumps(data), options)
323
+ with open(filename, 'w') as jsonfile:
324
+ jsonfile.write(json_data)
325
+
326
+ stage1_info.update({
327
+ "batch" : batch,
328
+ "mappingResult" : results
329
+ })
330
+ return stage1_info, channel_info
331
+