#!/usr/bin/env python3 # Copyright 2025 Xiaomi Corp. (authors: Han Zhu) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script generates speech with our pre-trained ZipVoice-Dialog or ZipVoice-Dialog-Stereo models. If no local model is specified, Required files will be automatically downloaded from HuggingFace. Usage: Note: If you having trouble connecting to HuggingFace, try switching endpoint to mirror site: export HF_ENDPOINT=https://hf-mirror.com python3 -m zipvoice.bin.infer_zipvoice_dialog \ --model-name "zipvoice_dialog" \ --test-list test.tsv \ --res-dir results `--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`, which generate mono and stereo dialogues, respectively. Each line of `test.tsv` is in the format of merged conversation: '{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' or splited conversation: '{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription} \t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}' """ import argparse import datetime as dt import json import os from typing import List, Optional, Union import numpy as np import safetensors.torch import torch import torchaudio from huggingface_hub import hf_hub_download from lhotse.utils import fix_random_seed from vocos import Vocos from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo from zipvoice.tokenizer.tokenizer import DialogTokenizer from zipvoice.utils.checkpoint import load_checkpoint from zipvoice.utils.common import AttributeDict from zipvoice.utils.feature import VocosFbank HUGGINGFACE_REPO = "k2-fsa/ZipVoice" PRETRAINED_MODEL = { "zipvoice_dialog": "zipvoice_dialog/model.pt", "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.pt", } TOKEN_FILE = { "zipvoice_dialog": "zipvoice_dialog/tokens.txt", "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/tokens.txt", } MODEL_CONFIG = { "zipvoice_dialog": "zipvoice_dialog/model.json", "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.json", } def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--model-name", type=str, default="zipvoice_dialog", choices=["zipvoice_dialog", "zipvoice_dialog_stereo"], help="The model used for inference", ) parser.add_argument( "--checkpoint", type=str, default=None, help="The model checkpoint. " "Will download pre-trained checkpoint from huggingface if not specified.", ) parser.add_argument( "--model-config", type=str, default=None, help="The model configuration file. " "Will download model.json from huggingface if not specified.", ) parser.add_argument( "--vocoder-path", type=str, default=None, help="The vocoder checkpoint. " "Will download pre-trained vocoder from huggingface if not specified.", ) parser.add_argument( "--token-file", type=str, default=None, help="The file that contains information that maps tokens to ids," "which is a text file with '{token}\t{token_id}' per line. " "Will download tokens_emilia.txt from huggingface if not specified.", ) parser.add_argument( "--test-list", type=str, default=None, help="The list of prompt speech, prompt_transcription, " "and text to synthesizein the format of merged conversation: " "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' " "or splited conversation: " "'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}" "\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.", ) parser.add_argument( "--res-dir", type=str, default="results", help=""" Path name of the generated wavs dir, used when test-list is not None """, ) parser.add_argument( "--guidance-scale", type=float, default=1.5, help="The scale of classifier-free guidance during inference.", ) parser.add_argument( "--num-step", type=int, default=16, help="The number of sampling steps.", ) parser.add_argument( "--feat-scale", type=float, default=0.1, help="The scale factor of fbank feature", ) parser.add_argument( "--speed", type=float, default=1.0, help="Control speech speed, 1.0 means normal, >1.0 means speed up", ) parser.add_argument( "--t-shift", type=float, default=0.5, help="Shift t to smaller ones if t_shift < 1.0", ) parser.add_argument( "--target-rms", type=float, default=0.1, help="Target speech normalization rms value, set to 0 to disable normalization", ) parser.add_argument( "--seed", type=int, default=666, help="Random seed", ) parser.add_argument( "--silence-wav", type=str, default="assets/silence.wav", help="Path of the silence wav file, used in two-channel generation " "with single-channel prompts", ) return parser def get_vocoder(vocos_local_path: Optional[str] = None): if vocos_local_path: vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") state_dict = torch.load( f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location="cpu", ) vocoder.load_state_dict(state_dict) else: vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") return vocoder def generate_sentence( save_path: str, prompt_text: str, prompt_wav: Union[str, List[str]], text: str, model: torch.nn.Module, vocoder: torch.nn.Module, tokenizer: DialogTokenizer, feature_extractor: VocosFbank, device: torch.device, num_step: int = 16, guidance_scale: float = 1.0, speed: float = 1.0, t_shift: float = 0.5, target_rms: float = 0.1, feat_scale: float = 0.1, sampling_rate: int = 24000, ): """ Generate waveform of a text based on a given prompt waveform and its transcription. Args: save_path (str): Path to save the generated wav. prompt_text (str): Transcription of the prompt wav. prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be one or two wav files, which corresponding to a merged conversational speech or two seperate speaker's speech. text (str): Text to be synthesized into a waveform. model (torch.nn.Module): The model used for generation. vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens. feature_extractor (VocosFbank): The feature extractor used to extract acoustic features. device (torch.device): The device on which computations are performed. num_step (int, optional): Number of steps for decoding. Defaults to 16. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 1.0. speed (float, optional): Speed control. Defaults to 1.0. t_shift (float, optional): Time shift. Defaults to 0.5. target_rms (float, optional): Target RMS for waveform normalization. Defaults to 0.1. feat_scale (float, optional): Scale for features. Defaults to 0.1. sampling_rate (int, optional): Sampling rate for the waveform. Defaults to 24000. Returns: metrics (dict): Dictionary containing time and real-time factor metrics for processing. """ # Convert text to tokens tokens = tokenizer.texts_to_token_ids([text]) prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) # Load and preprocess prompt wav if isinstance(prompt_wav, str): prompt_wav = [ prompt_wav, ] else: assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str) loaded_prompt_wavs = prompt_wav for i in range(len(prompt_wav)): loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i]) if prompt_sampling_rate != sampling_rate: resampler = torchaudio.transforms.Resample( orig_freq=prompt_sampling_rate, new_freq=sampling_rate ) loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i]) if len(loaded_prompt_wavs) == 1: prompt_wav = loaded_prompt_wavs[0] else: prompt_wav = torch.cat(loaded_prompt_wavs, dim=1) prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) if prompt_rms < target_rms: prompt_wav = prompt_wav * target_rms / prompt_rms # Extract features from prompt wav prompt_features = feature_extractor.extract( prompt_wav, sampling_rate=sampling_rate ).to(device) prompt_features = prompt_features.unsqueeze(0) * feat_scale prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) # Start timing start_t = dt.datetime.now() # Generate features ( pred_features, pred_features_lens, pred_prompt_features, pred_prompt_features_lens, ) = model.sample( tokens=tokens, prompt_tokens=prompt_tokens, prompt_features=prompt_features, prompt_features_lens=prompt_features_lens, speed=speed, t_shift=t_shift, duration="predict", num_step=num_step, guidance_scale=guidance_scale, ) # Postprocess predicted features pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) # Start vocoder processing start_vocoder_t = dt.datetime.now() wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) # Calculate processing times and real-time factors t = (dt.datetime.now() - start_t).total_seconds() t_no_vocoder = (start_vocoder_t - start_t).total_seconds() t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() wav_seconds = wav.shape[-1] / sampling_rate rtf = t / wav_seconds rtf_no_vocoder = t_no_vocoder / wav_seconds rtf_vocoder = t_vocoder / wav_seconds metrics = { "t": t, "t_no_vocoder": t_no_vocoder, "t_vocoder": t_vocoder, "wav_seconds": wav_seconds, "rtf": rtf, "rtf_no_vocoder": rtf_no_vocoder, "rtf_vocoder": rtf_vocoder, } # Adjust wav volume if necessary if prompt_rms < target_rms: wav = wav * prompt_rms / target_rms torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) return metrics def generate_sentence_stereo( save_path: str, prompt_text: str, prompt_wav: Union[str, List[str]], text: str, model: torch.nn.Module, vocoder: torch.nn.Module, tokenizer: DialogTokenizer, feature_extractor: VocosFbank, device: torch.device, num_step: int = 16, guidance_scale: float = 1.0, speed: float = 1.0, t_shift: float = 0.5, target_rms: float = 0.1, feat_scale: float = 0.1, sampling_rate: int = 24000, silence_wav: Optional[str] = None, ): """ Generate waveform of a text based on a given prompt waveform and its transcription. Args: save_path (str): Path to save the generated wav. prompt_text (str): Transcription of the prompt wav. prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be one or two wav files, which corresponding to a merged conversational speech or two seperate speaker's speech. text (str): Text to be synthesized into a waveform. model (torch.nn.Module): The model used for generation. vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens. feature_extractor (VocosFbank): The feature extractor used to extract acoustic features. device (torch.device): The device on which computations are performed. num_step (int, optional): Number of steps for decoding. Defaults to 16. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 1.0. speed (float, optional): Speed control. Defaults to 1.0. t_shift (float, optional): Time shift. Defaults to 0.5. target_rms (float, optional): Target RMS for waveform normalization. Defaults to 0.1. feat_scale (float, optional): Scale for features. Defaults to 0.1. sampling_rate (int, optional): Sampling rate for the waveform. Defaults to 24000. silence_wav (str): Path of the silence wav file, used in two-channel generation with single-channel prompts Returns: metrics (dict): Dictionary containing time and real-time factor metrics for processing. """ # Convert text to tokens tokens = tokenizer.texts_to_token_ids([text]) prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) # Load and preprocess prompt wav if isinstance(prompt_wav, str): prompt_wav = [ prompt_wav, ] else: assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str) loaded_prompt_wavs = prompt_wav for i in range(len(prompt_wav)): loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i]) if prompt_sampling_rate != sampling_rate: resampler = torchaudio.transforms.Resample( orig_freq=prompt_sampling_rate, new_freq=sampling_rate ) loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i]) if len(loaded_prompt_wavs) == 1: assert ( loaded_prompt_wavs[0].size(0) == 2 ), "Merged prompt wav must be stereo for stereo dialogue generation" prompt_wav = loaded_prompt_wavs[0] else: assert len(loaded_prompt_wavs) == 2 if loaded_prompt_wavs[0].size(0) == 2: prompt_wav = torch.cat(loaded_prompt_wavs, dim=1) else: assert loaded_prompt_wavs[0].size(0) == 1 silence_wav, silence_sampling_rate = torchaudio.load(silence_wav) assert silence_sampling_rate == sampling_rate prompt_wav = silence_wav[ :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1) ] prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0] prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1] prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) if prompt_rms < target_rms: prompt_wav = prompt_wav * target_rms / prompt_rms # Extract features from prompt wav prompt_features = feature_extractor.extract( prompt_wav, sampling_rate=sampling_rate ).to(device) prompt_features = prompt_features.unsqueeze(0) * feat_scale prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) # Start timing start_t = dt.datetime.now() # Generate features ( pred_features, pred_features_lens, pred_prompt_features, pred_prompt_features_lens, ) = model.sample( tokens=tokens, prompt_tokens=prompt_tokens, prompt_features=prompt_features, prompt_features_lens=prompt_features_lens, speed=speed, t_shift=t_shift, duration="predict", num_step=num_step, guidance_scale=guidance_scale, ) # Postprocess predicted features pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) # Start vocoder processing start_vocoder_t = dt.datetime.now() feat_dim = pred_features.size(1) // 2 wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1) wav_right = ( vocoder.decode(pred_features[:, feat_dim : feat_dim * 2]) .squeeze(1) .clamp(-1, 1) ) wav = torch.cat([wav_left, wav_right], dim=0) # Calculate processing times and real-time factors t = (dt.datetime.now() - start_t).total_seconds() t_no_vocoder = (start_vocoder_t - start_t).total_seconds() t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() wav_seconds = wav.shape[-1] / sampling_rate rtf = t / wav_seconds rtf_no_vocoder = t_no_vocoder / wav_seconds rtf_vocoder = t_vocoder / wav_seconds metrics = { "t": t, "t_no_vocoder": t_no_vocoder, "t_vocoder": t_vocoder, "wav_seconds": wav_seconds, "rtf": rtf, "rtf_no_vocoder": rtf_no_vocoder, "rtf_vocoder": rtf_vocoder, } # Adjust wav volume if necessary if prompt_rms < target_rms: wav = wav * prompt_rms / target_rms torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) return metrics def generate_list( model_name: str, res_dir: str, test_list: str, model: torch.nn.Module, vocoder: torch.nn.Module, tokenizer: DialogTokenizer, feature_extractor: VocosFbank, device: torch.device, num_step: int = 16, guidance_scale: float = 1.5, speed: float = 1.0, t_shift: float = 0.5, target_rms: float = 0.1, feat_scale: float = 0.1, sampling_rate: int = 24000, silence_wav: Optional[str] = None, ): total_t = [] total_t_no_vocoder = [] total_t_vocoder = [] total_wav_seconds = [] with open(test_list, "r") as fr: lines = fr.readlines() for i, line in enumerate(lines): items = line.strip().split("\t") if len(items) == 6: ( wav_name, prompt_text_1, prompt_text_2, prompt_wav_1, prompt_wav_2, text, ) = items prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}" prompt_wav = [prompt_wav_1, prompt_wav_2] elif len(items) == 4: wav_name, prompt_text, prompt_wav, text = items else: raise ValueError(f"Invalid line: {line}") assert text.startswith("[S1]") save_path = f"{res_dir}/{wav_name}.wav" if model_name == "zipvoice_dialog": metrics = generate_sentence( save_path=save_path, prompt_text=prompt_text, prompt_wav=prompt_wav, text=text, model=model, vocoder=vocoder, tokenizer=tokenizer, feature_extractor=feature_extractor, device=device, num_step=num_step, guidance_scale=guidance_scale, speed=speed, t_shift=t_shift, target_rms=target_rms, feat_scale=feat_scale, sampling_rate=sampling_rate, ) else: assert model_name == "zipvoice_dialog_stereo" metrics = generate_sentence_stereo( save_path=save_path, prompt_text=prompt_text, prompt_wav=prompt_wav, text=text, model=model, vocoder=vocoder, tokenizer=tokenizer, feature_extractor=feature_extractor, device=device, num_step=num_step, guidance_scale=guidance_scale, speed=speed, t_shift=t_shift, target_rms=target_rms, feat_scale=feat_scale, sampling_rate=sampling_rate, silence_wav=silence_wav, ) print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") total_t.append(metrics["t"]) total_t_no_vocoder.append(metrics["t_no_vocoder"]) total_t_vocoder.append(metrics["t_vocoder"]) total_wav_seconds.append(metrics["wav_seconds"]) print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}") print( f"Average RTF w/o vocoder: " f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}" ) print( f"Average RTF vocoder: " f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}" ) @torch.inference_mode() def main(): parser = get_parser() args = parser.parse_args() params = AttributeDict() params.update(vars(args)) fix_random_seed(params.seed) assert ( params.test_list is not None ), "For inference, please provide prompts and text with '--test-list'" if torch.cuda.is_available(): params.device = torch.device("cuda", 0) elif torch.backends.mps.is_available(): params.device = torch.device("mps") else: params.device = torch.device("cpu") print("Loading model...") if params.model_config is None: model_config = hf_hub_download( HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name] ) else: model_config = params.model_config with open(model_config, "r") as f: model_config = json.load(f) if params.token_file is None: token_file = hf_hub_download( HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name] ) else: token_file = params.token_file tokenizer = DialogTokenizer(token_file=token_file) tokenizer_config = { "vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id, "spk_a_id": tokenizer.spk_a_id, "spk_b_id": tokenizer.spk_b_id, } if params.checkpoint is None: model_ckpt = hf_hub_download( HUGGINGFACE_REPO, filename=PRETRAINED_MODEL[params.model_name], ) else: model_ckpt = params.checkpoint if params.model_name == "zipvoice_dialog": model = ZipVoiceDialog( **model_config["model"], **tokenizer_config, ) else: assert params.model_name == "zipvoice_dialog_stereo" model = ZipVoiceDialogStereo( **model_config["model"], **tokenizer_config, ) if model_ckpt.endswith(".safetensors"): safetensors.torch.load_model(model, model_ckpt) elif model_ckpt.endswith(".pt"): load_checkpoint(filename=model_ckpt, model=model, strict=True) else: raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}") model = model.to(params.device) model.eval() vocoder = get_vocoder(params.vocoder_path) vocoder = vocoder.to(params.device) vocoder.eval() if model_config["feature"]["type"] == "vocos": if params.model_name == "zipvoice_dialog": num_channels = 1 else: assert params.model_name == "zipvoice_dialog_stereo" num_channels = 2 feature_extractor = VocosFbank(num_channels=num_channels) else: raise NotImplementedError( f"Unsupported feature type: {model_config['feature']['type']}" ) params.sampling_rate = model_config["feature"]["sampling_rate"] print("Start generating...") os.makedirs(params.res_dir, exist_ok=True) generate_list( model_name=params.model_name, res_dir=params.res_dir, test_list=params.test_list, model=model, vocoder=vocoder, tokenizer=tokenizer, feature_extractor=feature_extractor, device=params.device, num_step=params.num_step, guidance_scale=params.guidance_scale, speed=params.speed, t_shift=params.t_shift, target_rms=params.target_rms, feat_scale=params.feat_scale, sampling_rate=params.sampling_rate, silence_wav=params.silence_wav, ) print("Done") if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) main()