Spaces:
Runtime error
Runtime error
Upload app_utils.py
#2
by
roseDwayane
- opened
- app_utils.py +331 -0
app_utils.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import utils
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import jsbeautifier
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import mne
|
9 |
+
from mne.channels import read_custom_montage
|
10 |
+
from scipy.interpolate import Rbf
|
11 |
+
from scipy.optimize import linear_sum_assignment
|
12 |
+
from sklearn.neighbors import NearestNeighbors
|
13 |
+
|
14 |
+
def get_matched(tpl_names, tpl_dict):
|
15 |
+
return [name for name in tpl_names if tpl_dict[name]["matched"]==True]
|
16 |
+
|
17 |
+
def get_empty_template(tpl_names, tpl_dict):
|
18 |
+
return [name for name in tpl_names if tpl_dict[name]["matched"]==False]
|
19 |
+
|
20 |
+
def get_unassigned_input(in_names, in_dict):
|
21 |
+
return [name for name in in_names if in_dict[name]["assigned"]==False]
|
22 |
+
|
23 |
+
def read_montage(loc_file):
|
24 |
+
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
25 |
+
in_montage = read_custom_montage(loc_file)
|
26 |
+
tpl_names = tpl_montage.ch_names
|
27 |
+
in_names = in_montage.ch_names
|
28 |
+
tpl_dict = {}
|
29 |
+
in_dict = {}
|
30 |
+
|
31 |
+
# convert all channel names to uppercase and store their information
|
32 |
+
for i, name in enumerate(tpl_names):
|
33 |
+
up_name = str.upper(name)
|
34 |
+
tpl_montage.rename_channels({name: up_name})
|
35 |
+
tpl_dict[up_name] = {
|
36 |
+
"index" : i,
|
37 |
+
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_name],
|
38 |
+
"matched" : False
|
39 |
+
}
|
40 |
+
for i, name in enumerate(in_names):
|
41 |
+
up_name = str.upper(name)
|
42 |
+
in_montage.rename_channels({name: up_name})
|
43 |
+
in_dict[up_name] = {
|
44 |
+
"index" : i,
|
45 |
+
"coord_3d" : in_montage.get_positions()['ch_pos'][up_name],
|
46 |
+
"assigned" : False
|
47 |
+
}
|
48 |
+
return tpl_montage, in_montage, tpl_dict, in_dict
|
49 |
+
|
50 |
+
def match_name(stage1_info):
|
51 |
+
# read the location file
|
52 |
+
loc_file = stage1_info["fileNames"]["inputData"]
|
53 |
+
tpl_montage, in_montage, tpl_dict, in_dict = read_montage(loc_file)
|
54 |
+
tpl_names = tpl_montage.ch_names
|
55 |
+
in_names = in_montage.ch_names
|
56 |
+
old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
|
57 |
+
is_orig_data = [False]*30
|
58 |
+
|
59 |
+
alias_dict = {
|
60 |
+
'T3': 'T7',
|
61 |
+
'T4': 'T8',
|
62 |
+
'T5': 'P7',
|
63 |
+
'T6': 'P8'
|
64 |
+
}
|
65 |
+
for i, name in enumerate(tpl_names):
|
66 |
+
if name in alias_dict and alias_dict[name] in in_dict:
|
67 |
+
tpl_montage.rename_channels({name: alias_dict[name]})
|
68 |
+
tpl_dict[alias_dict[name]] = tpl_dict.pop(name)
|
69 |
+
name = alias_dict[name]
|
70 |
+
|
71 |
+
if name in in_dict:
|
72 |
+
old_idx[i] = [in_dict[name]["index"]]
|
73 |
+
is_orig_data[i] = True
|
74 |
+
tpl_dict[name]["matched"] = True
|
75 |
+
in_dict[name]["assigned"] = True
|
76 |
+
|
77 |
+
# update the names
|
78 |
+
tpl_names = tpl_montage.ch_names
|
79 |
+
|
80 |
+
stage1_info.update({
|
81 |
+
"unassignedInput" : get_unassigned_input(in_names, in_dict),
|
82 |
+
"emptyTemplate" : get_empty_template(tpl_names, tpl_dict),
|
83 |
+
"mappingResult" : [
|
84 |
+
{
|
85 |
+
"index" : old_idx,
|
86 |
+
"isOriginalData" : is_orig_data
|
87 |
+
}
|
88 |
+
]
|
89 |
+
})
|
90 |
+
channel_info = {
|
91 |
+
"templateNames" : tpl_names,
|
92 |
+
"inputNames" : in_names,
|
93 |
+
"templateDict" : tpl_dict,
|
94 |
+
"inputDict" : in_dict
|
95 |
+
}
|
96 |
+
return stage1_info, channel_info, tpl_montage, in_montage
|
97 |
+
|
98 |
+
def align_coords(channel_info, tpl_montage, in_montage):
|
99 |
+
tpl_names = channel_info["templateNames"]
|
100 |
+
in_names = channel_info["inputNames"]
|
101 |
+
tpl_dict = channel_info["templateDict"]
|
102 |
+
in_dict = channel_info["inputDict"]
|
103 |
+
matched_names = get_matched(tpl_names, tpl_dict)
|
104 |
+
|
105 |
+
# 2D alignment (for visualization purposes)
|
106 |
+
fig = [tpl_montage.plot(), in_montage.plot()]
|
107 |
+
ax = [fig[0].axes[0], fig[1].axes[0]]
|
108 |
+
|
109 |
+
# extract the displayed 2D coordinates
|
110 |
+
all_tpl = ax[0].collections[0].get_offsets().data
|
111 |
+
all_in= ax[1].collections[0].get_offsets().data
|
112 |
+
matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
|
113 |
+
matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
|
114 |
+
plt.close('all')
|
115 |
+
|
116 |
+
# apply TPS to transform in_channels to align with tpl_channels positions
|
117 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
118 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
119 |
+
|
120 |
+
# apply the transformation to all in_channels
|
121 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
122 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
123 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
124 |
+
|
125 |
+
for i, name in enumerate(tpl_names):
|
126 |
+
tpl_dict[name]["coord_2d"] = all_tpl[i]
|
127 |
+
for i, name in enumerate(in_names):
|
128 |
+
in_dict[name]["coord_2d"] = transformed_in[i].tolist()
|
129 |
+
|
130 |
+
|
131 |
+
# 3D alignment
|
132 |
+
all_tpl = np.array([tpl_dict[name]["coord_3d"].tolist() for name in tpl_names])
|
133 |
+
all_in = np.array([in_dict[name]["coord_3d"].tolist() for name in in_names])
|
134 |
+
matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
|
135 |
+
matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
|
136 |
+
|
137 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
138 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
139 |
+
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
140 |
+
|
141 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
142 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
143 |
+
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
144 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
145 |
+
|
146 |
+
for i, name in enumerate(in_names):
|
147 |
+
in_dict[name]["coord_3d"] = transformed_in[i].tolist()
|
148 |
+
|
149 |
+
channel_info.update({
|
150 |
+
"templateDict" : tpl_dict,
|
151 |
+
"inputDict" : in_dict
|
152 |
+
})
|
153 |
+
return channel_info
|
154 |
+
|
155 |
+
def save_figure(channel_info, tpl_montage, filename1, filename2):
|
156 |
+
tpl_names = channel_info["templateNames"]
|
157 |
+
in_names = channel_info["inputNames"]
|
158 |
+
tpl_dict = channel_info["templateDict"]
|
159 |
+
in_dict = channel_info["inputDict"]
|
160 |
+
|
161 |
+
tpl_x = [tpl_dict[name]["coord_2d"][0] for name in tpl_names]
|
162 |
+
tpl_y = [tpl_dict[name]["coord_2d"][1] for name in tpl_names]
|
163 |
+
in_x = [in_dict[name]["coord_2d"][0] for name in in_names]
|
164 |
+
in_y = [in_dict[name]["coord_2d"][1] for name in in_names]
|
165 |
+
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
166 |
+
in_coords = np.vstack((in_x, in_y)).T
|
167 |
+
|
168 |
+
# extract template's head figure
|
169 |
+
tpl_fig = tpl_montage.plot()
|
170 |
+
tpl_ax = tpl_fig.axes[0]
|
171 |
+
lines = tpl_ax.lines
|
172 |
+
head_lines = []
|
173 |
+
for line in lines:
|
174 |
+
x, y = line.get_data()
|
175 |
+
head_lines.append((x,y))
|
176 |
+
|
177 |
+
# -------------------------plot input montage------------------------------
|
178 |
+
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
179 |
+
ax = fig.add_subplot(111)
|
180 |
+
fig.tight_layout()
|
181 |
+
ax.set_aspect('equal')
|
182 |
+
ax.axis('off')
|
183 |
+
|
184 |
+
# plot template's head
|
185 |
+
for x, y in head_lines:
|
186 |
+
ax.plot(x, y, color='black', linewidth=1.0)
|
187 |
+
# plot in_channels on it
|
188 |
+
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
189 |
+
for i, name in enumerate(in_names):
|
190 |
+
ax.text(in_coords[i,0]+0.004, in_coords[i,1], name, color='black', fontsize=10.0, va='center')
|
191 |
+
# save input_montage
|
192 |
+
fig.savefig(filename1)
|
193 |
+
|
194 |
+
# ---------------------------add indications-------------------------------
|
195 |
+
# plot unmatched input channels in red
|
196 |
+
indices = [in_dict[name]["index"] for name in in_names if in_dict[name]["assigned"]==False]
|
197 |
+
if indices != []:
|
198 |
+
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
199 |
+
for i in indices:
|
200 |
+
ax.text(in_coords[i,0]+0.004, in_coords[i,1], in_names[i], color='red', fontsize=10.0, va='center')
|
201 |
+
# save mapped_montage
|
202 |
+
fig.savefig(filename2)
|
203 |
+
|
204 |
+
# -------------------------------------------------------------------------
|
205 |
+
# store the tpl and in_channels' display positions (in px)
|
206 |
+
tpl_coords = ax.transData.transform(tpl_coords)
|
207 |
+
in_coords = ax.transData.transform(in_coords)
|
208 |
+
plt.close('all')
|
209 |
+
|
210 |
+
for i, name in enumerate(tpl_names):
|
211 |
+
left = tpl_coords[i,0]/6.4
|
212 |
+
bottom = tpl_coords[i,1]/6.4
|
213 |
+
tpl_dict[name]["css_position"] = [round(left, 2), round(bottom, 2)]
|
214 |
+
for i, name in enumerate(in_names):
|
215 |
+
left = in_coords[i,0]/6.4
|
216 |
+
bottom = in_coords[i,1]/6.4
|
217 |
+
in_dict[name]["css_position"] = [round(left, 2), round(bottom, 2)]
|
218 |
+
|
219 |
+
channel_info.update({
|
220 |
+
"templateDict" : tpl_dict,
|
221 |
+
"inputDict" : in_dict
|
222 |
+
})
|
223 |
+
return channel_info
|
224 |
+
|
225 |
+
def find_neighbors(channel_info, empty_tpl_names, old_idx):
|
226 |
+
in_names = channel_info["inputNames"]
|
227 |
+
tpl_dict = channel_info["templateDict"]
|
228 |
+
in_dict = channel_info["inputDict"]
|
229 |
+
|
230 |
+
all_in = [np.array(in_dict[name]["coord_3d"]) for name in in_names]
|
231 |
+
empty_tpl = [np.array(tpl_dict[name]["coord_3d"]) for name in empty_tpl_names]
|
232 |
+
|
233 |
+
# use KNN to choose k nearest channels
|
234 |
+
k = 4 if len(in_names)>4 else len(in_names)
|
235 |
+
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
236 |
+
knn.fit(all_in)
|
237 |
+
for i, name in enumerate(empty_tpl_names):
|
238 |
+
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
239 |
+
idx = tpl_dict[name]["index"]
|
240 |
+
old_idx[idx] = indices[0].tolist()
|
241 |
+
|
242 |
+
return old_idx
|
243 |
+
|
244 |
+
def optimal_mapping(channel_info):
|
245 |
+
tpl_names = channel_info["templateNames"]
|
246 |
+
in_names = channel_info["inputNames"]
|
247 |
+
tpl_dict = channel_info["templateDict"]
|
248 |
+
in_dict = channel_info["inputDict"]
|
249 |
+
unass_in_names = get_unassigned_input(in_names, in_dict)
|
250 |
+
# reset all tpl.matched to False
|
251 |
+
for name in tpl_dict:
|
252 |
+
tpl_dict[name]["matched"] = False
|
253 |
+
|
254 |
+
all_tpl = np.array([tpl_dict[name]["coord_3d"] for name in tpl_names])
|
255 |
+
unass_in = np.array([in_dict[name]["coord_3d"] for name in unass_in_names])
|
256 |
+
|
257 |
+
# initialize the cost matrix for the Hungarian algorithm
|
258 |
+
if len(unass_in_names) < 30:
|
259 |
+
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
260 |
+
else:
|
261 |
+
cost_matrix = np.zeros((30, len(unass_in_names)))
|
262 |
+
# fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
|
263 |
+
for i in range(30):
|
264 |
+
for j in range(len(unass_in_names)):
|
265 |
+
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
|
266 |
+
|
267 |
+
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
|
268 |
+
# by minimizing the total distances between their positions.
|
269 |
+
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
270 |
+
|
271 |
+
# store the mapping result
|
272 |
+
old_idx = [[None]]*30
|
273 |
+
is_orig_data = [False]*30
|
274 |
+
for i, j in zip(row_idx, col_idx):
|
275 |
+
if j < len(unass_in_names): # filter out dummy channels
|
276 |
+
tpl_name = tpl_names[i]
|
277 |
+
in_name = unass_in_names[j]
|
278 |
+
|
279 |
+
old_idx[i] = [in_dict[in_name]["index"]]
|
280 |
+
is_orig_data[i] = True
|
281 |
+
tpl_dict[tpl_name]["matched"] = True
|
282 |
+
in_dict[in_name]["assigned"] = True
|
283 |
+
|
284 |
+
# fill the remaining empty tpl_channels
|
285 |
+
empty_tpl_names = get_empty_template(tpl_names, tpl_dict)
|
286 |
+
if empty_tpl_names != []:
|
287 |
+
old_idx = find_neighbors(channel_info, empty_tpl_names, old_idx)
|
288 |
+
|
289 |
+
result = {
|
290 |
+
"index" : old_idx,
|
291 |
+
"isOriginalData" : is_orig_data
|
292 |
+
}
|
293 |
+
channel_info["inputDict"] = in_dict
|
294 |
+
return result, channel_info
|
295 |
+
|
296 |
+
def mapping_result(stage1_info, channel_info, filename):
|
297 |
+
unassigned_num = len(stage1_info["unassignedInput"])
|
298 |
+
batch = math.ceil(unassigned_num/30) + 1
|
299 |
+
|
300 |
+
# map the remaining in_channels
|
301 |
+
results = stage1_info["mappingResult"]
|
302 |
+
for i in range(1, batch):
|
303 |
+
# optimally select 30 in_channels to map to the tpl_channels based on proximity
|
304 |
+
result, channel_info = optimal_mapping(channel_info)
|
305 |
+
results += [result]
|
306 |
+
'''
|
307 |
+
for i in range(batch):
|
308 |
+
results[i]["name"] = {}
|
309 |
+
for j, indices in enumerate(results[i]["index"]):
|
310 |
+
names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
|
311 |
+
results[i]["name"][channel_info["templateNames"][j]] = names
|
312 |
+
'''
|
313 |
+
data = {
|
314 |
+
#"templateNames" : channel_info["templateNames"],
|
315 |
+
#"inputNames" : channel_info["inputNames"],
|
316 |
+
"channelNum" : len(channel_info["inputNames"]),
|
317 |
+
"batch" : batch,
|
318 |
+
"mappingResult" : results
|
319 |
+
}
|
320 |
+
options = jsbeautifier.default_options()
|
321 |
+
options.indent_size = 4
|
322 |
+
json_data = jsbeautifier.beautify(json.dumps(data), options)
|
323 |
+
with open(filename, 'w') as jsonfile:
|
324 |
+
jsonfile.write(json_data)
|
325 |
+
|
326 |
+
stage1_info.update({
|
327 |
+
"batch" : batch,
|
328 |
+
"mappingResult" : results
|
329 |
+
})
|
330 |
+
return stage1_info, channel_info
|
331 |
+
|