Files changed (1) hide show
  1. utils.py +82 -56
utils.py CHANGED
@@ -15,20 +15,21 @@ from scipy.signal import decimate, resample_poly, firwin, lfilter
15
 
16
 
17
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
 
18
 
19
- def resample(signal, fs):
20
- # downsample the signal to a sample rate of 256 Hz
21
- if fs>256:
22
- fs_down = 256 # Desired sample rate
23
  q = int(fs / fs_down) # Downsampling factor
24
  signal_new = []
25
  for ch in signal:
26
  x_down = decimate(ch, q)
27
  signal_new.append(x_down)
28
 
29
- # upsample the signal to a sample rate of 256 Hz
30
- elif fs<256:
31
- fs_up = 256 # Desired sample rate
32
  p = int(fs_up / fs) # Upsampling factor
33
  signal_new = []
34
  for ch in signal:
@@ -70,14 +71,14 @@ def cut_data(filepath, raw_data):
70
  total = int(len(raw_data[0]) / 1024)
71
  for i in range(total):
72
  table = raw_data[:, i * 1024:(i + 1) * 1024]
73
- filename = filepath + '/temp2/' + str(i) + '.csv'
74
  with open(filename, 'w', newline='') as csvfile:
75
  writer = csv.writer(csvfile)
76
  writer.writerows(table)
77
  return total
78
 
79
 
80
- def glue_data(file_name, total, output):
81
  gluedata = 0
82
  for i in range(total):
83
  file_name1 = file_name + 'output{}.csv'.format(str(i))
@@ -96,11 +97,7 @@ def glue_data(file_name, total, output):
96
  raw_data[:, 1] = smooth
97
  gluedata = np.append(gluedata, raw_data, axis=1)
98
  #print(gluedata.shape)
99
- filename2 = output
100
- with open(filename2, 'w', newline='') as csvfile:
101
- writer = csv.writer(csvfile)
102
- writer.writerows(gluedata)
103
- #print("GLUE DONE!" + filename2)
104
 
105
 
106
  def save_data(data, filename):
@@ -112,91 +109,105 @@ def dataDelete(path):
112
  try:
113
  shutil.rmtree(path)
114
  except OSError as e:
115
- print(e)
 
116
  else:
117
  pass
118
  #print("The directory is deleted successfully")
119
 
120
 
121
  def decode_data(data, std_num, mode=5):
122
- decode = data
123
  if mode == "ICUNet":
124
  # 1. read name
125
- model = cumbersome_model2.UNet1(n_channels=30, n_classes=30)
126
  resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
127
  # 2. load model
128
- checkpoint = torch.load(resumeLoc, map_location='cpu')
129
  model.load_state_dict(checkpoint['state_dict'], False)
130
  model.eval()
131
  # 3. decode strategy
132
  with torch.no_grad():
133
  data = data[np.newaxis, :, :]
134
- data = torch.Tensor(data)
135
  decode = model(data)
136
 
137
 
138
- elif mode == "UNetpp" or mode == "AttUnet":
139
  # 1. read name
140
- if mode == "UNetpp":
141
- model = UNet_family.NestedUNet3(num_classes=30)
142
- elif mode == "AttUnet":
143
- model = UNet_attention.UNetpp3_Transformer(num_classes=30)
144
- resumeLoc = './model/'+ mode + '/modelsave' + '/checkpoint.pth.tar'
145
  # 2. load model
146
- checkpoint = torch.load(resumeLoc, map_location='cpu')
147
  model.load_state_dict(checkpoint['state_dict'], False)
148
  model.eval()
149
  # 3. decode strategy
150
  with torch.no_grad():
151
  data = data[np.newaxis, :, :]
152
- data = torch.Tensor(data)
153
  decode1, decode2, decode = model(data)
154
 
155
 
156
- elif mode == "EEGART":
157
  # 1. read name
158
  resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
159
  # 2. load model
160
- checkpoint = torch.load(resumeLoc, map_location='cpu')
161
- model = tf_model.make_model(30, 30, N=2)
162
  model.load_state_dict(checkpoint['state_dict'])
163
  model.eval()
164
  # 3. decode strategy
165
  with torch.no_grad():
166
- data = torch.FloatTensor(data)
167
  data = data.unsqueeze(0)
168
  src = data
169
- tgt = data
170
  batch = tf_data.Batch(src, tgt, 0)
171
  out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
172
  decode = model.generator(out)
173
  decode = decode.permute(0, 2, 1)
174
- #add_tensor = torch.zeros(1, 30, 1)
175
- #decode = torch.cat((decode, add_tensor), dim=2)
176
 
177
  # 4. numpy
178
  #print(decode.shape)
179
- #decode = np.array(decode.cpu()).astype(np.float64)
180
- decode = np.array(decode).astype(np.float64)
181
- print(type(decode))
182
- print(decode.shape)
183
- #decode = decode.tolist()
184
  return decode
