FrederikRautenberg's picture
Rename Sliders and refactor text input
d161181
import numpy as np
from pathlib import Path
import paderbox as pb
import torch
from onnxruntime import InferenceSession
from pvq_manipulation.models.vits import Vits_NT
from pvq_manipulation.models.ffjord import FFJORD
from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER
import librosa
from pvq_manipulation.helper.vad import EnergyVAD
import gradio as gr
from pvq_manipulation.helper.creapy_wrapper import process_file
from creapy.utils import config
import os
torch.set_num_threads(os.cpu_count() or 1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pvq_labels = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']
dataset_dict = pb.io.load_yaml('./Dataset/dataset.yaml')
cached_example_id = None
cached_loaded_example = None
cached_labels = None
cached_d_vector = None
cached_unmanipulated = None
cached_transcription = None
# path to stats
stats_path = Path('./Dataset/Embeddings/')
# load normalizing flow
storage_dir_normalizing_flow = Path("./models/norm_flow")
config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / "config.json")
normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint="model.pt", device=device)
# load tts model
storage_dir_tts = Path("./models/tts_model/")
tts_model = Vits_NT.load_model(storage_dir_tts, "model.pt")
config._CONFIG_DIR = "./pvq_manipulation/helper/creapy_config.yaml"
config._USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml"
config.USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml"
# load hubert features model
hubert_model = HubertExtractor(
layer=SID_LARGE_LAYER,
model_name="HUBERT_LARGE",
backend="torchaudio",
device=device,
# storage_dir= # target storage dir hubert model
)
# load pvq models
reg_stor_dir = Path('./models/pvq_extractor/')
onnx_sessions = {}
for pvq in pvq_labels:
onnx_path = reg_stor_dir / f"{pvq}.onnx"
onnx_sessions[pvq] = InferenceSession(
str(onnx_path),
providers=["CPUExecutionProvider"]
)
def get_manipulation(
example,
labels,
flow,
tts_model,
d_vector,
config_norm_flow,
manipulation_idx=0,
manipulation_fkt=1,
):
labels_manipulated = labels.clone()
labels_manipulated[:, manipulation_idx] += manipulation_fkt
if config_norm_flow['flag_remove_mean']:
global_mean = pb.io.load(stats_path / "mean.json")
global_mean = torch.tensor(global_mean, dtype=torch.float32)
speaker_embedding_norm = (d_vector - global_mean)
global_std = pb.io.load(stats_path / "std.json")
global_std = torch.tensor(global_std, dtype=torch.float32)
speaker_embedding_norm = speaker_embedding_norm / global_std
else:
speaker_embedding_norm = d_vector
output_forward = flow.forward((speaker_embedding_norm.float(), labels))[0]
sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]
if config_norm_flow['flag_remove_mean']:
sampled_class_manipulated = (sampled_class_manipulated * global_std + global_mean)
wav = tts_model.synthesize_from_example({
'text': example['transcription'],
'd_vector': d_vector.detach().numpy(),
'd_vector_man': sampled_class_manipulated.detach().numpy(),
'd_vector_storage_root': example['d_vector_storage_root'],
})
return wav
def get_creak_label(example):
audio_data = example['loaded_audio_data']['16_000']
test, y_pred, included_indices = process_file(audio_data)
mean_creak = np.mean(y_pred[included_indices])
return mean_creak * 100
def load_speaker_labels(example):
audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :]
num_samples = torch.tensor([audio_data.shape[-1]])
if torch.cuda.is_available():
audio_data = audio_data.cuda()
num_samples = num_samples.cuda()
with torch.no_grad():
features, seq_len = hubert_model(
audio_data,
16_000,
sequence_lengths=num_samples,
)
features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)
pvqd_predictions = {}
for pvq in pvq_labels:
sess = onnx_sessions[pvq]
pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
pvqd_predictions[pvq] = pred.tolist()[0]
pvqd_predictions['Creak_mean'] = get_creak_label(example)
labels = [pvqd_predictions[key] / 100 for key in pvq_labels + ["Creak_mean"]]
return torch.tensor(labels, device=device).float()
def load_audio_files(example):
observation_loaded, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)
example['loaded_audio_data'] = {}
observation = librosa.resample(observation_loaded, orig_sr=sr, target_sr=16_000)
vad = EnergyVAD(sample_rate=16_000)
if observation.ndim == 1:
observation = observation[None, :]
observation = vad({'audio_data': observation})['audio_data']
example['loaded_audio_data']['16_000'] = observation
observation = librosa.resample(observation, orig_sr=sr, target_sr=24_000)
vad = EnergyVAD(sample_rate=24_000)
if observation.ndim == 1:
observation = observation[None, :]
observation = vad({'audio_data': observation})['audio_data']
example['loaded_audio_data']['24_000'] = observation
return example
def delete_cache():
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, cached_unmanipulated
del cached_example_id
del cached_loaded_example
del cached_labels
del cached_d_vector
del cached_unmanipulated
def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt):
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated, cached_transcription
speaker_id = dataset_dict['dataset'][example_id]['speaker_id']
example = {
'audio_path': {'observation': f"./Dataset/Audio_files/{example_id}.wav"},
'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth",
'speaker_id': speaker_id,
'example_id': example_id,
'transcription': transcription
}
if cached_example_id != example_id:
delete_cache()
cached_loaded_example = load_audio_files(example)
cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth")
cached_labels = load_speaker_labels(example)
cached_example_id = example_id
with torch.no_grad():
cached_unmanipulated = tts_model.synthesize_from_example({
'text': transcription,
'd_vector': cached_d_vector.detach().numpy(),
})
cached_transcription = transcription
if cached_loaded_example != example or transcription != cached_transcription:
with torch.no_grad():
cached_unmanipulated = tts_model.synthesize_from_example({
'text': transcription,
'd_vector': cached_d_vector.detach().numpy(),
})
cached_transcription = transcription
with torch.no_grad():
wav_manipulated = get_manipulation(
example=example,
d_vector=cached_d_vector,
labels=cached_labels[None, :],
flow=normalizing_flow,
tts_model=tts_model,
manipulation_idx=manipulation_idx,
manipulation_fkt=manipulation_fkt,
config_norm_flow=config_norm_flow,
)
return (24_000, cached_unmanipulated), (24_000, wav_manipulated)
demo = gr.Interface(
title="Perceptual Voice Quality (PVQ) Manipulation",
fn=update_manipulation,
inputs=[
gr.Dropdown(
label="PVQ Feature",
choices=[('Weight', 0), ('Resonance', 1), ('Breathiness', 2), ('Roughness', 3), ('Creak', 7)],
value=2, type="value"
),
gr.Dropdown(
label="Speaker",
choices=[(str(idx), example_id) for idx, example_id in enumerate(dataset_dict['dataset'].keys())],
value="1422_149735_000006_000000",
type="value"
),
gr.Textbox(
label="Text Input",
value="Department of Communications Engineering Paderborn University.",
placeholder='Type something'
),
gr.Slider(label="Manipulation Intensity", minimum=-1.0, maximum=2.0, value=1.0, step=0.1),
],
outputs=[gr.Audio(label="original synthesized utterance"), gr.Audio(label="manipulated synthesized utterance")],
)
if __name__ == "__main__":
demo.launch(share=True)