jordand commited on
Commit
97724f1
·
verified ·
1 Parent(s): 45073ca

Update silentcipher/server.py

Browse files
Files changed (1) hide show
  1. silentcipher/server.py +479 -479
silentcipher/server.py CHANGED
@@ -1,480 +1,480 @@
1
- from calendar import c
2
- import os
3
- import argparse
4
- import re
5
- from tabnanny import check
6
- import yaml
7
- import time
8
- import numpy as np
9
- import soundfile as sf
10
- from scipy import stats as st
11
- import librosa
12
- from pydub import AudioSegment
13
- import torch
14
- from torch import nn
15
-
16
- from .model import Encoder, CarrierDecoder, MsgDecoder
17
- from .stft import STFT
18
-
19
- class Model():
20
-
21
- def __init__(self, config, device='cpu'):
22
-
23
- self.config = config
24
- self.device = device
25
-
26
- self.n_messages = config.n_messages
27
- self.model_type = config.model_type
28
- self.message_dim = config.message_dim
29
- self.message_len = config.message_len
30
-
31
- # model dimensions
32
- self.enc_conv_dim = 16
33
- self.enc_num_repeat = 3
34
- self.dec_c_num_repeat = self.enc_num_repeat
35
- self.dec_m_conv_dim = 1
36
- self.dec_m_num_repeat = 8
37
- self.encoder_out_dim = 32
38
- self.dec_c_conv_dim = 32*3
39
-
40
- self.enc_c = Encoder(n_layers=self.config.enc_n_layers,
41
- message_dim=self.message_dim,
42
- out_dim=self.encoder_out_dim,
43
- message_band_size=self.config.message_band_size,
44
- n_fft=self.config.N_FFT)
45
-
46
- self.dec_c = CarrierDecoder(config=self.config,
47
- conv_dim=self.dec_c_conv_dim,
48
- n_layers=self.config.dec_c_n_layers,
49
- message_band_size=self.config.message_band_size)
50
-
51
- self.dec_m = [MsgDecoder(message_dim=self.message_dim,
52
- message_band_size=self.config.message_band_size) for _ in range(self.n_messages)]
53
- # ------ make parallel ------
54
- self.enc_c = self.enc_c.to(self.device)
55
- self.dec_c = self.dec_c.to(self.device)
56
- self.dec_m = [m.to(self.device) for m in self.dec_m]
57
-
58
- self.average_energy_VCTK=0.002837200844477648
59
- self.stft = STFT(self.config.N_FFT, self.config.HOP_LENGTH)
60
- self.stft.to(self.device)
61
- self.load_models(config.load_ckpt)
62
- self.sr = self.config.SR
63
-
64
- def letters_encoding(self, patch_len, message_lst):
65
-
66
- """
67
- Encodes a list of messages into a compact representation and a padded representation.
68
-
69
- Args:
70
- patch_len (int): The length of the patch.
71
- message_lst (list): A list of messages to be encoded.
72
-
73
- Returns:
74
- tuple: A tuple containing two numpy arrays:
75
- - message: A padded representation of the messages, where each message is repeated to match the patch length.
76
- - message_compact: A compact representation of the messages, where each message is encoded as a one-hot vector.
77
-
78
- Raises:
79
- AssertionError: If the length of any message in message_lst is not equal to self.config.message_len - 1.
80
- """
81
-
82
- message = []
83
- message_compact = []
84
- for i in range(self.n_messages):
85
-
86
- assert len(message_lst[i]) == self.config.message_len - 1
87
- index = np.concatenate((np.array(message_lst[i])+1, [0]))
88
- one_hot = np.identity(self.message_dim)[index]
89
- message_compact.append(one_hot)
90
- if patch_len % self.message_len == 0:
91
- message.append(np.tile(one_hot.T, (1, patch_len // self.message_len)))
92
- else:
93
- _ = np.tile(one_hot.T, (1, patch_len // self.message_len))
94
- _ = np.concatenate([_, one_hot.T[:, 0:patch_len % self.message_len]], axis=1)
95
- message.append(_)
96
- message = np.stack(message)
97
- message_compact = np.stack(message_compact)
98
- # message = np.pad(message, ((0, 0), (0, 129 - self.message_dim), (0, 0)), 'constant')
99
- return message, message_compact
100
-
101
- def get_best_ps(self, y_one_sec):
102
-
103
- """
104
- Calculates the best phase shift value for watermark decoding.
105
-
106
- Args:
107
- y_one_sec (numpy.ndarray): Input audio signal.
108
-
109
- Returns:
110
- int: The best phase shift value.
111
-
112
- """
113
-
114
- def check_accuracy(pred_values):
115
-
116
- accuracy = 0
117
- for i in range(pred_values.shape[1]):
118
- unique, counts = np.unique(pred_values[:, i], return_counts=True)
119
- accuracy += np.max(counts) / pred_values.shape[0]
120
-
121
- return accuracy / pred_values.shape[1]
122
-
123
- y = torch.FloatTensor(y_one_sec).unsqueeze(0).unsqueeze(0).to(self.device)
124
- max_accuracy = 0
125
- final_phase_shift = 0
126
-
127
- for ps in range(0, self.config.HOP_LENGTH, 10):
128
-
129
- carrier, _ = self.stft.transform(y[0:1, 0:1, ps:].squeeze(1))
130
- carrier = carrier[:, None]
131
-
132
- for i in range(self.n_messages): # decode each msg_i using decoder_m_i
133
- msg_reconst = self.dec_m[i](carrier)
134
- pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
135
- pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
136
- pred_values = pred_values.reshape([-1, self.config.message_len])
137
- cur_acc = check_accuracy(pred_values)
138
- if cur_acc > max_accuracy:
139
- max_accuracy = cur_acc
140
- final_phase_shift = ps
141
-
142
- return final_phase_shift
143
-
144
- def get_confidence(self, pred_values, message):
145
- """
146
- Calculates the confidence of the predicted values based on the provided message.
147
-
148
- Parameters:
149
- pred_values (numpy.ndarray): The predicted values.
150
- message (str): The message used for prediction.
151
-
152
- Returns:
153
- float: The confidence score.
154
-
155
- Raises:
156
- AssertionError: If the length of the message is not equal to the number of columns in pred_values.
157
-
158
- """
159
- assert len(message) == pred_values.shape[1], f'{len(message)} | {pred_values.shape}'
160
- return np.mean((pred_values == message[None]).astype(np.float32)).item()
161
-
162
- def sdr(self, orig, recon):
163
- """
164
- Calculate the Signal-to-Distortion Ratio (SDR) between the original and reconstructed signals.
165
-
166
- Parameters:
167
- orig (numpy.ndarray): The original signal.
168
- recon (numpy.ndarray): The reconstructed signal.
169
-
170
- Returns:
171
- float: The Signal-to-Distortion Ratio (SDR) value.
172
-
173
- """
174
-
175
- rms1 = ((np.mean(orig ** 2)) ** 0.5)
176
- rms2 = ((np.mean((orig - recon) ** 2)) ** 0.5)
177
- sdr = 20 * np.log10(rms1 / rms2)
178
- return sdr
179
-
180
- def load_audio(self, path):
181
- """
182
- Load an audio file from the given path and return the audio array and sample rate.
183
-
184
- Args:
185
- path (str): The path to the audio file.
186
-
187
- Returns:
188
- tuple: A tuple containing the audio array and sample rate.
189
-
190
- """
191
- audio = AudioSegment.from_file(path)
192
- audio_array, sr = (np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((-1, audio.channels)) / (
193
- 1 << (8 * audio.sample_width - 1))), audio.frame_rate
194
- if audio_array.shape[1] == 1:
195
- audio_array = audio_array[:, 0]
196
-
197
- return audio_array, sr
198
-
199
- def encode(self, in_path, out_path, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
200
- """
201
- Encodes a message into an audio file.
202
-
203
- Parameters:
204
- - in_path (str): The path to the input audio file.
205
- - out_path (str): The path to save the output audio file.
206
- - message_list (list): A list of messages to be encoded into the audio file.
207
- - message_sdr (float, optional): The Signal-to-Distortion Ratio (SDR) of the message. Defaults to None.
208
- - calc_sdr (bool, optional): Whether to calculate the SDR of the encoded audio. Defaults to True.
209
- - disable_checks (bool, optional): Whether to disable input checks. Defaults to False.
210
-
211
- Returns:
212
- - dict: A dictionary containing the status of the encoding process, the SDR value(s), the time taken for encoding, and the time taken per second of audio.
213
-
214
- """
215
- y, orig_sr = self.load_audio(in_path)
216
- start = time.time()
217
- encoded_y, sdr = self.encode_wav(y, orig_sr, message_list=message_list, message_sdr=message_sdr, calc_sdr=calc_sdr, disable_checks=disable_checks)
218
- time_taken = time.time() - start
219
- sf.write(out_path, encoded_y, orig_sr)
220
-
221
- if type(sdr) == list:
222
- return {'status': True, 'sdr': [f'{sdr_i:.2f}' for sdr_i in sdr], 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
223
- else:
224
- return {'status': True, 'sdr': f'{sdr:.2f}', 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
225
-
226
- def decode(self, path, phase_shift_decoding):
227
- """
228
- Decode the audio file at the given path using phase shift decoding.
229
-
230
- Parameters:
231
- path (str): The path to the audio file.
232
- phase_shift_decoding (bool): Flag indicating whether to use phase shift decoding.
233
-
234
- Returns:
235
- dictionary: A dictionary containing the decoded message status and value
236
- """
237
-
238
- y, orig_sr = self.load_audio(path)
239
-
240
- return self.decode_wav(y, orig_sr, phase_shift_decoding)
241
-
242
- def encode_wav(self, y_multi_channel, orig_sr, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
243
-
244
- """
245
- Encodes a multi-channel audio waveform with a given message.
246
-
247
- Args:
248
- y_multi_channel (numpy.ndarray): The multi-channel audio waveform to be encoded.
249
- orig_sr (int): The original sampling rate of the audio waveform.
250
- message_list (list): The list of messages to be encoded. Each message may correspond to a channel in the audio waveform.
251
- message_sdr (float, optional): The signal-to-distortion ratio (SDR) of the message. If not provided, the default SDR from the configuration is used.
252
- calc_sdr (bool, optional): Flag indicating whether to calculate the SDR of the encoded waveform. Defaults to True.
253
- disable_checks (bool, optional): Flag indicating whether to disable input audio checks. Defaults to False.
254
-
255
- Returns:
256
- tuple: A tuple containing the encoded multi-channel audio waveform and the SDR (if calculated).
257
-
258
- Raises:
259
- AssertionError: If the number of messages does not match the number of channels in the input audio waveform.
260
- """
261
-
262
- single_channel = False
263
- if len(y_multi_channel.shape) == 1:
264
- single_channel = True
265
- y_multi_channel = y_multi_channel[:, None]
266
-
267
- if message_sdr is None:
268
- message_sdr = self.config.message_sdr
269
- print(f'Using the default SDR of {self.config.message_sdr} dB')
270
-
271
- if type(message_list[0]) == int:
272
- message_list = [message_list]*y_multi_channel.shape[1]
273
-
274
- y_watermarked_multi_channel = []
275
- sdrs = []
276
-
277
- assert len(message_list) == y_multi_channel.shape[1], f'{len(message_list)} | {y_multi_channel.shape[1]} Mismatch in the number of messages and channels in the input audio.'
278
-
279
- for channel_i in range(y_multi_channel.shape[1]):
280
- y = y_multi_channel[:, channel_i]
281
- message = message_list[channel_i]
282
-
283
- with torch.no_grad():
284
-
285
- orig_y = y.copy()
286
- if orig_sr != self.sr:
287
- if orig_sr > self.sr:
288
- print(f'WARNING! Reducing the sampling rate of the original audio from {orig_sr} -> {self.sr}. High frequency components may be lost!')
289
- y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
290
- original_power = np.mean(y**2)
291
-
292
- if not disable_checks:
293
- if original_power == 0:
294
- print('WARNING! The input audio has a power of 0.This means the audio is likely just silence. Skipping encoding.')
295
- return orig_y, 0
296
-
297
- y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
298
- y = torch.FloatTensor(y).unsqueeze(0).unsqueeze(0).to(self.device)
299
- carrier, carrier_phase = self.stft.transform(y.squeeze(1))
300
- carrier = carrier[:, None]
301
- carrier_phase = carrier_phase[:, None]
302
-
303
- def binary_encode(mes):
304
- binary_message = ''.join(['{0:08b}'.format(mes_i) for mes_i in mes])
305
- four_bit_msg = []
306
- for i in range(len(binary_message)//2):
307
- four_bit_msg.append(int(binary_message[i*2:i*2+2], 2))
308
- return four_bit_msg
309
-
310
- binary_encoded_message = binary_encode(message)
311
-
312
- msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message])
313
- msg_enc = torch.from_numpy(msgs[None]).to(self.device).float()
314
-
315
- carrier_enc = self.enc_c(carrier) # encode the carrier
316
- msg_enc = self.enc_c.transform_message(msg_enc)
317
-
318
- merged_enc = torch.cat((carrier_enc, carrier.repeat(1, 32, 1, 1), msg_enc.repeat(1, 32, 1, 1)), dim=1) # concat encodings on features axis
319
-
320
- message_info = self.dec_c(merged_enc, message_sdr)
321
- if self.config.frame_level_normalization:
322
- message_info = message_info*(torch.mean((carrier**2), dim=2, keepdim=True)**0.5) # *time_weighing
323
- elif self.config.utterance_level_normalization:
324
- message_info = message_info*(torch.mean((carrier**2), dim=(2,3), keepdim=True)**0.5) # *time_weighing
325
-
326
- if self.config.ensure_negative_message:
327
- message_info = -message_info
328
- carrier_reconst = torch.nn.functional.relu(message_info + carrier) # decode carrier, output in stft domain
329
- elif self.config.ensure_constrained_message:
330
- message_info[message_info > carrier] = carrier[message_info > carrier]
331
- message_info[-message_info > carrier] = -carrier[-message_info > carrier]
332
- carrier_reconst = message_info + carrier # decode carrier, output in stft domain
333
- assert torch.all(carrier_reconst >= 0), 'negative values found in carrier_reconst'
334
- else:
335
- carrier_reconst = torch.abs(message_info + carrier) # decode carrier, output in stft domain
336
-
337
- self.stft.num_samples = y.shape[2]
338
-
339
- y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1)).data.cpu().numpy()[0, 0]
340
- y = y * np.sqrt(original_power / (self.average_energy_VCTK)) # Noise has a power of 5% power of VCTK samples
341
- if orig_sr != self.sr:
342
- y = librosa.resample(y, orig_sr = self.sr, target_sr = orig_sr)
343
-
344
- if calc_sdr:
345
- sdr = self.sdr(orig_y, y)
346
- else:
347
- sdr = 0
348
-
349
- y_watermarked_multi_channel.append(y[:, None])
350
- sdrs.append(sdr)
351
-
352
- y_watermarked_multi_channel = np.concatenate(y_watermarked_multi_channel, axis=1)
353
-
354
- if single_channel:
355
- y_watermarked_multi_channel = y_watermarked_multi_channel[:, 0]
356
- sdrs = sdrs[0]
357
-
358
- return y_watermarked_multi_channel, sdrs
359
-
360
- def decode_wav(self, y_multi_channel, orig_sr, phase_shift_decoding):
361
- """
362
- Decode the given audio waveform to extract hidden messages.
363
-
364
- Args:
365
- y_multi_channel (numpy.ndarray): The multi-channel audio waveform.
366
- orig_sr (int): The original sample rate of the audio waveform.
367
- phase_shift_decoding (str): Flag indicating whether to perform phase shift decoding.
368
-
369
- Returns:
370
- dict or list: A list of dictionary containing the decoded messages, confidences, and status for each channel if the input is multi-channel.
371
- Otherwise, a dictionary containing the decoded messages, confidences, and status for a single channel.
372
-
373
- Raises:
374
- Exception: If the decoding process fails.
375
-
376
- """
377
- single_channel = False
378
- if len(y_multi_channel.shape) == 1:
379
- single_channel = True
380
- y_multi_channel = y_multi_channel[:, None]
381
-
382
- results = []
383
-
384
- for channel_i in range(y_multi_channel.shape[1]):
385
- y = y_multi_channel[:, channel_i]
386
- try:
387
- with torch.no_grad():
388
- if orig_sr != self.sr:
389
- y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
390
- original_power = np.mean(y**2)
391
- y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
392
- if phase_shift_decoding and phase_shift_decoding != 'false':
393
- ps = self.get_best_ps(y)
394
- else:
395
- ps = 0
396
- y = torch.FloatTensor(y[ps:]).unsqueeze(0).unsqueeze(0).to(self.device)
397
- carrier, _ = self.stft.transform(y.squeeze(1))
398
- carrier = carrier[:, None]
399
-
400
- msg_reconst_list = []
401
- confidence = []
402
-
403
- for i in range(self.n_messages): # decode each msg_i using decoder_m_i
404
- msg_reconst = self.dec_m[i](carrier)
405
- pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
406
- pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
407
- pred_values = pred_values.reshape([-1, self.config.message_len])
408
-
409
- ord_values = st.mode(pred_values, keepdims=False).mode
410
- end_char = np.min(np.nonzero(ord_values == 0)[0])
411
- confidence.append(self.get_confidence(pred_values, ord_values))
412
- if end_char == self.config.message_len:
413
- ord_values = ord_values[:self.config.message_len-1]
414
- else:
415
- ord_values = np.concatenate([ord_values[end_char+1:], ord_values[:end_char]], axis=0)
416
-
417
- # pred_values = ''.join([chr(v + 64) for v in ord_values])
418
- msg_reconst_list.append((ord_values - 1).tolist())
419
-
420
- def convert_to_8_bit_segments(msg_list):
421
- segment_message_list = []
422
- for msg_list_i in msg_list:
423
- binary_format = ''.join(['{0:02b}'.format(mes_i) for mes_i in msg_list_i])
424
- eight_bit_segments = [int(binary_format[i*8:i*8+8], 2) for i in range(len(binary_format)//8)]
425
- segment_message_list.append(eight_bit_segments)
426
- return segment_message_list
427
- msg_reconst_list = convert_to_8_bit_segments(msg_reconst_list)
428
-
429
- results.append({'messages': msg_reconst_list, 'confidences': confidence, 'status': True})
430
- except:
431
- results.append({'messages': [], 'confidences': [], 'error': 'Could not find message', 'status': False})
432
-
433
- if single_channel:
434
- results = results[0]
435
-
436
- return results
437
-
438
- def convert_dataparallel_to_normal(self, checkpoint):
439
-
440
- return {i[len('module.'):] if i.startswith('module.') else i: checkpoint[i] for i in checkpoint }
441
-
442
- def load_models(self, ckpt_dir):
443
-
444
- self.enc_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "enc_c.ckpt"), map_location=self.device)))
445
- self.dec_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "dec_c.ckpt"), map_location=self.device)))
446
- for i,m in enumerate(self.dec_m):
447
- m.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, f"dec_m_{i}.ckpt"), map_location=self.device)))
448
-
449
-
450
- def get_model(model_type='44.1k', ckpt_path='../Models/44_1_khz/73999_iteration', config_path='../Models/44_1_khz/73999_iteration/hparams.yaml', device='cpu'):
451
-
452
- if model_type == '44.1k':
453
- if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
454
- print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
455
- from huggingface_hub import snapshot_download
456
- folder_dir = snapshot_download(repo_id="sony/silentcipher")
457
- ckpt_path = os.path.join(folder_dir, '44_1_khz/73999_iteration')
458
- config_path = os.path.join(folder_dir, '44_1_khz/73999_iteration/hparams.yaml')
459
-
460
- config = yaml.safe_load(open(config_path))
461
- config = argparse.Namespace(**config)
462
- config.load_ckpt = ckpt_path
463
- model = Model(config, device)
464
- elif model_type == '16k':
465
- if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
466
- print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
467
- from huggingface_hub import snapshot_download
468
- folder_dir = snapshot_download(repo_id="sony/silentcipher")
469
- ckpt_path = os.path.join(folder_dir, '16_khz/97561_iteration')
470
- config_path = os.path.join(folder_dir, '16_khz/97561_iteration/hparams.yaml')
471
-
472
- config = yaml.safe_load(open(config_path))
473
- config = argparse.Namespace(**config)
474
- config.load_ckpt = ckpt_path
475
-
476
- model = Model(config, device)
477
- else:
478
- print('Please specify a valid model_type [44.1k, 16k]')
479
-
480
  return model
 
