Spaces:
Running
on
Zero
Running
on
Zero
Update silentcipher/server.py
Browse files- silentcipher/server.py +479 -479
silentcipher/server.py
CHANGED
|
@@ -1,480 +1,480 @@
|
|
| 1 |
-
from calendar import c
|
| 2 |
-
import os
|
| 3 |
-
import argparse
|
| 4 |
-
import re
|
| 5 |
-
from tabnanny import check
|
| 6 |
-
import yaml
|
| 7 |
-
import time
|
| 8 |
-
import numpy as np
|
| 9 |
-
import soundfile as sf
|
| 10 |
-
from scipy import stats as st
|
| 11 |
-
import librosa
|
| 12 |
-
from pydub import AudioSegment
|
| 13 |
-
import torch
|
| 14 |
-
from torch import nn
|
| 15 |
-
|
| 16 |
-
from .model import Encoder, CarrierDecoder, MsgDecoder
|
| 17 |
-
from .stft import STFT
|
| 18 |
-
|
| 19 |
-
class Model():
|
| 20 |
-
|
| 21 |
-
def __init__(self, config, device='cpu'):
|
| 22 |
-
|
| 23 |
-
self.config = config
|
| 24 |
-
self.device = device
|
| 25 |
-
|
| 26 |
-
self.n_messages = config.n_messages
|
| 27 |
-
self.model_type = config.model_type
|
| 28 |
-
self.message_dim = config.message_dim
|
| 29 |
-
self.message_len = config.message_len
|
| 30 |
-
|
| 31 |
-
# model dimensions
|
| 32 |
-
self.enc_conv_dim = 16
|
| 33 |
-
self.enc_num_repeat = 3
|
| 34 |
-
self.dec_c_num_repeat = self.enc_num_repeat
|
| 35 |
-
self.dec_m_conv_dim = 1
|
| 36 |
-
self.dec_m_num_repeat = 8
|
| 37 |
-
self.encoder_out_dim = 32
|
| 38 |
-
self.dec_c_conv_dim = 32*3
|
| 39 |
-
|
| 40 |
-
self.enc_c = Encoder(n_layers=self.config.enc_n_layers,
|
| 41 |
-
message_dim=self.message_dim,
|
| 42 |
-
out_dim=self.encoder_out_dim,
|
| 43 |
-
message_band_size=self.config.message_band_size,
|
| 44 |
-
n_fft=self.config.N_FFT)
|
| 45 |
-
|
| 46 |
-
self.dec_c = CarrierDecoder(config=self.config,
|
| 47 |
-
conv_dim=self.dec_c_conv_dim,
|
| 48 |
-
n_layers=self.config.dec_c_n_layers,
|
| 49 |
-
message_band_size=self.config.message_band_size)
|
| 50 |
-
|
| 51 |
-
self.dec_m = [MsgDecoder(message_dim=self.message_dim,
|
| 52 |
-
message_band_size=self.config.message_band_size) for _ in range(self.n_messages)]
|
| 53 |
-
# ------ make parallel ------
|
| 54 |
-
self.enc_c = self.enc_c.to(self.device)
|
| 55 |
-
self.dec_c = self.dec_c.to(self.device)
|
| 56 |
-
self.dec_m = [m.to(self.device) for m in self.dec_m]
|
| 57 |
-
|
| 58 |
-
self.average_energy_VCTK=0.002837200844477648
|
| 59 |
-
self.stft = STFT(self.config.N_FFT, self.config.HOP_LENGTH)
|
| 60 |
-
self.stft.to(self.device)
|
| 61 |
-
self.load_models(config.load_ckpt)
|
| 62 |
-
self.sr = self.config.SR
|
| 63 |
-
|
| 64 |
-
def letters_encoding(self, patch_len, message_lst):
|
| 65 |
-
|
| 66 |
-
"""
|
| 67 |
-
Encodes a list of messages into a compact representation and a padded representation.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
patch_len (int): The length of the patch.
|
| 71 |
-
message_lst (list): A list of messages to be encoded.
|
| 72 |
-
|
| 73 |
-
Returns:
|
| 74 |
-
tuple: A tuple containing two numpy arrays:
|
| 75 |
-
- message: A padded representation of the messages, where each message is repeated to match the patch length.
|
| 76 |
-
- message_compact: A compact representation of the messages, where each message is encoded as a one-hot vector.
|
| 77 |
-
|
| 78 |
-
Raises:
|
| 79 |
-
AssertionError: If the length of any message in message_lst is not equal to self.config.message_len - 1.
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
message = []
|
| 83 |
-
message_compact = []
|
| 84 |
-
for i in range(self.n_messages):
|
| 85 |
-
|
| 86 |
-
assert len(message_lst[i]) == self.config.message_len - 1
|
| 87 |
-
index = np.concatenate((np.array(message_lst[i])+1, [0]))
|
| 88 |
-
one_hot = np.identity(self.message_dim)[index]
|
| 89 |
-
message_compact.append(one_hot)
|
| 90 |
-
if patch_len % self.message_len == 0:
|
| 91 |
-
message.append(np.tile(one_hot.T, (1, patch_len // self.message_len)))
|
| 92 |
-
else:
|
| 93 |
-
_ = np.tile(one_hot.T, (1, patch_len // self.message_len))
|
| 94 |
-
_ = np.concatenate([_, one_hot.T[:, 0:patch_len % self.message_len]], axis=1)
|
| 95 |
-
message.append(_)
|
| 96 |
-
message = np.stack(message)
|
| 97 |
-
message_compact = np.stack(message_compact)
|
| 98 |
-
# message = np.pad(message, ((0, 0), (0, 129 - self.message_dim), (0, 0)), 'constant')
|
| 99 |
-
return message, message_compact
|
| 100 |
-
|
| 101 |
-
def get_best_ps(self, y_one_sec):
|
| 102 |
-
|
| 103 |
-
"""
|
| 104 |
-
Calculates the best phase shift value for watermark decoding.
|
| 105 |
-
|
| 106 |
-
Args:
|
| 107 |
-
y_one_sec (numpy.ndarray): Input audio signal.
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
int: The best phase shift value.
|
| 111 |
-
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
-
def check_accuracy(pred_values):
|
| 115 |
-
|
| 116 |
-
accuracy = 0
|
| 117 |
-
for i in range(pred_values.shape[1]):
|
| 118 |
-
unique, counts = np.unique(pred_values[:, i], return_counts=True)
|
| 119 |
-
accuracy += np.max(counts) / pred_values.shape[0]
|
| 120 |
-
|
| 121 |
-
return accuracy / pred_values.shape[1]
|
| 122 |
-
|
| 123 |
-
y = torch.FloatTensor(y_one_sec).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 124 |
-
max_accuracy = 0
|
| 125 |
-
final_phase_shift = 0
|
| 126 |
-
|
| 127 |
-
for ps in range(0, self.config.HOP_LENGTH, 10):
|
| 128 |
-
|
| 129 |
-
carrier, _ = self.stft.transform(y[0:1, 0:1, ps:].squeeze(1))
|
| 130 |
-
carrier = carrier[:, None]
|
| 131 |
-
|
| 132 |
-
for i in range(self.n_messages): # decode each msg_i using decoder_m_i
|
| 133 |
-
msg_reconst = self.dec_m[i](carrier)
|
| 134 |
-
pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
|
| 135 |
-
pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
|
| 136 |
-
pred_values = pred_values.reshape([-1, self.config.message_len])
|
| 137 |
-
cur_acc = check_accuracy(pred_values)
|
| 138 |
-
if cur_acc > max_accuracy:
|
| 139 |
-
max_accuracy = cur_acc
|
| 140 |
-
final_phase_shift = ps
|
| 141 |
-
|
| 142 |
-
return final_phase_shift
|
| 143 |
-
|
| 144 |
-
def get_confidence(self, pred_values, message):
|
| 145 |
-
"""
|
| 146 |
-
Calculates the confidence of the predicted values based on the provided message.
|
| 147 |
-
|
| 148 |
-
Parameters:
|
| 149 |
-
pred_values (numpy.ndarray): The predicted values.
|
| 150 |
-
message (str): The message used for prediction.
|
| 151 |
-
|
| 152 |
-
Returns:
|
| 153 |
-
float: The confidence score.
|
| 154 |
-
|
| 155 |
-
Raises:
|
| 156 |
-
AssertionError: If the length of the message is not equal to the number of columns in pred_values.
|
| 157 |
-
|
| 158 |
-
"""
|
| 159 |
-
assert len(message) == pred_values.shape[1], f'{len(message)} | {pred_values.shape}'
|
| 160 |
-
return np.mean((pred_values == message[None]).astype(np.float32)).item()
|
| 161 |
-
|
| 162 |
-
def sdr(self, orig, recon):
|
| 163 |
-
"""
|
| 164 |
-
Calculate the Signal-to-Distortion Ratio (SDR) between the original and reconstructed signals.
|
| 165 |
-
|
| 166 |
-
Parameters:
|
| 167 |
-
orig (numpy.ndarray): The original signal.
|
| 168 |
-
recon (numpy.ndarray): The reconstructed signal.
|
| 169 |
-
|
| 170 |
-
Returns:
|
| 171 |
-
float: The Signal-to-Distortion Ratio (SDR) value.
|
| 172 |
-
|
| 173 |
-
"""
|
| 174 |
-
|
| 175 |
-
rms1 = ((np.mean(orig ** 2)) ** 0.5)
|
| 176 |
-
rms2 = ((np.mean((orig - recon) ** 2)) ** 0.5)
|
| 177 |
-
sdr = 20 * np.log10(rms1 / rms2)
|
| 178 |
-
return sdr
|
| 179 |
-
|
| 180 |
-
def load_audio(self, path):
|
| 181 |
-
"""
|
| 182 |
-
Load an audio file from the given path and return the audio array and sample rate.
|
| 183 |
-
|
| 184 |
-
Args:
|
| 185 |
-
path (str): The path to the audio file.
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
tuple: A tuple containing the audio array and sample rate.
|
| 189 |
-
|
| 190 |
-
"""
|
| 191 |
-
audio = AudioSegment.from_file(path)
|
| 192 |
-
audio_array, sr = (np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((-1, audio.channels)) / (
|
| 193 |
-
1 << (8 * audio.sample_width - 1))), audio.frame_rate
|
| 194 |
-
if audio_array.shape[1] == 1:
|
| 195 |
-
audio_array = audio_array[:, 0]
|
| 196 |
-
|
| 197 |
-
return audio_array, sr
|
| 198 |
-
|
| 199 |
-
def encode(self, in_path, out_path, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
|
| 200 |
-
"""
|
| 201 |
-
Encodes a message into an audio file.
|
| 202 |
-
|
| 203 |
-
Parameters:
|
| 204 |
-
- in_path (str): The path to the input audio file.
|
| 205 |
-
- out_path (str): The path to save the output audio file.
|
| 206 |
-
- message_list (list): A list of messages to be encoded into the audio file.
|
| 207 |
-
- message_sdr (float, optional): The Signal-to-Distortion Ratio (SDR) of the message. Defaults to None.
|
| 208 |
-
- calc_sdr (bool, optional): Whether to calculate the SDR of the encoded audio. Defaults to True.
|
| 209 |
-
- disable_checks (bool, optional): Whether to disable input checks. Defaults to False.
|
| 210 |
-
|
| 211 |
-
Returns:
|
| 212 |
-
- dict: A dictionary containing the status of the encoding process, the SDR value(s), the time taken for encoding, and the time taken per second of audio.
|
| 213 |
-
|
| 214 |
-
"""
|
| 215 |
-
y, orig_sr = self.load_audio(in_path)
|
| 216 |
-
start = time.time()
|
| 217 |
-
encoded_y, sdr = self.encode_wav(y, orig_sr, message_list=message_list, message_sdr=message_sdr, calc_sdr=calc_sdr, disable_checks=disable_checks)
|
| 218 |
-
time_taken = time.time() - start
|
| 219 |
-
sf.write(out_path, encoded_y, orig_sr)
|
| 220 |
-
|
| 221 |
-
if type(sdr) == list:
|
| 222 |
-
return {'status': True, 'sdr': [f'{sdr_i:.2f}' for sdr_i in sdr], 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
|
| 223 |
-
else:
|
| 224 |
-
return {'status': True, 'sdr': f'{sdr:.2f}', 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
|
| 225 |
-
|
| 226 |
-
def decode(self, path, phase_shift_decoding):
|
| 227 |
-
"""
|
| 228 |
-
Decode the audio file at the given path using phase shift decoding.
|
| 229 |
-
|
| 230 |
-
Parameters:
|
| 231 |
-
path (str): The path to the audio file.
|
| 232 |
-
phase_shift_decoding (bool): Flag indicating whether to use phase shift decoding.
|
| 233 |
-
|
| 234 |
-
Returns:
|
| 235 |
-
dictionary: A dictionary containing the decoded message status and value
|
| 236 |
-
"""
|
| 237 |
-
|
| 238 |
-
y, orig_sr = self.load_audio(path)
|
| 239 |
-
|
| 240 |
-
return self.decode_wav(y, orig_sr, phase_shift_decoding)
|
| 241 |
-
|
| 242 |
-
def encode_wav(self, y_multi_channel, orig_sr, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
|
| 243 |
-
|
| 244 |
-
"""
|
| 245 |
-
Encodes a multi-channel audio waveform with a given message.
|
| 246 |
-
|
| 247 |
-
Args:
|
| 248 |
-
y_multi_channel (numpy.ndarray): The multi-channel audio waveform to be encoded.
|
| 249 |
-
orig_sr (int): The original sampling rate of the audio waveform.
|
| 250 |
-
message_list (list): The list of messages to be encoded. Each message may correspond to a channel in the audio waveform.
|
| 251 |
-
message_sdr (float, optional): The signal-to-distortion ratio (SDR) of the message. If not provided, the default SDR from the configuration is used.
|
| 252 |
-
calc_sdr (bool, optional): Flag indicating whether to calculate the SDR of the encoded waveform. Defaults to True.
|
| 253 |
-
disable_checks (bool, optional): Flag indicating whether to disable input audio checks. Defaults to False.
|
| 254 |
-
|
| 255 |
-
Returns:
|
| 256 |
-
tuple: A tuple containing the encoded multi-channel audio waveform and the SDR (if calculated).
|
| 257 |
-
|
| 258 |
-
Raises:
|
| 259 |
-
AssertionError: If the number of messages does not match the number of channels in the input audio waveform.
|
| 260 |
-
"""
|
| 261 |
-
|
| 262 |
-
single_channel = False
|
| 263 |
-
if len(y_multi_channel.shape) == 1:
|
| 264 |
-
single_channel = True
|
| 265 |
-
y_multi_channel = y_multi_channel[:, None]
|
| 266 |
-
|
| 267 |
-
if message_sdr is None:
|
| 268 |
-
message_sdr = self.config.message_sdr
|
| 269 |
-
print(f'Using the default SDR of {self.config.message_sdr} dB')
|
| 270 |
-
|
| 271 |
-
if type(message_list[0]) == int:
|
| 272 |
-
message_list = [message_list]*y_multi_channel.shape[1]
|
| 273 |
-
|
| 274 |
-
y_watermarked_multi_channel = []
|
| 275 |
-
sdrs = []
|
| 276 |
-
|
| 277 |
-
assert len(message_list) == y_multi_channel.shape[1], f'{len(message_list)} | {y_multi_channel.shape[1]} Mismatch in the number of messages and channels in the input audio.'
|
| 278 |
-
|
| 279 |
-
for channel_i in range(y_multi_channel.shape[1]):
|
| 280 |
-
y = y_multi_channel[:, channel_i]
|
| 281 |
-
message = message_list[channel_i]
|
| 282 |
-
|
| 283 |
-
with torch.no_grad():
|
| 284 |
-
|
| 285 |
-
orig_y = y.copy()
|
| 286 |
-
if orig_sr != self.sr:
|
| 287 |
-
if orig_sr > self.sr:
|
| 288 |
-
print(f'WARNING! Reducing the sampling rate of the original audio from {orig_sr} -> {self.sr}. High frequency components may be lost!')
|
| 289 |
-
y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
|
| 290 |
-
original_power = np.mean(y**2)
|
| 291 |
-
|
| 292 |
-
if not disable_checks:
|
| 293 |
-
if original_power == 0:
|
| 294 |
-
print('WARNING! The input audio has a power of 0.This means the audio is likely just silence. Skipping encoding.')
|
| 295 |
-
return orig_y, 0
|
| 296 |
-
|
| 297 |
-
y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
|
| 298 |
-
y = torch.FloatTensor(y).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 299 |
-
carrier, carrier_phase = self.stft.transform(y.squeeze(1))
|
| 300 |
-
carrier = carrier[:, None]
|
| 301 |
-
carrier_phase = carrier_phase[:, None]
|
| 302 |
-
|
| 303 |
-
def binary_encode(mes):
|
| 304 |
-
binary_message = ''.join(['{0:08b}'.format(mes_i) for mes_i in mes])
|
| 305 |
-
four_bit_msg = []
|
| 306 |
-
for i in range(len(binary_message)//2):
|
| 307 |
-
four_bit_msg.append(int(binary_message[i*2:i*2+2], 2))
|
| 308 |
-
return four_bit_msg
|
| 309 |
-
|
| 310 |
-
binary_encoded_message = binary_encode(message)
|
| 311 |
-
|
| 312 |
-
msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message])
|
| 313 |
-
msg_enc = torch.from_numpy(msgs[None]).to(self.device).float()
|
| 314 |
-
|
| 315 |
-
carrier_enc = self.enc_c(carrier) # encode the carrier
|
| 316 |
-
msg_enc = self.enc_c.transform_message(msg_enc)
|
| 317 |
-
|
| 318 |
-
merged_enc = torch.cat((carrier_enc, carrier.repeat(1, 32, 1, 1), msg_enc.repeat(1, 32, 1, 1)), dim=1) # concat encodings on features axis
|
| 319 |
-
|
| 320 |
-
message_info = self.dec_c(merged_enc, message_sdr)
|
| 321 |
-
if self.config.frame_level_normalization:
|
| 322 |
-
message_info = message_info*(torch.mean((carrier**2), dim=2, keepdim=True)**0.5) # *time_weighing
|
| 323 |
-
elif self.config.utterance_level_normalization:
|
| 324 |
-
message_info = message_info*(torch.mean((carrier**2), dim=(2,3), keepdim=True)**0.5) # *time_weighing
|
| 325 |
-
|
| 326 |
-
if self.config.ensure_negative_message:
|
| 327 |
-
message_info = -message_info
|
| 328 |
-
carrier_reconst = torch.nn.functional.relu(message_info + carrier) # decode carrier, output in stft domain
|
| 329 |
-
elif self.config.ensure_constrained_message:
|
| 330 |
-
message_info[message_info > carrier] = carrier[message_info > carrier]
|
| 331 |
-
message_info[-message_info > carrier] = -carrier[-message_info > carrier]
|
| 332 |
-
carrier_reconst = message_info + carrier # decode carrier, output in stft domain
|
| 333 |
-
assert torch.all(carrier_reconst >= 0), 'negative values found in carrier_reconst'
|
| 334 |
-
else:
|
| 335 |
-
carrier_reconst = torch.abs(message_info + carrier) # decode carrier, output in stft domain
|
| 336 |
-
|
| 337 |
-
self.stft.num_samples = y.shape[2]
|
| 338 |
-
|
| 339 |
-
y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1)).data.cpu().numpy()[0, 0]
|
| 340 |
-
y = y * np.sqrt(original_power / (self.average_energy_VCTK)) # Noise has a power of 5% power of VCTK samples
|
| 341 |
-
if orig_sr != self.sr:
|
| 342 |
-
y = librosa.resample(y, orig_sr = self.sr, target_sr = orig_sr)
|
| 343 |
-
|
| 344 |
-
if calc_sdr:
|
| 345 |
-
sdr = self.sdr(orig_y, y)
|
| 346 |
-
else:
|
| 347 |
-
sdr = 0
|
| 348 |
-
|
| 349 |
-
y_watermarked_multi_channel.append(y[:, None])
|
| 350 |
-
sdrs.append(sdr)
|
| 351 |
-
|
| 352 |
-
y_watermarked_multi_channel = np.concatenate(y_watermarked_multi_channel, axis=1)
|
| 353 |
-
|
| 354 |
-
if single_channel:
|
| 355 |
-
y_watermarked_multi_channel = y_watermarked_multi_channel[:, 0]
|
| 356 |
-
sdrs = sdrs[0]
|
| 357 |
-
|
| 358 |
-
return y_watermarked_multi_channel, sdrs
|
| 359 |
-
|
| 360 |
-
def decode_wav(self, y_multi_channel, orig_sr, phase_shift_decoding):
|
| 361 |
-
"""
|
| 362 |
-
Decode the given audio waveform to extract hidden messages.
|
| 363 |
-
|
| 364 |
-
Args:
|
| 365 |
-
y_multi_channel (numpy.ndarray): The multi-channel audio waveform.
|
| 366 |
-
orig_sr (int): The original sample rate of the audio waveform.
|
| 367 |
-
phase_shift_decoding (str): Flag indicating whether to perform phase shift decoding.
|
| 368 |
-
|
| 369 |
-
Returns:
|
| 370 |
-
dict or list: A list of dictionary containing the decoded messages, confidences, and status for each channel if the input is multi-channel.
|
| 371 |
-
Otherwise, a dictionary containing the decoded messages, confidences, and status for a single channel.
|
| 372 |
-
|
| 373 |
-
Raises:
|
| 374 |
-
Exception: If the decoding process fails.
|
| 375 |
-
|
| 376 |
-
"""
|
| 377 |
-
single_channel = False
|
| 378 |
-
if len(y_multi_channel.shape) == 1:
|
| 379 |
-
single_channel = True
|
| 380 |
-
y_multi_channel = y_multi_channel[:, None]
|
| 381 |
-
|
| 382 |
-
results = []
|
| 383 |
-
|
| 384 |
-
for channel_i in range(y_multi_channel.shape[1]):
|
| 385 |
-
y = y_multi_channel[:, channel_i]
|
| 386 |
-
try:
|
| 387 |
-
with torch.no_grad():
|
| 388 |
-
if orig_sr != self.sr:
|
| 389 |
-
y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
|
| 390 |
-
original_power = np.mean(y**2)
|
| 391 |
-
y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
|
| 392 |
-
if phase_shift_decoding and phase_shift_decoding != 'false':
|
| 393 |
-
ps = self.get_best_ps(y)
|
| 394 |
-
else:
|
| 395 |
-
ps = 0
|
| 396 |
-
y = torch.FloatTensor(y[ps:]).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 397 |
-
carrier, _ = self.stft.transform(y.squeeze(1))
|
| 398 |
-
carrier = carrier[:, None]
|
| 399 |
-
|
| 400 |
-
msg_reconst_list = []
|
| 401 |
-
confidence = []
|
| 402 |
-
|
| 403 |
-
for i in range(self.n_messages): # decode each msg_i using decoder_m_i
|
| 404 |
-
msg_reconst = self.dec_m[i](carrier)
|
| 405 |
-
pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
|
| 406 |
-
pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
|
| 407 |
-
pred_values = pred_values.reshape([-1, self.config.message_len])
|
| 408 |
-
|
| 409 |
-
ord_values = st.mode(pred_values, keepdims=False).mode
|
| 410 |
-
end_char = np.min(np.nonzero(ord_values == 0)[0])
|
| 411 |
-
confidence.append(self.get_confidence(pred_values, ord_values))
|
| 412 |
-
if end_char == self.config.message_len:
|
| 413 |
-
ord_values = ord_values[:self.config.message_len-1]
|
| 414 |
-
else:
|
| 415 |
-
ord_values = np.concatenate([ord_values[end_char+1:], ord_values[:end_char]], axis=0)
|
| 416 |
-
|
| 417 |
-
# pred_values = ''.join([chr(v + 64) for v in ord_values])
|
| 418 |
-
msg_reconst_list.append((ord_values - 1).tolist())
|
| 419 |
-
|
| 420 |
-
def convert_to_8_bit_segments(msg_list):
|
| 421 |
-
segment_message_list = []
|
| 422 |
-
for msg_list_i in msg_list:
|
| 423 |
-
binary_format = ''.join(['{0:02b}'.format(mes_i) for mes_i in msg_list_i])
|
| 424 |
-
eight_bit_segments = [int(binary_format[i*8:i*8+8], 2) for i in range(len(binary_format)//8)]
|
| 425 |
-
segment_message_list.append(eight_bit_segments)
|
| 426 |
-
return segment_message_list
|
| 427 |
-
msg_reconst_list = convert_to_8_bit_segments(msg_reconst_list)
|
| 428 |
-
|
| 429 |
-
results.append({'messages': msg_reconst_list, 'confidences': confidence, 'status': True})
|
| 430 |
-
except:
|
| 431 |
-
results.append({'messages': [], 'confidences': [], 'error': 'Could not find message', 'status': False})
|
| 432 |
-
|
| 433 |
-
if single_channel:
|
| 434 |
-
results = results[0]
|
| 435 |
-
|
| 436 |
-
return results
|
| 437 |
-
|
| 438 |
-
def convert_dataparallel_to_normal(self, checkpoint):
|
| 439 |
-
|
| 440 |
-
return {i[len('module.'):] if i.startswith('module.') else i: checkpoint[i] for i in checkpoint }
|
| 441 |
-
|
| 442 |
-
def load_models(self, ckpt_dir):
|
| 443 |
-
|
| 444 |
-
self.enc_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "enc_c.ckpt"), map_location=self.device)))
|
| 445 |
-
self.dec_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "dec_c.ckpt"), map_location=self.device)))
|
| 446 |
-
for i,m in enumerate(self.dec_m):
|
| 447 |
-
m.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, f"dec_m_{i}.ckpt"), map_location=self.device)))
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
def get_model(model_type='44.1k', ckpt_path='../Models/44_1_khz/73999_iteration', config_path='../Models/44_1_khz/73999_iteration/hparams.yaml', device='cpu'):
|
| 451 |
-
|
| 452 |
-
if model_type == '44.1k':
|
| 453 |
-
if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
|
| 454 |
-
print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
|
| 455 |
-
from huggingface_hub import snapshot_download
|
| 456 |
-
folder_dir = snapshot_download(repo_id="sony/silentcipher")
|
| 457 |
-
ckpt_path = os.path.join(folder_dir, '44_1_khz/73999_iteration')
|
| 458 |
-
config_path = os.path.join(folder_dir, '44_1_khz/73999_iteration/hparams.yaml')
|
| 459 |
-
|
| 460 |
-
config = yaml.safe_load(open(config_path))
|
| 461 |
-
config = argparse.Namespace(**config)
|
| 462 |
-
config.load_ckpt = ckpt_path
|
| 463 |
-
model = Model(config, device)
|
| 464 |
-
elif model_type == '16k':
|
| 465 |
-
if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
|
| 466 |
-
print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
|
| 467 |
-
from huggingface_hub import snapshot_download
|
| 468 |
-
folder_dir = snapshot_download(repo_id="sony/silentcipher")
|
| 469 |
-
ckpt_path = os.path.join(folder_dir, '16_khz/97561_iteration')
|
| 470 |
-
config_path = os.path.join(folder_dir, '16_khz/97561_iteration/hparams.yaml')
|
| 471 |
-
|
| 472 |
-
config = yaml.safe_load(open(config_path))
|
| 473 |
-
config = argparse.Namespace(**config)
|
| 474 |
-
config.load_ckpt = ckpt_path
|
| 475 |
-
|
| 476 |
-
model = Model(config, device)
|
| 477 |
-
else:
|
| 478 |
-
print('Please specify a valid model_type [44.1k, 16k]')
|
| 479 |
-
|
| 480 |
return model
|
|
|
|
| 1 |
+
from calendar import c
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import re
|
| 5 |
+
from tabnanny import check
|
| 6 |
+
import yaml
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
from scipy import stats as st
|
| 11 |
+
import librosa
|
| 12 |
+
from pydub import AudioSegment
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from .model import Encoder, CarrierDecoder, MsgDecoder
|
| 17 |
+
from .stft import STFT
|
| 18 |
+
|
| 19 |
+
class Model():
|
| 20 |
+
|
| 21 |
+
def __init__(self, config, device='cpu'):
|
| 22 |
+
|
| 23 |
+
self.config = config
|
| 24 |
+
self.device = device
|
| 25 |
+
|
| 26 |
+
self.n_messages = config.n_messages
|
| 27 |
+
self.model_type = config.model_type
|
| 28 |
+
self.message_dim = config.message_dim
|
| 29 |
+
self.message_len = config.message_len
|
| 30 |
+
|
| 31 |
+
# model dimensions
|
| 32 |
+
self.enc_conv_dim = 16
|
| 33 |
+
self.enc_num_repeat = 3
|
| 34 |
+
self.dec_c_num_repeat = self.enc_num_repeat
|
| 35 |
+
self.dec_m_conv_dim = 1
|
| 36 |
+
self.dec_m_num_repeat = 8
|
| 37 |
+
self.encoder_out_dim = 32
|
| 38 |
+
self.dec_c_conv_dim = 32*3
|
| 39 |
+
|
| 40 |
+
self.enc_c = Encoder(n_layers=self.config.enc_n_layers,
|
| 41 |
+
message_dim=self.message_dim,
|
| 42 |
+
out_dim=self.encoder_out_dim,
|
| 43 |
+
message_band_size=self.config.message_band_size,
|
| 44 |
+
n_fft=self.config.N_FFT)
|
| 45 |
+
|
| 46 |
+
self.dec_c = CarrierDecoder(config=self.config,
|
| 47 |
+
conv_dim=self.dec_c_conv_dim,
|
| 48 |
+
n_layers=self.config.dec_c_n_layers,
|
| 49 |
+
message_band_size=self.config.message_band_size)
|
| 50 |
+
|
| 51 |
+
self.dec_m = [MsgDecoder(message_dim=self.message_dim,
|
| 52 |
+
message_band_size=self.config.message_band_size) for _ in range(self.n_messages)]
|
| 53 |
+
# ------ make parallel ------
|
| 54 |
+
self.enc_c = self.enc_c.to(self.device)
|
| 55 |
+
self.dec_c = self.dec_c.to(self.device)
|
| 56 |
+
self.dec_m = [m.to(self.device) for m in self.dec_m]
|
| 57 |
+
|
| 58 |
+
self.average_energy_VCTK=0.002837200844477648
|
| 59 |
+
self.stft = STFT(self.config.N_FFT, self.config.HOP_LENGTH)
|
| 60 |
+
self.stft.to(self.device)
|
| 61 |
+
self.load_models(config.load_ckpt)
|
| 62 |
+
self.sr = self.config.SR
|
| 63 |
+
|
| 64 |
+
def letters_encoding(self, patch_len, message_lst):
|
| 65 |
+
|
| 66 |
+
"""
|
| 67 |
+
Encodes a list of messages into a compact representation and a padded representation.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
patch_len (int): The length of the patch.
|
| 71 |
+
message_lst (list): A list of messages to be encoded.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
tuple: A tuple containing two numpy arrays:
|
| 75 |
+
- message: A padded representation of the messages, where each message is repeated to match the patch length.
|
| 76 |
+
- message_compact: A compact representation of the messages, where each message is encoded as a one-hot vector.
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
AssertionError: If the length of any message in message_lst is not equal to self.config.message_len - 1.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
message = []
|
| 83 |
+
message_compact = []
|
| 84 |
+
for i in range(self.n_messages):
|
| 85 |
+
|
| 86 |
+
assert len(message_lst[i]) == self.config.message_len - 1
|
| 87 |
+
index = np.concatenate((np.array(message_lst[i])+1, [0]))
|
| 88 |
+
one_hot = np.identity(self.message_dim)[index]
|
| 89 |
+
message_compact.append(one_hot)
|
| 90 |
+
if patch_len % self.message_len == 0:
|
| 91 |
+
message.append(np.tile(one_hot.T, (1, patch_len // self.message_len)))
|
| 92 |
+
else:
|
| 93 |
+
_ = np.tile(one_hot.T, (1, patch_len // self.message_len))
|
| 94 |
+
_ = np.concatenate([_, one_hot.T[:, 0:patch_len % self.message_len]], axis=1)
|
| 95 |
+
message.append(_)
|
| 96 |
+
message = np.stack(message)
|
| 97 |
+
message_compact = np.stack(message_compact)
|
| 98 |
+
# message = np.pad(message, ((0, 0), (0, 129 - self.message_dim), (0, 0)), 'constant')
|
| 99 |
+
return message, message_compact
|
| 100 |
+
|
| 101 |
+
def get_best_ps(self, y_one_sec):
|
| 102 |
+
|
| 103 |
+
"""
|
| 104 |
+
Calculates the best phase shift value for watermark decoding.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
y_one_sec (numpy.ndarray): Input audio signal.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
int: The best phase shift value.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def check_accuracy(pred_values):
|
| 115 |
+
|
| 116 |
+
accuracy = 0
|
| 117 |
+
for i in range(pred_values.shape[1]):
|
| 118 |
+
unique, counts = np.unique(pred_values[:, i], return_counts=True)
|
| 119 |
+
accuracy += np.max(counts) / pred_values.shape[0]
|
| 120 |
+
|
| 121 |
+
return accuracy / pred_values.shape[1]
|
| 122 |
+
|
| 123 |
+
y = torch.FloatTensor(y_one_sec).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 124 |
+
max_accuracy = 0
|
| 125 |
+
final_phase_shift = 0
|
| 126 |
+
|
| 127 |
+
for ps in range(0, self.config.HOP_LENGTH, 10):
|
| 128 |
+
|
| 129 |
+
carrier, _ = self.stft.transform(y[0:1, 0:1, ps:].squeeze(1))
|
| 130 |
+
carrier = carrier[:, None]
|
| 131 |
+
|
| 132 |
+
for i in range(self.n_messages): # decode each msg_i using decoder_m_i
|
| 133 |
+
msg_reconst = self.dec_m[i](carrier)
|
| 134 |
+
pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
|
| 135 |
+
pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
|
| 136 |
+
pred_values = pred_values.reshape([-1, self.config.message_len])
|
| 137 |
+
cur_acc = check_accuracy(pred_values)
|
| 138 |
+
if cur_acc > max_accuracy:
|
| 139 |
+
max_accuracy = cur_acc
|
| 140 |
+
final_phase_shift = ps
|
| 141 |
+
|
| 142 |
+
return final_phase_shift
|
| 143 |
+
|
| 144 |
+
def get_confidence(self, pred_values, message):
|
| 145 |
+
"""
|
| 146 |
+
Calculates the confidence of the predicted values based on the provided message.
|
| 147 |
+
|
| 148 |
+
Parameters:
|
| 149 |
+
pred_values (numpy.ndarray): The predicted values.
|
| 150 |
+
message (str): The message used for prediction.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
float: The confidence score.
|
| 154 |
+
|
| 155 |
+
Raises:
|
| 156 |
+
AssertionError: If the length of the message is not equal to the number of columns in pred_values.
|
| 157 |
+
|
| 158 |
+
"""
|
| 159 |
+
assert len(message) == pred_values.shape[1], f'{len(message)} | {pred_values.shape}'
|
| 160 |
+
return np.mean((pred_values == message[None]).astype(np.float32)).item()
|
| 161 |
+
|
| 162 |
+
def sdr(self, orig, recon):
|
| 163 |
+
"""
|
| 164 |
+
Calculate the Signal-to-Distortion Ratio (SDR) between the original and reconstructed signals.
|
| 165 |
+
|
| 166 |
+
Parameters:
|
| 167 |
+
orig (numpy.ndarray): The original signal.
|
| 168 |
+
recon (numpy.ndarray): The reconstructed signal.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
float: The Signal-to-Distortion Ratio (SDR) value.
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
rms1 = ((np.mean(orig ** 2)) ** 0.5)
|
| 176 |
+
rms2 = ((np.mean((orig - recon) ** 2)) ** 0.5)
|
| 177 |
+
sdr = 20 * np.log10(rms1 / rms2)
|
| 178 |
+
return sdr
|
| 179 |
+
|
| 180 |
+
def load_audio(self, path):
|
| 181 |
+
"""
|
| 182 |
+
Load an audio file from the given path and return the audio array and sample rate.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
path (str): The path to the audio file.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
tuple: A tuple containing the audio array and sample rate.
|
| 189 |
+
|
| 190 |
+
"""
|
| 191 |
+
audio = AudioSegment.from_file(path)
|
| 192 |
+
audio_array, sr = (np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((-1, audio.channels)) / (
|
| 193 |
+
1 << (8 * audio.sample_width - 1))), audio.frame_rate
|
| 194 |
+
if audio_array.shape[1] == 1:
|
| 195 |
+
audio_array = audio_array[:, 0]
|
| 196 |
+
|
| 197 |
+
return audio_array, sr
|
| 198 |
+
|
| 199 |
+
def encode(self, in_path, out_path, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
|
| 200 |
+
"""
|
| 201 |
+
Encodes a message into an audio file.
|
| 202 |
+
|
| 203 |
+
Parameters:
|
| 204 |
+
- in_path (str): The path to the input audio file.
|
| 205 |
+
- out_path (str): The path to save the output audio file.
|
| 206 |
+
- message_list (list): A list of messages to be encoded into the audio file.
|
| 207 |
+
- message_sdr (float, optional): The Signal-to-Distortion Ratio (SDR) of the message. Defaults to None.
|
| 208 |
+
- calc_sdr (bool, optional): Whether to calculate the SDR of the encoded audio. Defaults to True.
|
| 209 |
+
- disable_checks (bool, optional): Whether to disable input checks. Defaults to False.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
- dict: A dictionary containing the status of the encoding process, the SDR value(s), the time taken for encoding, and the time taken per second of audio.
|
| 213 |
+
|
| 214 |
+
"""
|
| 215 |
+
y, orig_sr = self.load_audio(in_path)
|
| 216 |
+
start = time.time()
|
| 217 |
+
encoded_y, sdr = self.encode_wav(y, orig_sr, message_list=message_list, message_sdr=message_sdr, calc_sdr=calc_sdr, disable_checks=disable_checks)
|
| 218 |
+
time_taken = time.time() - start
|
| 219 |
+
sf.write(out_path, encoded_y, orig_sr)
|
| 220 |
+
|
| 221 |
+
if type(sdr) == list:
|
| 222 |
+
return {'status': True, 'sdr': [f'{sdr_i:.2f}' for sdr_i in sdr], 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
|
| 223 |
+
else:
|
| 224 |
+
return {'status': True, 'sdr': f'{sdr:.2f}', 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
|
| 225 |
+
|
| 226 |
+
def decode(self, path, phase_shift_decoding):
|
| 227 |
+
"""
|
| 228 |
+
Decode the audio file at the given path using phase shift decoding.
|
| 229 |
+
|
| 230 |
+
Parameters:
|
| 231 |
+
path (str): The path to the audio file.
|
| 232 |
+
phase_shift_decoding (bool): Flag indicating whether to use phase shift decoding.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
dictionary: A dictionary containing the decoded message status and value
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
y, orig_sr = self.load_audio(path)
|
| 239 |
+
|
| 240 |
+
return self.decode_wav(y, orig_sr, phase_shift_decoding)
|
| 241 |
+
|
| 242 |
+
def encode_wav(self, y_multi_channel, orig_sr, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
|
| 243 |
+
|
| 244 |
+
"""
|
| 245 |
+
Encodes a multi-channel audio waveform with a given message.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
y_multi_channel (numpy.ndarray): The multi-channel audio waveform to be encoded.
|
| 249 |
+
orig_sr (int): The original sampling rate of the audio waveform.
|
| 250 |
+
message_list (list): The list of messages to be encoded. Each message may correspond to a channel in the audio waveform.
|
| 251 |
+
message_sdr (float, optional): The signal-to-distortion ratio (SDR) of the message. If not provided, the default SDR from the configuration is used.
|
| 252 |
+
calc_sdr (bool, optional): Flag indicating whether to calculate the SDR of the encoded waveform. Defaults to True.
|
| 253 |
+
disable_checks (bool, optional): Flag indicating whether to disable input audio checks. Defaults to False.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
tuple: A tuple containing the encoded multi-channel audio waveform and the SDR (if calculated).
|
| 257 |
+
|
| 258 |
+
Raises:
|
| 259 |
+
AssertionError: If the number of messages does not match the number of channels in the input audio waveform.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
single_channel = False
|
| 263 |
+
if len(y_multi_channel.shape) == 1:
|
| 264 |
+
single_channel = True
|
| 265 |
+
y_multi_channel = y_multi_channel[:, None]
|
| 266 |
+
|
| 267 |
+
if message_sdr is None:
|
| 268 |
+
message_sdr = self.config.message_sdr
|
| 269 |
+
print(f'Using the default SDR of {self.config.message_sdr} dB')
|
| 270 |
+
|
| 271 |
+
if type(message_list[0]) == int:
|
| 272 |
+
message_list = [message_list]*y_multi_channel.shape[1]
|
| 273 |
+
|
| 274 |
+
y_watermarked_multi_channel = []
|
| 275 |
+
sdrs = []
|
| 276 |
+
|
| 277 |
+
assert len(message_list) == y_multi_channel.shape[1], f'{len(message_list)} | {y_multi_channel.shape[1]} Mismatch in the number of messages and channels in the input audio.'
|
| 278 |
+
|
| 279 |
+
for channel_i in range(y_multi_channel.shape[1]):
|
| 280 |
+
y = y_multi_channel[:, channel_i]
|
| 281 |
+
message = message_list[channel_i]
|
| 282 |
+
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
|
| 285 |
+
orig_y = y.copy()
|
| 286 |
+
if orig_sr != self.sr:
|
| 287 |
+
if orig_sr > self.sr:
|
| 288 |
+
print(f'WARNING! Reducing the sampling rate of the original audio from {orig_sr} -> {self.sr}. High frequency components may be lost!')
|
| 289 |
+
y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
|
| 290 |
+
original_power = np.mean(y**2)
|
| 291 |
+
|
| 292 |
+
if not disable_checks:
|
| 293 |
+
if original_power == 0:
|
| 294 |
+
print('WARNING! The input audio has a power of 0.This means the audio is likely just silence. Skipping encoding.')
|
| 295 |
+
return orig_y, 0
|
| 296 |
+
|
| 297 |
+
y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
|
| 298 |
+
y = torch.FloatTensor(y).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 299 |
+
carrier, carrier_phase = self.stft.transform(y.squeeze(1))
|
| 300 |
+
carrier = carrier[:, None]
|
| 301 |
+
carrier_phase = carrier_phase[:, None]
|
| 302 |
+
|
| 303 |
+
def binary_encode(mes):
|
| 304 |
+
binary_message = ''.join(['{0:08b}'.format(mes_i) for mes_i in mes])
|
| 305 |
+
four_bit_msg = []
|
| 306 |
+
for i in range(len(binary_message)//2):
|
| 307 |
+
four_bit_msg.append(int(binary_message[i*2:i*2+2], 2))
|
| 308 |
+
return four_bit_msg
|
| 309 |
+
|
| 310 |
+
binary_encoded_message = binary_encode(message)
|
| 311 |
+
|
| 312 |
+
msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message])
|
| 313 |
+
msg_enc = torch.from_numpy(msgs[None]).to(self.device).float()
|
| 314 |
+
|
| 315 |
+
carrier_enc = self.enc_c(carrier) # encode the carrier
|
| 316 |
+
msg_enc = self.enc_c.transform_message(msg_enc)
|
| 317 |
+
|
| 318 |
+
merged_enc = torch.cat((carrier_enc, carrier.repeat(1, 32, 1, 1), msg_enc.repeat(1, 32, 1, 1)), dim=1) # concat encodings on features axis
|
| 319 |
+
|
| 320 |
+
message_info = self.dec_c(merged_enc, message_sdr)
|
| 321 |
+
if self.config.frame_level_normalization:
|
| 322 |
+
message_info = message_info*(torch.mean((carrier**2), dim=2, keepdim=True)**0.5) # *time_weighing
|
| 323 |
+
elif self.config.utterance_level_normalization:
|
| 324 |
+
message_info = message_info*(torch.mean((carrier**2), dim=(2,3), keepdim=True)**0.5) # *time_weighing
|
| 325 |
+
|
| 326 |
+
if self.config.ensure_negative_message:
|
| 327 |
+
message_info = -message_info
|
| 328 |
+
carrier_reconst = torch.nn.functional.relu(message_info + carrier) # decode carrier, output in stft domain
|
| 329 |
+
elif self.config.ensure_constrained_message:
|
| 330 |
+
message_info[message_info > carrier] = carrier[message_info > carrier]
|
| 331 |
+
message_info[-message_info > carrier] = -carrier[-message_info > carrier]
|
| 332 |
+
carrier_reconst = message_info + carrier # decode carrier, output in stft domain
|
| 333 |
+
assert torch.all(carrier_reconst >= 0), 'negative values found in carrier_reconst'
|
| 334 |
+
else:
|
| 335 |
+
carrier_reconst = torch.abs(message_info + carrier) # decode carrier, output in stft domain
|
| 336 |
+
|
| 337 |
+
self.stft.num_samples = y.shape[2]
|
| 338 |
+
|
| 339 |
+
y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1)).data.cpu().numpy()[0, 0]
|
| 340 |
+
y = y * np.sqrt(original_power / (self.average_energy_VCTK)) # Noise has a power of 5% power of VCTK samples
|
| 341 |
+
if orig_sr != self.sr:
|
| 342 |
+
y = librosa.resample(y, orig_sr = self.sr, target_sr = orig_sr)
|
| 343 |
+
|
| 344 |
+
if calc_sdr:
|
| 345 |
+
sdr = self.sdr(orig_y, y)
|
| 346 |
+
else:
|
| 347 |
+
sdr = 0
|
| 348 |
+
|
| 349 |
+
y_watermarked_multi_channel.append(y[:, None])
|
| 350 |
+
sdrs.append(sdr)
|
| 351 |
+
|
| 352 |
+
y_watermarked_multi_channel = np.concatenate(y_watermarked_multi_channel, axis=1)
|
| 353 |
+
|
| 354 |
+
if single_channel:
|
| 355 |
+
y_watermarked_multi_channel = y_watermarked_multi_channel[:, 0]
|
| 356 |
+
sdrs = sdrs[0]
|
| 357 |
+
|
| 358 |
+
return y_watermarked_multi_channel, sdrs
|
| 359 |
+
|
| 360 |
+
def decode_wav(self, y_multi_channel, orig_sr, phase_shift_decoding):
|
| 361 |
+
"""
|
| 362 |
+
Decode the given audio waveform to extract hidden messages.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
y_multi_channel (numpy.ndarray): The multi-channel audio waveform.
|
| 366 |
+
orig_sr (int): The original sample rate of the audio waveform.
|
| 367 |
+
phase_shift_decoding (str): Flag indicating whether to perform phase shift decoding.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
dict or list: A list of dictionary containing the decoded messages, confidences, and status for each channel if the input is multi-channel.
|
| 371 |
+
Otherwise, a dictionary containing the decoded messages, confidences, and status for a single channel.
|
| 372 |
+
|
| 373 |
+
Raises:
|
| 374 |
+
Exception: If the decoding process fails.
|
| 375 |
+
|
| 376 |
+
"""
|
| 377 |
+
single_channel = False
|
| 378 |
+
if len(y_multi_channel.shape) == 1:
|
| 379 |
+
single_channel = True
|
| 380 |
+
y_multi_channel = y_multi_channel[:, None]
|
| 381 |
+
|
| 382 |
+
results = []
|
| 383 |
+
|
| 384 |
+
for channel_i in range(y_multi_channel.shape[1]):
|
| 385 |
+
y = y_multi_channel[:, channel_i]
|
| 386 |
+
try:
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
if orig_sr != self.sr:
|
| 389 |
+
y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
|
| 390 |
+
original_power = np.mean(y**2)
|
| 391 |
+
y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
|
| 392 |
+
if phase_shift_decoding and phase_shift_decoding != 'false':
|
| 393 |
+
ps = self.get_best_ps(y)
|
| 394 |
+
else:
|
| 395 |
+
ps = 0
|
| 396 |
+
y = torch.FloatTensor(y[ps:]).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 397 |
+
carrier, _ = self.stft.transform(y.squeeze(1))
|
| 398 |
+
carrier = carrier[:, None]
|
| 399 |
+
|
| 400 |
+
msg_reconst_list = []
|
| 401 |
+
confidence = []
|
| 402 |
+
|
| 403 |
+
for i in range(self.n_messages): # decode each msg_i using decoder_m_i
|
| 404 |
+
msg_reconst = self.dec_m[i](carrier)
|
| 405 |
+
pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
|
| 406 |
+
pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
|
| 407 |
+
pred_values = pred_values.reshape([-1, self.config.message_len])
|
| 408 |
+
|
| 409 |
+
ord_values = st.mode(pred_values, keepdims=False).mode
|
| 410 |
+
end_char = np.min(np.nonzero(ord_values == 0)[0])
|
| 411 |
+
confidence.append(self.get_confidence(pred_values, ord_values))
|
| 412 |
+
if end_char == self.config.message_len:
|
| 413 |
+
ord_values = ord_values[:self.config.message_len-1]
|
| 414 |
+
else:
|
| 415 |
+
ord_values = np.concatenate([ord_values[end_char+1:], ord_values[:end_char]], axis=0)
|
| 416 |
+
|
| 417 |
+
# pred_values = ''.join([chr(v + 64) for v in ord_values])
|
| 418 |
+
msg_reconst_list.append((ord_values - 1).tolist())
|
| 419 |
+
|
| 420 |
+
def convert_to_8_bit_segments(msg_list):
|
| 421 |
+
segment_message_list = []
|
| 422 |
+
for msg_list_i in msg_list:
|
| 423 |
+
binary_format = ''.join(['{0:02b}'.format(mes_i) for mes_i in msg_list_i])
|
| 424 |
+
eight_bit_segments = [int(binary_format[i*8:i*8+8], 2) for i in range(len(binary_format)//8)]
|
| 425 |
+
segment_message_list.append(eight_bit_segments)
|
| 426 |
+
return segment_message_list
|
| 427 |
+
msg_reconst_list = convert_to_8_bit_segments(msg_reconst_list)
|
| 428 |
+
|
| 429 |
+
results.append({'messages': msg_reconst_list, 'confidences': confidence, 'status': True})
|
| 430 |
+
except:
|
| 431 |
+
results.append({'messages': [], 'confidences': [], 'error': 'Could not find message', 'status': False})
|
| 432 |
+
|
| 433 |
+
if single_channel:
|
| 434 |
+
results = results[0]
|
| 435 |
+
|
| 436 |
+
return results
|
| 437 |
+
|
| 438 |
+
def convert_dataparallel_to_normal(self, checkpoint):
|
| 439 |
+
|
| 440 |
+
return {i[len('module.'):] if i.startswith('module.') else i: checkpoint[i] for i in checkpoint }
|
| 441 |
+
|
| 442 |
+
def load_models(self, ckpt_dir):
|
| 443 |
+
|
| 444 |
+
self.enc_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "enc_c.ckpt"), map_location=self.device)))
|
| 445 |
+
self.dec_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "dec_c.ckpt"), map_location=self.device)))
|
| 446 |
+
for i,m in enumerate(self.dec_m):
|
| 447 |
+
m.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, f"dec_m_{i}.ckpt"), map_location=self.device)))
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def get_model(model_type='44.1k', ckpt_path='../Models/44_1_khz/73999_iteration', config_path='../Models/44_1_khz/73999_iteration/hparams.yaml', device='cpu'):
|
| 451 |
+
|
| 452 |
+
if model_type == '44.1k':
|
| 453 |
+
if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
|
| 454 |
+
print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
|
| 455 |
+
from huggingface_hub import snapshot_download
|
| 456 |
+
folder_dir = snapshot_download(repo_id="sony/silentcipher")
|
| 457 |
+
ckpt_path = os.path.join(folder_dir, '44_1_khz/73999_iteration')
|
| 458 |
+
config_path = os.path.join(folder_dir, '44_1_khz/73999_iteration/hparams.yaml')
|
| 459 |
+
|
| 460 |
+
config = yaml.safe_load(open(config_path))
|
| 461 |
+
config = argparse.Namespace(**config)
|
| 462 |
+
config.load_ckpt = ckpt_path
|
| 463 |
+
model = Model(config, device)
|
| 464 |
+
elif model_type == '16k':
|
| 465 |
+
if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
|
| 466 |
+
print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
|
| 467 |
+
from huggingface_hub import snapshot_download
|
| 468 |
+
folder_dir = snapshot_download(repo_id="sony/silentcipher")
|
| 469 |
+
ckpt_path = os.path.join(folder_dir, '16_khz/97561_iteration')
|
| 470 |
+
config_path = os.path.join(folder_dir, '16_khz/97561_iteration/hparams.yaml')
|
| 471 |
+
|
| 472 |
+
config = yaml.safe_load(open(config_path))
|
| 473 |
+
config = argparse.Namespace(**config)
|
| 474 |
+
config.load_ckpt = ckpt_path
|
| 475 |
+
|
| 476 |
+
model = Model(config, device)
|
| 477 |
+
else:
|
| 478 |
+
print('Please specify a valid model_type [44.1k, 16k]')
|
| 479 |
+
|
| 480 |
return model
|