multi-sami / syn_hifigan.py
katrihiovain
change in syn script
7b68d7e
import argparse
import models
import time
import sys
import warnings
#from pathlib import Path
import torch
import numpy as np
from scipy.stats import norm
from scipy.io.wavfile import write
from torch.nn.utils.rnn import pad_sequence
#import style_controller
from common.utils import load_wav_to_torch
from common import utils, layers
from common.text.text_processing import TextProcessing
import os
#os.environ["CUDA_VISIBLE_DEVICES"]=""
#device = "cuda:0"
device = "cpu"
vocoder = "hifigan"
SHARPEN = True
from hifigan.data_function import MAX_WAV_VALUE, mel_spectrogram
from hifigan.models import Denoiser
import json
from scipy import ndimage
import os
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('-i', '--input', type=str, required=False,
help='Full path to the input text (phareses separated by newlines)')
parser.add_argument('-o', '--output', default=None,
help='Output folder to save audio (file per phrase)')
parser.add_argument('--log-file', type=str, default=None,
help='Path to a DLLogger log file')
parser.add_argument('--save-mels', action='store_true', help='')
parser.add_argument('--cuda', action='store_true',
help='Run inference on a GPU using CUDA')
parser.add_argument('--cudnn-benchmark', action='store_true',
help='Enable cudnn benchmark mode')
#parser.add_argument('--fastpitch', type=str, default='output_smj_sander/FastPitch_checkpoint_660.pt',
#help='Full path to the generator checkpoint file (skip to use ground truth mels)') #########
parser.add_argument('--fastpitch', type=str, default='output_multilang/FastPitch_checkpoint_200.pt',
help='Full path to the generator checkpoint file (skip to use ground truth mels)') #########
parser.add_argument('-d', '--denoising-strength', default=0.01, type=float,
help='WaveGlow denoising')
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--stft-hop-length', type=int, default=256,
help='STFT hop length for estimating audio length from mel size')
parser.add_argument('--amp', action='store_true',default=False,
help='Inference with AMP')
parser.add_argument('-bs', '--batch-size', type=int, default=1)
parser.add_argument('--ema', action='store_true',
help='Use EMA averaged model (if saved in checkpoints)')
parser.add_argument('--speaker', type=int, default=0,
help='Speaker ID for a multi-speaker model')
parser.add_argument('--language', type=int, default=0,
help='Language ID for a multilingual model')
parser.add_argument('--p-arpabet', type=float, default=0.0, help='') ################
text_processing = parser.add_argument_group('Text processing parameters')
text_processing.add_argument('--text-cleaners', nargs='*',
default=['basic_cleaners'], type=str,
help='Type of text cleaners for input text')
text_processing.add_argument('--symbol-set', type=str, default='all_sami', #################
help='Define symbol set for input text')
cond = parser.add_argument_group('conditioning on additional attributes')
cond.add_argument('--n-speakers', type=int, default=10,
help='Number of speakers in the model.')
cond.add_argument('--n-languages', type=int, default=3,
help='Number of languages in the model.')
return parser
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_model_from_ckpt(checkpoint_path, ema, model):
checkpoint_data = torch.load(checkpoint_path,map_location = device)
status = ''
if 'state_dict' in checkpoint_data:
sd = checkpoint_data['state_dict']
if ema and 'ema_state_dict' in checkpoint_data:
sd = checkpoint_data['ema_state_dict']
status += ' (EMA)'
elif ema and not 'ema_state_dict' in checkpoint_data:
print(f'WARNING: EMA weights missing for {checkpoint_data}')
if any(key.startswith('module.') for key in sd):
sd = {k.replace('module.', ''): v for k,v in sd.items()}
status += ' ' + str(model.load_state_dict(sd, strict=False))
else:
model = checkpoint_data['model']
print(f'Loaded {checkpoint_path}{status}')
return model
def load_and_setup_model(model_name, parser, checkpoint, amp, device,
unk_args=[], forward_is_infer=False, ema=True,
jitable=False):
model_parser = models.parse_model_args(model_name, parser, add_help=False)
model_args, model_unk_args = model_parser.parse_known_args()
unk_args[:] = list(set(unk_args) & set(model_unk_args))
setattr(model_args, "energy_conditioning",True)
model_config = models.get_model_config(model_name, model_args)
# print(model_config)
model = models.get_model(model_name, model_config, device,
forward_is_infer=forward_is_infer,
jitable=jitable)
if checkpoint is not None:
model = load_model_from_ckpt(checkpoint, ema, model)
amp = False
if amp:
model.half()
model.eval()
return model.to(device)
class Synthesizer:
def _load_pyt_or_ts_model(self, model_name, ckpt_path, format = 'pyt'):
if format == 'ts':
model = models.load_and_setup_ts_model(model_name, ckpt_path,
False, device)
model_train_setup = {}
return model, model_train_setup
is_ts_based_infer = False
model, _, model_train_setup = models.load_and_setup_model(
model_name, self.parser, ckpt_path, False, device,
unk_args=self.unk_args, forward_is_infer=True, jitable=is_ts_based_infer)
if is_ts_based_infer:
model = torch.jit.script(model)
return model, model_train_setup
def __init__(self):
parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
allow_abbrev=False)
self.parser = parse_args(parser)
self.args, self.unk_args = self.parser.parse_known_args()
self.generator = load_and_setup_model(
'FastPitch', parser, self.args.fastpitch, self.args.amp, device,
unk_args=self.unk_args, forward_is_infer=True, ema=self.args.ema,
jitable=False)
self.hifigan_model = "pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt" # Better with Sander!
#self.hifigan_model = "pretrained_models/hifigan/hifigan_gen_checkpoint_6500.pt"
#self.vocoder = UnivNetModel.from_pretrained(model_name="tts_en_libritts_univnet")
self.vocoder, voc_train_setup= self._load_pyt_or_ts_model('HiFi-GAN', self.hifigan_model)
self.denoiser = Denoiser(self.vocoder,device=device) #, win_length=self.args.win_length).to(device)
self.tp = TextProcessing(self.args.symbol_set, self.args.text_cleaners, p_arpabet=0.0)
def unsharp_mask(self, img, radius=1, amount=1):
blurred = ndimage.gaussian_filter(img, radius)
sharpened = img + amount * ( img - blurred)
return sharpened
def speak(self, text, output_file="/tmp/tmp", spkr=0, lang=0, l_weight=1, s_weight=1, pace=0.95, clarity=1):
text = self.tp.encode_text(text)
#text = [9]+self.tp.encode_text(text)+[9]
text = torch.LongTensor([text]).to(device)
#probs = surprisals
for p in [0]:
with torch.no_grad():
print(s_weight, l_weight)
mel, mel_lens, *_ = self.generator(text, pace, max_duration=15, speaker=spkr, language=lang, speaker_weight=s_weight, language_weight=l_weight) #, ref_vector=embedding, speaker=speaker_i) #, **gen_kw, speaker 0 = bad audio, speaker 1 = better audio
if SHARPEN:
mel_np = mel.float().data.cpu().numpy()[0]
tgt_min = -11
tgt_max = 1.25
#print(np.min(mel_np), np.max(mel_np))
mel_np = self.unsharp_mask(mel_np, radius = 0.5, amount=0.5)
mel_np = self.unsharp_mask(mel_np, radius = 3, amount=.05)
# mel_np = self.unsharp_mask(mel_np, radius = 7, amount=0.05)
for i in range(0, 80):
mel_np[i,:]+=(i-30)*clarity*0.02
mel_np = (mel_np-np.min(mel_np))/ (np.max(mel_np)-np.min(mel_np)) * (tgt_max - tgt_min) + tgt_min
mel[0] = torch.from_numpy(mel_np).float().to(device)
"""
mel_np = mel.float().data.cpu().numpy()[0]
blurred_f = ndimage.gaussian_filter(mel_np, 1.0) #3
alpha = 0.2 #0.3 ta
mel_np = mel_np + alpha * (mel_np - blurred_f)
blurred_f = ndimage.gaussian_filter(mel_np, 3.0) #3
alpha = 0.1 # 0.1 ta
sharpened = mel_np + alpha * (mel_np - blurred_f)
for i in range(0,80):
sharpened[i, :]+=(i-40)*0.01 #0.01 ta
mel[0] = torch.from_numpy(sharpened).float().to(device)
"""
with torch.no_grad():
y_g_hat = self.vocoder(mel).float() ###########
#y_g_hat = self.denoiser(y_g_hat.squeeze(1), strength=0.01) #[:, 0]
audio = y_g_hat.squeeze()
# normalize volume
audio = audio/torch.max(torch.abs(audio))*0.95*32768
audio = audio.cpu().numpy().astype('int16')
write(output_file+".wav", 22050, audio)
os.system("play -q "+output_file+".wav")
return audio
if __name__ == '__main__':
syn = Synthesizer()
hifigan = syn.hifigan_model
hifigan_n = hifigan.replace(".pt", "")
fastpitch = syn.args.fastpitch
fastpitch_n = fastpitch.replace(".pt", "")
print(hifigan_n + " " + fastpitch_n)
hifigan_n_short = hifigan_n.split("/")
hifigan_n_shorter = hifigan_n_short[2].split("_")
hifigan_n_shortest = hifigan_n_shorter[3]
fastpitch_n_short = fastpitch_n.split("/")
fastpitch_n_shorter = fastpitch_n_short[1].split("_")
fastpitch_n_shortest = fastpitch_n_shorter[2]
#syn.speak("Gå lij riek mælggadav vádtsám, de bådij vijmak tjáppa vuobmáj.")
i = 0
spkr = 1
lang = 1
while (1==1):
text = input(">")
text1 = text.split(" ")
syn.speak(text, output_file="/tmp/tmp.wav", spkr=6, lang=1)
syn.speak(text, output_file="/tmp/tmp.wav", spkr=7, lang=1)
continue
for s in range(1,10):
for l in range(3): ##
print("speaker", s, "language", l) ##
syn.speak(text, output_file="/tmp/"+str(i)+"_"+text1[0]+"_"+str(s)+"_"+str(l)+"_FP_"+fastpitch_n_shortest+"univnet", spkr=s, lang=l)
#syn.speak(text, output_file="/home/hiovain/DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitchMulti/inf_output_multi/"+str(i)+"_"+text1[0]+"_"+str(s)+"_"+str(l)+"_FP_"+fastpitch_n_shortest+"univnet", spkr=s, lang=l)
i += 1