|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script generates speech with our pre-trained ZipVoice or |
|
ZipVoice-Distill models. If no local model is specified, |
|
Required files will be automatically downloaded from HuggingFace. |
|
|
|
Usage: |
|
|
|
Note: If you having trouble connecting to HuggingFace, |
|
try switching endpoint to mirror site: |
|
export HF_ENDPOINT=https://hf-mirror.com |
|
|
|
(1) Inference of a single sentence: |
|
|
|
python3 -m zipvoice.bin.infer_zipvoice \ |
|
--model-name "zipvoice" \ |
|
--prompt-wav prompt.wav \ |
|
--prompt-text "I am a prompt." \ |
|
--text "I am a sentence." \ |
|
--res-wav-path result.wav |
|
|
|
(2) Inference of a list of sentences: |
|
|
|
python3 -m zipvoice.bin.infer_zipvoice \ |
|
--model-name "zipvoice" \ |
|
--test-list test.tsv \ |
|
--res-dir results |
|
|
|
`--model-name` can be `zipvoice` or `zipvoice_distill`, |
|
which are the models before and after distillation, respectively. |
|
|
|
Each line of `test.tsv` is in the format of |
|
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. |
|
""" |
|
|
|
import argparse |
|
import datetime as dt |
|
import json |
|
import os |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import safetensors.torch |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import hf_hub_download |
|
from lhotse.utils import fix_random_seed |
|
from vocos import Vocos |
|
|
|
from zipvoice.models.zipvoice import ZipVoice |
|
from zipvoice.models.zipvoice_distill import ZipVoiceDistill |
|
from zipvoice.tokenizer.tokenizer import ( |
|
EmiliaTokenizer, |
|
EspeakTokenizer, |
|
LibriTTSTokenizer, |
|
SimpleTokenizer, |
|
) |
|
from zipvoice.utils.checkpoint import load_checkpoint |
|
from zipvoice.utils.common import AttributeDict |
|
from zipvoice.utils.feature import VocosFbank |
|
|
|
HUGGINGFACE_REPO = "k2-fsa/ZipVoice" |
|
PRETRAINED_MODEL = { |
|
"zipvoice": "zipvoice/model.pt", |
|
"zipvoice_distill": "zipvoice_distill/model.pt", |
|
} |
|
TOKEN_FILE = { |
|
"zipvoice": "zipvoice/tokens.txt", |
|
"zipvoice_distill": "zipvoice_distill/tokens.txt", |
|
} |
|
MODEL_CONFIG = { |
|
"zipvoice": "zipvoice/zipvoice_base.json", |
|
"zipvoice_distill": "zipvoice_distill/zipvoice_base.json", |
|
} |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--model-name", |
|
type=str, |
|
default="zipvoice", |
|
choices=["zipvoice", "zipvoice_distill"], |
|
help="The model used for inference", |
|
) |
|
|
|
parser.add_argument( |
|
"--checkpoint", |
|
type=str, |
|
default=None, |
|
help="The model checkpoint. " |
|
"Will download pre-trained checkpoint from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-config", |
|
type=str, |
|
default=None, |
|
help="The model configuration file. " |
|
"Will download zipvoice_base.json from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--vocoder-path", |
|
type=str, |
|
default=None, |
|
help="The vocoder checkpoint. " |
|
"Will download pre-trained vocoder from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--token-file", |
|
type=str, |
|
default=None, |
|
help="The file that contains information that maps tokens to ids," |
|
"which is a text file with '{token}\t{token_id}' per line. " |
|
"Will download tokens_emilia.txt from huggingface if not specified.", |
|
) |
|
|
|
parser.add_argument( |
|
"--tokenizer", |
|
type=str, |
|
default="emilia", |
|
choices=["emilia", "libritts", "espeak", "simple"], |
|
help="Tokenizer type.", |
|
) |
|
|
|
parser.add_argument( |
|
"--lang", |
|
type=str, |
|
default="en-us", |
|
help="Language identifier, used when tokenizer type is espeak. see" |
|
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md", |
|
) |
|
|
|
parser.add_argument( |
|
"--test-list", |
|
type=str, |
|
default=None, |
|
help="The list of prompt speech, prompt_transcription, " |
|
"and text to synthesizein the format of " |
|
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.", |
|
) |
|
|
|
parser.add_argument( |
|
"--prompt-wav", |
|
type=str, |
|
default=None, |
|
help="The prompt wav to mimic", |
|
) |
|
|
|
parser.add_argument( |
|
"--prompt-text", |
|
type=str, |
|
default=None, |
|
help="The transcription of the prompt wav", |
|
) |
|
|
|
parser.add_argument( |
|
"--text", |
|
type=str, |
|
default=None, |
|
help="The text to synthesize", |
|
) |
|
|
|
parser.add_argument( |
|
"--res-dir", |
|
type=str, |
|
default="results", |
|
help=""" |
|
Path name of the generated wavs dir, |
|
used when test-list is not None |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--res-wav-path", |
|
type=str, |
|
default="result.wav", |
|
help=""" |
|
Path name of the generated wav path, |
|
used when test-list is None |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--guidance-scale", |
|
type=float, |
|
default=None, |
|
help="The scale of classifier-free guidance during inference.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-step", |
|
type=int, |
|
default=None, |
|
help="The number of sampling steps.", |
|
) |
|
|
|
parser.add_argument( |
|
"--feat-scale", |
|
type=float, |
|
default=0.1, |
|
help="The scale factor of fbank feature", |
|
) |
|
|
|
parser.add_argument( |
|
"--speed", |
|
type=float, |
|
default=1.0, |
|
help="Control speech speed, 1.0 means normal, >1.0 means speed up", |
|
) |
|
|
|
parser.add_argument( |
|
"--t-shift", |
|
type=float, |
|
default=0.5, |
|
help="Shift t to smaller ones if t_shift < 1.0", |
|
) |
|
|
|
parser.add_argument( |
|
"--target-rms", |
|
type=float, |
|
default=0.1, |
|
help="Target speech normalization rms value, set to 0 to disable normalization", |
|
) |
|
|
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=666, |
|
help="Random seed", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def get_vocoder(vocos_local_path: Optional[str] = None): |
|
if vocos_local_path: |
|
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") |
|
state_dict = torch.load( |
|
f"{vocos_local_path}/pytorch_model.bin", |
|
weights_only=True, |
|
map_location="cpu", |
|
) |
|
vocoder.load_state_dict(state_dict) |
|
else: |
|
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") |
|
return vocoder |
|
|
|
|
|
def generate_sentence( |
|
save_path: str, |
|
prompt_text: str, |
|
prompt_wav: str, |
|
text: str, |
|
model: torch.nn.Module, |
|
vocoder: torch.nn.Module, |
|
tokenizer: EmiliaTokenizer, |
|
feature_extractor: VocosFbank, |
|
device: torch.device, |
|
num_step: int = 16, |
|
guidance_scale: float = 1.0, |
|
speed: float = 1.0, |
|
t_shift: float = 0.5, |
|
target_rms: float = 0.1, |
|
feat_scale: float = 0.1, |
|
sampling_rate: int = 24000, |
|
): |
|
""" |
|
Generate waveform of a text based on a given prompt |
|
waveform and its transcription. |
|
|
|
Args: |
|
save_path (str): Path to save the generated wav. |
|
prompt_text (str): Transcription of the prompt wav. |
|
prompt_wav (str): Path to the prompt wav file. |
|
text (str): Text to be synthesized into a waveform. |
|
model (torch.nn.Module): The model used for generation. |
|
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. |
|
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens. |
|
feature_extractor (VocosFbank): The feature extractor used to |
|
extract acoustic features. |
|
device (torch.device): The device on which computations are performed. |
|
num_step (int, optional): Number of steps for decoding. Defaults to 16. |
|
guidance_scale (float, optional): Scale for classifier-free guidance. |
|
Defaults to 1.0. |
|
speed (float, optional): Speed control. Defaults to 1.0. |
|
t_shift (float, optional): Time shift. Defaults to 0.5. |
|
target_rms (float, optional): Target RMS for waveform normalization. |
|
Defaults to 0.1. |
|
feat_scale (float, optional): Scale for features. |
|
Defaults to 0.1. |
|
sampling_rate (int, optional): Sampling rate for the waveform. |
|
Defaults to 24000. |
|
Returns: |
|
metrics (dict): Dictionary containing time and real-time |
|
factor metrics for processing. |
|
""" |
|
|
|
tokens = tokenizer.texts_to_token_ids([text]) |
|
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) |
|
|
|
|
|
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) |
|
|
|
if prompt_sampling_rate != sampling_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=prompt_sampling_rate, new_freq=sampling_rate |
|
) |
|
prompt_wav = resampler(prompt_wav) |
|
|
|
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) |
|
if prompt_rms < target_rms: |
|
prompt_wav = prompt_wav * target_rms / prompt_rms |
|
|
|
|
|
prompt_features = feature_extractor.extract( |
|
prompt_wav, sampling_rate=sampling_rate |
|
).to(device) |
|
|
|
prompt_features = prompt_features.unsqueeze(0) * feat_scale |
|
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) |
|
|
|
|
|
start_t = dt.datetime.now() |
|
|
|
|
|
( |
|
pred_features, |
|
pred_features_lens, |
|
pred_prompt_features, |
|
pred_prompt_features_lens, |
|
) = model.sample( |
|
tokens=tokens, |
|
prompt_tokens=prompt_tokens, |
|
prompt_features=prompt_features, |
|
prompt_features_lens=prompt_features_lens, |
|
speed=speed, |
|
t_shift=t_shift, |
|
duration="predict", |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
) |
|
|
|
|
|
pred_features = pred_features.permute(0, 2, 1) / feat_scale |
|
|
|
|
|
start_vocoder_t = dt.datetime.now() |
|
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) |
|
|
|
|
|
t = (dt.datetime.now() - start_t).total_seconds() |
|
t_no_vocoder = (start_vocoder_t - start_t).total_seconds() |
|
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() |
|
wav_seconds = wav.shape[-1] / sampling_rate |
|
rtf = t / wav_seconds |
|
rtf_no_vocoder = t_no_vocoder / wav_seconds |
|
rtf_vocoder = t_vocoder / wav_seconds |
|
metrics = { |
|
"t": t, |
|
"t_no_vocoder": t_no_vocoder, |
|
"t_vocoder": t_vocoder, |
|
"wav_seconds": wav_seconds, |
|
"rtf": rtf, |
|
"rtf_no_vocoder": rtf_no_vocoder, |
|
"rtf_vocoder": rtf_vocoder, |
|
} |
|
|
|
|
|
if prompt_rms < target_rms: |
|
wav = wav * prompt_rms / target_rms |
|
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) |
|
|
|
return metrics |
|
|
|
|
|
def generate_list( |
|
res_dir: str, |
|
test_list: str, |
|
model: torch.nn.Module, |
|
vocoder: torch.nn.Module, |
|
tokenizer: EmiliaTokenizer, |
|
feature_extractor: VocosFbank, |
|
device: torch.device, |
|
num_step: int = 16, |
|
guidance_scale: float = 1.0, |
|
speed: float = 1.0, |
|
t_shift: float = 0.5, |
|
target_rms: float = 0.1, |
|
feat_scale: float = 0.1, |
|
sampling_rate: int = 24000, |
|
): |
|
total_t = [] |
|
total_t_no_vocoder = [] |
|
total_t_vocoder = [] |
|
total_wav_seconds = [] |
|
|
|
with open(test_list, "r") as fr: |
|
lines = fr.readlines() |
|
|
|
for i, line in enumerate(lines): |
|
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") |
|
save_path = f"{res_dir}/{wav_name}.wav" |
|
metrics = generate_sentence( |
|
save_path=save_path, |
|
prompt_text=prompt_text, |
|
prompt_wav=prompt_wav, |
|
text=text, |
|
model=model, |
|
vocoder=vocoder, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
device=device, |
|
num_step=num_step, |
|
guidance_scale=guidance_scale, |
|
speed=speed, |
|
t_shift=t_shift, |
|
target_rms=target_rms, |
|
feat_scale=feat_scale, |
|
sampling_rate=sampling_rate, |
|
) |
|
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") |
|
total_t.append(metrics["t"]) |
|
total_t_no_vocoder.append(metrics["t_no_vocoder"]) |
|
total_t_vocoder.append(metrics["t_vocoder"]) |
|
total_wav_seconds.append(metrics["wav_seconds"]) |
|
|
|
print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}") |
|
print( |
|
f"Average RTF w/o vocoder: " |
|
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}" |
|
) |
|
print( |
|
f"Average RTF vocoder: " |
|
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}" |
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
def main(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
params = AttributeDict() |
|
params.update(vars(args)) |
|
fix_random_seed(params.seed) |
|
|
|
model_defaults = { |
|
"zipvoice": { |
|
"num_step": 16, |
|
"guidance_scale": 1.0, |
|
}, |
|
"zipvoice_distill": { |
|
"num_step": 8, |
|
"guidance_scale": 3.0, |
|
}, |
|
} |
|
|
|
model_specific_defaults = model_defaults.get(params.model_name, {}) |
|
|
|
for param, value in model_specific_defaults.items(): |
|
if getattr(params, param) is None: |
|
setattr(params, param, value) |
|
print(f"Setting {param} to default value: {value}") |
|
|
|
assert (params.test_list is not None) ^ ( |
|
(params.prompt_wav and params.prompt_text and params.text) is not None |
|
), ( |
|
"For inference, please provide prompts and text with either '--test-list'" |
|
" or '--prompt-wav, --prompt-text and --text'." |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
params.device = torch.device("cuda", 0) |
|
elif torch.backends.mps.is_available(): |
|
params.device = torch.device("mps") |
|
else: |
|
params.device = torch.device("cpu") |
|
|
|
print("Loading model...") |
|
if params.model_config is None: |
|
model_config = hf_hub_download( |
|
HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name] |
|
) |
|
else: |
|
model_config = params.model_config |
|
|
|
with open(model_config, "r") as f: |
|
model_config = json.load(f) |
|
|
|
if params.token_file is None: |
|
token_file = hf_hub_download( |
|
HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name] |
|
) |
|
else: |
|
token_file = params.token_file |
|
|
|
if params.tokenizer == "emilia": |
|
tokenizer = EmiliaTokenizer(token_file=token_file) |
|
elif params.tokenizer == "libritts": |
|
tokenizer = LibriTTSTokenizer(token_file=token_file) |
|
elif params.tokenizer == "espeak": |
|
tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang) |
|
else: |
|
assert params.tokenizer == "simple" |
|
tokenizer = SimpleTokenizer(token_file=token_file) |
|
|
|
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} |
|
|
|
if params.checkpoint is None: |
|
model_ckpt = hf_hub_download( |
|
HUGGINGFACE_REPO, |
|
filename=PRETRAINED_MODEL[params.model_name], |
|
) |
|
else: |
|
model_ckpt = params.checkpoint |
|
|
|
if params.model_name == "zipvoice": |
|
model = ZipVoice( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
else: |
|
assert params.model_name == "zipvoice_distill" |
|
model = ZipVoiceDistill( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
|
|
if model_ckpt.endswith(".safetensors"): |
|
safetensors.torch.load_model(model, model_ckpt) |
|
elif model_ckpt.endswith(".pt"): |
|
load_checkpoint(filename=model_ckpt, model=model, strict=True) |
|
else: |
|
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}") |
|
|
|
model = model.to(params.device) |
|
model.eval() |
|
|
|
vocoder = get_vocoder(params.vocoder_path) |
|
vocoder = vocoder.to(params.device) |
|
vocoder.eval() |
|
|
|
if model_config["feature"]["type"] == "vocos": |
|
feature_extractor = VocosFbank() |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupported feature type: {model_config['feature']['type']}" |
|
) |
|
params.sampling_rate = model_config["feature"]["sampling_rate"] |
|
|
|
print("Start generating...") |
|
if params.test_list: |
|
os.makedirs(params.res_dir, exist_ok=True) |
|
generate_list( |
|
res_dir=params.res_dir, |
|
test_list=params.test_list, |
|
model=model, |
|
vocoder=vocoder, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
device=params.device, |
|
num_step=params.num_step, |
|
guidance_scale=params.guidance_scale, |
|
speed=params.speed, |
|
t_shift=params.t_shift, |
|
target_rms=params.target_rms, |
|
feat_scale=params.feat_scale, |
|
sampling_rate=params.sampling_rate, |
|
) |
|
else: |
|
generate_sentence( |
|
save_path=params.res_wav_path, |
|
prompt_text=params.prompt_text, |
|
prompt_wav=params.prompt_wav, |
|
text=params.text, |
|
model=model, |
|
vocoder=vocoder, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
device=params.device, |
|
num_step=params.num_step, |
|
guidance_scale=params.guidance_scale, |
|
speed=params.speed, |
|
t_shift=params.t_shift, |
|
target_rms=params.target_rms, |
|
feat_scale=params.feat_scale, |
|
sampling_rate=params.sampling_rate, |
|
) |
|
print("Done") |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.set_num_threads(1) |
|
torch.set_num_interop_threads(1) |
|
main() |
|
|