185
 
186
- def preprocessing(filepath, filename, samplerate):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # establish temp folder
188
  try:
189
- os.mkdir(filepath+"/temp2/")
190
  except OSError as e:
191
- dataDelete(filepath+"/temp2/")
192
- os.mkdir(filepath+"/temp2/")
193
  print(e)
194
 
195
  # read data
196
- signal = read_train_data(filepath+'/'+filename)
 
 
 
197
  #print(signal.shape)
198
  # resample
199
- signal = resample(signal, samplerate)
200
  #print(signal.shape)
201
  # FIR_filter
202
  signal = FIR_filter(signal, 1, 50)
@@ -206,13 +217,29 @@ def preprocessing(filepath, filename, samplerate):
206
 
207
  return total_file_num
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # model = tf.keras.models.load_model('./denoise_model/')
211
- def reconstruct(model_name, total, filepath, outputfile):
212
  # -------------------decode_data---------------------------
213
  second1 = time.time()
214
  for i in range(total):
215
- file_name = filepath + '/temp2/{}.csv'.format(str(i))
216
  data_noise = read_train_data(file_name)
217
 
218
  std = np.std(data_noise)
@@ -222,18 +249,17 @@ def reconstruct(model_name, total, filepath, outputfile):
222
 
223
  # Deep Learning Artifact Removal
224
  d_data = decode_data(data_noise, std, model_name)
225
- #d_data = d_data[0]
226
 
227
- outputname = filepath + '/temp2/output{}.csv'.format(str(i))
228
  save_data(d_data, outputname)
229
- #d_data.to_csv(outputname, index=False)
230
-
231
 
232
  # --------------------glue_data----------------------------
233
- glue_data(filepath+"/temp2/", total, filepath+'/'+outputfile)
234
  # -------------------delete_data---------------------------
235
- dataDelete(filepath+"/temp2/")
236
  second2 = time.time()
237
-
238
- print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
239
-
 
 
15
 
16
 
17
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
18
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
 
20
+ def resample(signal, fs, tgt_fs):
21
+ # downsample the signal to the target sample rate
22
+ if fs>tgt_fs:
23
+ fs_down = tgt_fs # Desired sample rate
24
  q = int(fs / fs_down) # Downsampling factor
25
  signal_new = []
26
  for ch in signal:
27
  x_down = decimate(ch, q)
28
  signal_new.append(x_down)
29
 
30
+ # upsample the signal to the target sample rate
31
+ elif fs<tgt_fs:
32
+ fs_up = tgt_fs # Desired sample rate
33
  p = int(fs_up / fs) # Upsampling factor
34
  signal_new = []
35
  for ch in signal:
 
71
  total = int(len(raw_data[0]) / 1024)
72
  for i in range(total):
73
  table = raw_data[:, i * 1024:(i + 1) * 1024]
74
+ filename = filepath + 'temp2/' + str(i) + '.csv'
75
  with open(filename, 'w', newline='') as csvfile:
76
  writer = csv.writer(csvfile)
77
  writer.writerows(table)
78
  return total
79
 
80
 
81
+ def glue_data(file_name, total):
82
  gluedata = 0
83
  for i in range(total):
84
  file_name1 = file_name + 'output{}.csv'.format(str(i))
 
97
  raw_data[:, 1] = smooth
98
  gluedata = np.append(gluedata, raw_data, axis=1)
99
  #print(gluedata.shape)
100
+ return gluedata
 
 
 
 
101
 
102
 
103
  def save_data(data, filename):
 
109
  try:
110
  shutil.rmtree(path)
111
  except OSError as e:
112
+ pass
113
+ #print(e)
114
  else:
115
  pass
116
  #print("The directory is deleted successfully")
117
 
118
 
119
  def decode_data(data, std_num, mode=5):
120
+
121
  if mode == "ICUNet":
122
  # 1. read name
123
+ model = cumbersome_model2.UNet1(n_channels=30, n_classes=30).to(device)
124
  resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
125
  # 2. load model
126
+ checkpoint = torch.load(resumeLoc, map_location=device)
127
  model.load_state_dict(checkpoint['state_dict'], False)
128
  model.eval()
129
  # 3. decode strategy
130
  with torch.no_grad():
131
  data = data[np.newaxis, :, :]
132
+ data = torch.Tensor(data).to(device)
133
  decode = model(data)
134
 
135
 
136
+ elif mode == "ICUNet++" or mode == "ICUNet_attn":
137
  # 1. read name
138
+ if mode == "ICUNet++":
139
+ model = UNet_family.NestedUNet3(num_classes=30).to(device)
140
+ elif mode == "ICUNet_attn":
141
+ model = UNet_attention.UNetpp3_Transformer(num_classes=30).to(device)
142
+ resumeLoc = './model/' + mode + '/modelsave' + '/checkpoint.pth.tar'
143
  # 2. load model