1
+ from calendar import c
2
+ import os
3
+ import argparse
4
+ import re
5
+ from tabnanny import check
6
+ import yaml
7
+ import time
8
+ import numpy as np
9
+ import soundfile as sf
10
+ from scipy import stats as st
11
+ import librosa
12
+ from pydub import AudioSegment
13
+ import torch
14
+ from torch import nn
15
+
16
+ from .model import Encoder, CarrierDecoder, MsgDecoder
17
+ from .stft import STFT
18
+
19
+ class Model():
20
+
21
+ def __init__(self, config, device='cpu'):
22
+
23
+ self.config = config
24
+ self.device = device
25
+
26
+ self.n_messages = config.n_messages
27
+ self.model_type = config.model_type
28
+ self.message_dim = config.message_dim
29
+ self.message_len = config.message_len
30
+
31
+ # model dimensions
32
+ self.enc_conv_dim = 16
33
+ self.enc_num_repeat = 3
34
+ self.dec_c_num_repeat = self.enc_num_repeat
35
+ self.dec_m_conv_dim = 1
36
+ self.dec_m_num_repeat = 8
37
+ self.encoder_out_dim = 32
38
+ self.dec_c_conv_dim = 32*3
39
+
40
+ self.enc_c = Encoder(n_layers=self.config.enc_n_layers,
41
+ message_dim=self.message_dim,
42
+ out_dim=self.encoder_out_dim,
43
+ message_band_size=self.config.message_band_size,
44
+ n_fft=self.config.N_FFT)
45
+
46
+ self.dec_c = CarrierDecoder(config=self.config,
47
+ conv_dim=self.dec_c_conv_dim,
48
+ n_layers=self.config.dec_c_n_layers,
49
+ message_band_size=self.config.message_band_size)
50
+
51
+ self.dec_m = [MsgDecoder(message_dim=self.message_dim,
52
+ message_band_size=self.config.message_band_size) for _ in range(self.n_messages)]
53
+ # ------ make parallel ------
54
+ self.enc_c = self.enc_c.to(self.device)
55
+ self.dec_c = self.dec_c.to(self.device)
56
+ self.dec_m = [m.to(self.device) for m in self.dec_m]
57
+
58
+ self.average_energy_VCTK=0.002837200844477648
59
+ self.stft = STFT(self.config.N_FFT, self.config.HOP_LENGTH)
60
+ self.stft.to(self.device)
61
+ self.load_models(config.load_ckpt)
62
+ self.sr = self.config.SR
63
+
64
+ def letters_encoding(self, patch_len, message_lst):
65
+
66
+ """
67
+ Encodes a list of messages into a compact representation and a padded representation.
68
+
69
+ Args:
70
+ patch_len (int): The length of the patch.
71
+ message_lst (list): A list of messages to be encoded.
72
+
73
+ Returns:
74
+ tuple: A tuple containing two numpy arrays:
75
+ - message: A padded representation of the messages, where each message is repeated to match the patch length.
76
+ - message_compact: A compact representation of the messages, where each message is encoded as a one-hot vector.
77
+
78
+ Raises:
79
+ AssertionError: If the length of any message in message_lst is not equal to self.config.message_len - 1.
80
+ """
81
+
82
+ message = []
83
+ message_compact = []
84
+ for i in range(self.n_messages):
85
+
86
+ assert len(message_lst[i]) == self.config.message_len - 1
87
+ index = np.concatenate((np.array(message_lst[i])+1, [0]))
88
+ one_hot = np.identity(self.message_dim)[index]
89
+ message_compact.append(one_hot)
90
+ if patch_len % self.message_len == 0:
91
+ message.append(np.tile(one_hot.T, (1, patch_len // self.message_len)))
92
+ else:
93
+ _ = np.tile(one_hot.T, (1, patch_len // self.message_len))
94
+ _ = np.concatenate([_, one_hot.T[:, 0:patch_len % self.message_len]], axis=1)
95
+ message.append(_)
96
+ message = np.stack(message)
97
+ message_compact = np.stack(message_compact)
98
+ # message = np.pad(message, ((0, 0), (0, 129 - self.message_dim), (0, 0)), 'constant')
99
+ return message, message_compact
100
+
101
+ def get_best_ps(self, y_one_sec):
102
+
103
+ """
104
+ Calculates the best phase shift value for watermark decoding.
105
+
106
+ Args:
107
+ y_one_sec (numpy.ndarray): Input audio signal.
108
+
109
+ Returns:
110
+ int: The best phase shift value.
111
+
112
+ """
113
+
114
+ def check_accuracy(pred_values):
115
+
116
+ accuracy = 0
117
+ for i in range(pred_values.shape[1]):
118
+ unique, counts = np.unique(pred_values[:, i], return_counts=True)
119
+ accuracy += np.max(counts) / pred_values.shape[0]
120
+
121
+ return accuracy / pred_values.shape[1]
122
+
123
+ y = torch.FloatTensor(y_one_sec).unsqueeze(0).unsqueeze(0).to(self.device)
124
+ max_accuracy = 0
125
+ final_phase_shift = 0
126
+
127
+ for ps in range(0, self.config.HOP_LENGTH, 10):
128
+
129
+ carrier, _ = self.stft.transform(y[0:1, 0:1, ps:].squeeze(1))
130
+ carrier = carrier[:, None]
131
+
132
+ for i in range(self.n_messages): # decode each msg_i using decoder_m_i
133
+ msg_reconst = self.dec_m[i](carrier)
134
+ pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
135
+ pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
136
+ pred_values = pred_values.reshape([-1, self.config.message_len])
137
+ cur_acc = check_accuracy(pred_values)
138
+ if cur_acc > max_accuracy:
139
+ max_accuracy = cur_acc
140
+ final_phase_shift = ps
141
+
142
+ return final_phase_shift
143
+
144
+ def get_confidence(self, pred_values, message):
145
+ """
146
+ Calculates the confidence of the predicted values based on the provided message.
147
+
148
+ Parameters:
149
+ pred_values (numpy.ndarray): The predicted values.
150
+ message (str): The message used for prediction.
151
+
152
+ Returns:
153
+ float: The confidence score.
154
+
155
+ Raises:
156
+ AssertionError: If the length of the message is not equal to the number of columns in pred_values.
157
+
158
+ """
159
+ assert len(message) == pred_values.shape[1], f'{len(message)} | {pred_values.shape}'
160
+ return np.mean((pred_values == message[None]).astype(np.float32)).item()
161
+
162
+ def sdr(self, orig, recon):
163
+ """
164
+ Calculate the Signal-to-Distortion Ratio (SDR) between the original and reconstructed signals.
165
+
166
+ Parameters:
167
+ orig (numpy.ndarray): The original signal.
168
+ recon (numpy.ndarray): The reconstructed signal.
169
+
170
+ Returns:
171
+ float: The Signal-to-Distortion Ratio (SDR) value.
172
+
173
+ """
174
+
175
+ rms1 = ((np.mean(orig ** 2)) ** 0.5)
176
+ rms2 = ((np.mean((orig - recon) ** 2)) ** 0.5)
177
+ sdr = 20 * np.log10(rms1 / rms2)
178
+ return sdr
179
+
180
+ def load_audio(self, path):
181
+ """
182
+ Load an audio file from the given path and return the audio array and sample rate.
183
+
184
+ Args:
185
+ path (str): The path to the audio file.
186
+
187
+ Returns:
188
+ tuple: A tuple containing the audio array and sample rate.
189
+
190
+ """
191
+ audio = AudioSegment.from_file(path)
192
+ audio_array, sr = (np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((-1, audio.channels)) / (
193
+ 1 << (8 * audio.sample_width - 1))), audio.frame_rate
194
+ if audio_array.shape[1] == 1:
195
+ audio_array = audio_array[:, 0]
196
+
197
+ return audio_array, sr
198
+
199
+ def encode(self, in_path, out_path, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
200
+ """
201
+ Encodes a message into an audio file.
202
+
203
+ Parameters:
204
+ - in_path (str): The path to the input audio file.
205
+ - out_path (str): The path to save the output audio file.
206
+ - message_list (list): A list of messages to be encoded into the audio file.
207
+ - message_sdr (float, optional): The Signal-to-Distortion Ratio (SDR) of the message. Defaults to None.
208
+ - calc_sdr (bool, optional): Whether to calculate the SDR of the encoded audio. Defaults to True.
209
+ - disable_checks (bool, optional): Whether to disable input checks. Defaults to False.
210
+
211
+ Returns:
212
+ - dict: A dictionary containing the status of the encoding process, the SDR value(s), the time taken for encoding, and the time taken per second of audio.
213
+
214
+ """
215
+ y, orig_sr = self.load_audio(in_path)
216
+ start = time.time()
217
+ encoded_y, sdr = self.encode_wav(y, orig_sr, message_list=message_list, message_sdr=message_sdr, calc_sdr=calc_sdr, disable_checks=disable_checks)
218
+ time_taken = time.time() - start
219
+ sf.write(out_path, encoded_y, orig_sr)
220
+
221
+ if type(sdr) == list:
222
+ return {'status': True, 'sdr': [f'{sdr_i:.2f}' for sdr_i in sdr], 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
223
+ else:
224
+ return {'status': True, 'sdr': f'{sdr:.2f}', 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
225
+
226
+ def decode(self, path, phase_shift_decoding):
227
+ """
228
+ Decode the audio file at the given path using phase shift decoding.
229
+
230
+ Parameters:
231
+ path (str): The path to the audio file.
232
+ phase_shift_decoding (bool): Flag indicating whether to use phase shift decoding.
233
+
234
+ Returns:
235
+ dictionary: A dictionary containing the decoded message status and value
236
+ """
237
+
238
+ y, orig_sr = self.load_audio(path)
239
+
240
+ return self.decode_wav(y, orig_sr, phase_shift_decoding)
241
+
242
+ def encode_wav(self, y_multi_channel, orig_sr, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
243
+
244
+ """
245
+ Encodes a multi-channel audio waveform with a given message.
246
+
247
+ Args:
248
+ y_multi_channel (numpy.ndarray): The multi-channel audio waveform to be encoded.
249
+ orig_sr (int): The original sampling rate of the audio waveform.
250
+ message_list (list): The list of messages to be encoded. Each message may correspond to a channel in the audio waveform.
251
+ message_sdr (float, optional): The signal-to-distortion ratio (SDR) of the message. If not provided, the default SDR from the configuration is used.
252
+ calc_sdr (bool, optional): Flag indicating whether to calculate the SDR of the encoded waveform. Defaults to True.
253
+ disable_checks (bool, optional): Flag indicating whether to disable input audio checks. Defaults to False.
254
+
255
+ Returns:
256
+ tuple: A tuple containing the encoded multi-channel audio waveform and the SDR (if calculated).
257
+
258
+ Raises:
259
+ AssertionError: If the number of messages does not match the number of channels in the input audio waveform.
260
+ """
261
+
262
+ single_channel = False
263
+ if len(y_multi_channel.shape) == 1:
264
+ single_channel = True
265
+ y_multi_channel = y_multi_channel[:, None]
266
+
267
+ if message_sdr is None:
268
+ message_sdr = self.config.message_sdr
269
+ print(f'Using the default SDR of {self.config.message_sdr} dB')
270
+
271
+ if type(message_list[0]) == int:
272
+ message_list = [message_list]*y_multi_channel.shape[1]
273
+
274
+ y_watermarked_multi_channel = []
275
+ sdrs = []
276
+
277
+ assert len(message_list) == y_multi_channel.shape[1], f'{len(message_list)} | {y_multi_channel.shape[1]} Mismatch in the number of messages and channels in the input audio.'
278
+
279
+ for channel_i in range(y_multi_channel.shape[1]):
280
+ y = y_multi_channel[:, channel_i]
281
+ message = message_list[channel_i]
282
+
283
+ with torch.no_grad():
284
+
285
+ orig_y = y.copy()
286
+ if orig_sr != self.sr:
287
+ if orig_sr > self.sr:
288
+ print(f'WARNING! Reducing the sampling rate of the original audio from {orig_sr} -> {self.sr}. High frequency components may be lost!')
289
+ y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
290
+ original_power = np.mean(y**2)
291
+
292
+ if not disable_checks:
293
+ if original_power == 0:
294
+ print('WARNING! The input audio has a power of 0.This means the audio is likely just silence. Skipping encoding.')
295
+ return orig_y, 0
296
+
297
+ y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
298
+ y = torch.FloatTensor(y).unsqueeze(0).unsqueeze(0).to(self.device)
299
+ carrier, carrier_phase = self.stft.transform(y.squeeze(1))
300
+ carrier = carrier[:, None]
301
+ carrier_phase = carrier_phase[:, None]
302
+
303
+ def binary_encode(mes):
304
+ binary_message = ''.join(['{0:08b}'.format(mes_i) for mes_i in mes])
305
+ four_bit_msg = []
306
+ for i in range(len(binary_message)//2):
307
+ four_bit_msg.append(int(binary_message[i*2:i*2+2], 2))
308
+ return four_bit_msg
309
+
310
+ binary_encoded_message = binary_encode(message)
311
+
312
+ msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message])
313
+ msg_enc = torch.from_numpy(msgs[None]).to(self.device).float()
314
+
315
+ carrier_enc = self.enc_c(carrier) # encode the carrier
316
+ msg_enc = self.enc_c.transform_message(msg_enc)
317
+
318
+ merged_enc = torch.cat((carrier_enc, carrier.repeat(1, 32, 1, 1), msg_enc.repeat(1, 32, 1, 1)), dim=1) # concat encodings on features axis
319
+
320
+ message_info = self.dec_c(merged_enc, message_sdr)
321
+ if self.config.frame_level_normalization:
322
+ message_info = message_info*(torch.mean((carrier**2), dim=2, keepdim=True)**0.5) # *time_weighing
323
+ elif self.config.utterance_level_normalization:
324
+ message_info = message_info*(torch.mean((carrier**2), dim=(2,3), keepdim=True)**0.5) # *time_weighing
325
+
326
+ if self.config.ensure_negative_message:
327
+ message_info = -message_info
328
+ carrier_reconst = torch.nn.functional.relu(message_info + carrier) # decode carrier, output in stft domain
329
+ elif self.config.ensure_constrained_message:
330
+ message_info[message_info > carrier] = carrier[message_info > carrier]
331
+ message_info[-message_info > carrier] = -carrier[-message_info > carrier]
332
+ carrier_reconst = message_info + carrier # decode carrier, output in stft domain
333
+ assert torch.all(carrier_reconst >= 0), 'negative values found in carrier_reconst'
334
+ else:
335
+ carrier_reconst = torch.abs(message_info + carrier) # decode carrier, output in stft domain
336
+
337
+ self.stft.num_samples = y.shape[2]
338
+
339
+ y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1)).data.cpu().numpy()[0, 0]
340
+ y = y * np.sqrt(original_power / (self.average_energy_VCTK)) # Noise has a power of 5% power of VCTK samples
341
+ if orig_sr != self.sr:
342
+ y = librosa.resample(y, orig_sr = self.sr, target_sr = orig_sr)
343
+
344
+ if calc_sdr:
345
+ sdr = self.sdr(orig_y, y)
346
+ else:
347
+ sdr = 0
348
+
349
+ y_watermarked_multi_channel.append(y[:, None])
350
+ sdrs.append(sdr)
351
+
352
+ y_watermarked_multi_channel = np.concatenate(y_watermarked_multi_channel, axis=1)
353
+
354
+ if single_channel:
355
+ y_watermarked_multi_channel = y_watermarked_multi_channel[:, 0]
356
+ sdrs = sdrs[0]
357
+
358
+ return y_watermarked_multi_channel, sdrs
359
+
360
+ def decode_wav(self, y_multi_channel, orig_sr, phase_shift_decoding):
361
+ """
362
+ Decode the given audio waveform to extract hidden messages.
363
+
364
+ Args:
365
+ y_multi_channel (numpy.ndarray): The multi-channel audio waveform.
366
+ orig_sr (int): The original sample rate of the audio waveform.
367
+ phase_shift_decoding (str): Flag indicating whether to perform phase shift decoding.
368
+
369
+ Returns:
370
+ dict or list: A list of dictionary containing the decoded messages, confidences, and status for each channel if the input is multi-channel.
371
+ Otherwise, a dictionary containing the decoded messages, confidences, and status for a single channel.
372
+
373
+ Raises:
374
+ Exception: If the decoding process fails.
375
+
376
+ """
377
+ single_channel = False
378
+ if len(y_multi_channel.shape) == 1:
379
+ single_channel = True
380
+ y_multi_channel = y_multi_channel[:, None]
381
+
382
+ results = []
383
+
384
+ for channel_i in range(y_multi_channel.shape[1]):
385
+ y = y_multi_channel[:, channel_i]
386
+ try:
387
+ with torch.no_grad():
388
+ if orig_sr != self.sr:
389
+ y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
390
+ original_power = np.mean(y**2)
391
+ y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
392
+ if phase_shift_decoding and phase_shift_decoding != 'false':
393
+ ps = self.get_best_ps(y)
394
+ else:
395
+ ps = 0
396
+ y = torch.FloatTensor(y[ps:]).unsqueeze(0).unsqueeze(0).to(self.device)
397
+ carrier, _ = self.stft.transform(y.squeeze(1))
398
+ carrier = carrier[:, None]
399
+
400
+ msg_reconst_list = []
401
+ confidence = []
402
+
403
+ for i in range(self.n_messages): # decode each msg_i using decoder_m_i
404
+ msg_reconst = self.dec_m[i](carrier)
405
+ pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
406
+ pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
407
+ pred_values = pred_values.reshape([-1, self.config.message_len])
408
+
409
+ ord_values = st.mode(pred_values, keepdims=False).mode
410
+ end_char = np.min(np.nonzero(ord_values == 0)[0])
411
+ confidence.append(self.get_confidence(pred_values, ord_values))
412
+ if end_char == self.config.message_len:
413
+ ord_values = ord_values[:self.config.message_len-1]
414
+ else:
415
+ ord_values = np.concatenate([ord_values[end_char+1:], ord_values[:end_char]], axis=0)
416
+
417
+ # pred_values = ''.join([chr(v + 64) for v in ord_values])
418
+ msg_reconst_list.append((ord_values - 1).tolist())
419
+
420
+ def convert_to_8_bit_segments(msg_list):
421
+ segment_message_list = []
422
+ for msg_list_i in msg_list:
423
+ binary_format = ''.join(['{0:02b}'.format(mes_i) for mes_i in msg_list_i])
424
+ eight_bit_segments = [int(binary_format[i*8:i*8+8], 2) for i in range(len(binary_format)//8)]
425
+ segment_message_list.append(eight_bit_segments)
426
+ return segment_message_list
427
+ msg_reconst_list = convert_to_8_bit_segments(msg_reconst_list)
428
+
429
+ results.append({'messages': msg_reconst_list, 'confidences': confidence, 'status': True})
430
+ except:
431
+ results.append({'messages': [], 'confidences': [], 'error': 'Could not find message', 'status': False})
432
+
433
+ if single_channel:
434
+ results = results[0]
435
+
436
+ return results
437
+
438
+ def convert_dataparallel_to_normal(self, checkpoint):
439
+
440
+ return {i[len('module.'):] if i.startswith('module.') else i: checkpoint[i] for i in checkpoint }
441
+
442
+ def load_models(self, ckpt_dir):
443
+
444
+ self.enc_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "enc_c.ckpt"), map_location=self.device)))
445
+ self.dec_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "dec_c.ckpt"), map_location=self.device)))
446
+ for i,m in enumerate(self.dec_m):
447
+ m.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, f"dec_m_{i}.ckpt"), map_location=self.device)))
448
+
449
+
450
+ def get_model(model_type='44.1k', ckpt_path='../Models/44_1_khz/73999_iteration', config_path='../Models/44_1_khz/73999_iteration/hparams.yaml', device='cpu'):
451
+
452
+ if model_type == '44.1k':
453
+ if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
454
+ print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
455
+ from huggingface_hub import snapshot_download
456
+ folder_dir = snapshot_download(repo_id="sony/silentcipher")
457
+ ckpt_path = os.path.join(folder_dir, '44_1_khz/73999_iteration')
458
+ config_path = os.path.join(folder_dir, '44_1_khz/73999_iteration/hparams.yaml')
459
+
460
+ config = yaml.safe_load(open(config_path))
461
+ config = argparse.Namespace(**config)
462
+ config.load_ckpt = ckpt_path
463
+ model = Model(config, device)
464
+ elif model_type == '16k':
465
+ if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
466
+ print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
467
+ from huggingface_hub import snapshot_download
468
+ folder_dir = snapshot_download(repo_id="sony/silentcipher")
469
+ ckpt_path = os.path.join(folder_dir, '16_khz/97561_iteration')
470
+ config_path = os.path.join(folder_dir, '16_khz/97561_iteration/hparams.yaml')
471
+
472
+ config = yaml.safe_load(open(config_path))
473
+ config = argparse.Namespace(**config)
474
+ config.load_ckpt = ckpt_path
475
+
476
+ model = Model(config, device)
477
+ else:
478
+ print('Please specify a valid model_type [44.1k, 16k]')
479
+
480
  return model