144
+ checkpoint = torch.load(resumeLoc, map_location=device)
145
  model.load_state_dict(checkpoint['state_dict'], False)
146
  model.eval()
147
  # 3. decode strategy
148
  with torch.no_grad():
149
  data = data[np.newaxis, :, :]
150
+ data = torch.Tensor(data).to(device)
151
  decode1, decode2, decode = model(data)
152
 
153
 
154
+ elif mode == "ART":
155
  # 1. read name
156
  resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
157
  # 2. load model
158
+ checkpoint = torch.load(resumeLoc, map_location=device)
159
+ model = tf_model.make_model(30, 30, N=2).to(device)
160
  model.load_state_dict(checkpoint['state_dict'])
161
  model.eval()
162
  # 3. decode strategy
163
  with torch.no_grad():
164
+ data = torch.FloatTensor(data).to(device)
165
  data = data.unsqueeze(0)
166
  src = data
167
+ tgt = data # you can modify to randomize data
168
  batch = tf_data.Batch(src, tgt, 0)
169
  out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
170
  decode = model.generator(out)
171
  decode = decode.permute(0, 2, 1)
172
+ add_tensor = torch.zeros(1, 30, 1).to(device)
173
+ decode = torch.cat((decode, add_tensor), dim=2)
174
 
175
  # 4. numpy
176
  #print(decode.shape)
177
+ decode = np.array(decode.cpu()).astype(np.float64)
 
 
 
 
178
  return decode
179
 
180
+
181
+ def reorder_data(raw_data, mapping_result):
182
+ new_data = np.zeros((30, raw_data.shape[1]))
183
+ zero_arr = np.zeros((1, raw_data.shape[1]))
184
+ for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
185
+ if flag == True:
186
+ new_data[i, :] = raw_data[indices[0], :]
187
+ elif indices[0] == None:
188
+ new_data[i, :] = zero_arr
189
+ else:
190
+ data = [raw_data[idx, :] for idx in indices]
191
+ new_data[i, :] = np.mean(data, axis=0)
192
+ return new_data
193
+
194
+ def preprocessing(filepath, inputfile, samplerate, mapping_result):
195
  # establish temp folder
196
  try:
197
+ os.mkdir(filepath+"temp2/")
198
  except OSError as e:
199
+ dataDelete(filepath+"temp2/")
200
+ os.mkdir(filepath+"temp2/")
201
  print(e)
202
 
203
  # read data
204
+ signal = read_train_data(inputfile)
205
+ #print(signal.shape)
206
+ # channel mapping
207
+ signal = reorder_data(signal, mapping_result)
208
  #print(signal.shape)
209
  # resample
210
+ signal = resample(signal, samplerate, 256)
211
  #print(signal.shape)
212
  # FIR_filter
213
  signal = FIR_filter(signal, 1, 50)
 
217
 
218
  return total_file_num
219
 
220
+ def restore_order(data, all_data, mapping_result):
221
+ for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
222
+ if flag == True:
223
+ all_data[indices[0], :] = data[i, :]
224
+ return all_data
225
+
226
+ def postprocessing(data, samplerate, outputfile, mapping_result, batch_cnt, channel_num):
227
+
228
+ # resample to original sampling rate
229
+ data = resample(data, 256, samplerate)
230
+ # reverse channel mapping
231
+ all_data = np.zeros((channel_num, data.shape[1])) if batch_cnt==0 else read_train_data(outputfile)
232
+ all_data = restore_order(data, all_data, mapping_result)
233
+ # save data
234
+ save_data(all_data, outputfile)
235
+
236
 
237
  # model = tf.keras.models.load_model('./denoise_model/')
238
+ def reconstruct(model_name, total, filepath, batch_cnt):
239
  # -------------------decode_data---------------------------
240
  second1 = time.time()
241
  for i in range(total):
242
+ file_name = filepath + 'temp2/{}.csv'.format(str(i))
243
  data_noise = read_train_data(file_name)
244
 
245
  std = np.std(data_noise)
 
249
 
250
  # Deep Learning Artifact Removal
251
  d_data = decode_data(data_noise, std, model_name)
252
+ d_data = d_data[0]
253
 
254
+ outputname = filepath + 'temp2/output{}.csv'.format(str(i))
255
  save_data(d_data, outputname)
 
 
256
 
257
  # --------------------glue_data----------------------------
258
+ data = glue_data(filepath+"temp2/", total)
259
  # -------------------delete_data---------------------------
260
+ dataDelete(filepath+"temp2/")
261
  second2 = time.time()
262
+
263
+ print(f"Using {model_name} model to reconstruct batch-{batch_cnt+1} has been success in {second2 - second1} sec(s)")
264
+ return data
265
+