| | import os |
| | import torch |
| | import gdown |
| | import logging |
| | import langid |
| | langid.set_languages(['en', 'zh', 'ja']) |
| |
|
| | import pathlib |
| | import platform |
| | if platform.system().lower() == 'windows': |
| | temp = pathlib.PosixPath |
| | pathlib.PosixPath = pathlib.WindowsPath |
| | elif platform.system().lower() == 'linux': |
| | temp = pathlib.WindowsPath |
| | pathlib.WindowsPath = pathlib.PosixPath |
| |
|
| | import numpy as np |
| | from data.tokenizer import ( |
| | AudioTokenizer, |
| | tokenize_audio, |
| | ) |
| | from data.collation import get_text_token_collater |
| | from models.vallex import VALLE |
| | from utils.g2p import PhonemeBpeTokenizer |
| | from utils.sentence_cutter import split_text_into_sentences |
| |
|
| | from macros import * |
| |
|
| | device = torch.device("cpu") |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda", 0) |
| |
|
| | url = 'https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing' |
| |
|
| | checkpoints_dir = "./checkpoints/" |
| |
|
| | model_checkpoint_name = "vallex-checkpoint.pt" |
| |
|
| | model = None |
| |
|
| | codec = None |
| |
|
| | text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") |
| | text_collater = get_text_token_collater() |
| |
|
| | def preload_models(): |
| | global model, codec |
| | if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) |
| | if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): |
| | gdown.download(id="10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl", output=os.path.join(checkpoints_dir, model_checkpoint_name), quiet=False) |
| | |
| | model = VALLE( |
| | N_DIM, |
| | NUM_HEAD, |
| | NUM_LAYERS, |
| | norm_first=True, |
| | add_prenet=False, |
| | prefix_mode=PREFIX_MODE, |
| | share_embedding=True, |
| | nar_scale_factor=1.0, |
| | prepend_bos=True, |
| | num_quantizers=NUM_QUANTIZERS, |
| | ).to(device) |
| | checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') |
| | missing_keys, unexpected_keys = model.load_state_dict( |
| | checkpoint["model"], strict=True |
| | ) |
| | assert not missing_keys |
| | model.eval() |
| |
|
| | |
| | codec = AudioTokenizer(device) |
| |
|
| | @torch.no_grad() |
| | def generate_audio(text, prompt=None, language='auto', accent='no-accent'): |
| | global model, codec, text_tokenizer, text_collater |
| | text = text.replace("\n", "").strip(" ") |
| | |
| | if language == "auto": |
| | language = langid.classify(text)[0] |
| | lang_token = lang2token[language] |
| | lang = token2lang[lang_token] |
| | text = lang_token + text + lang_token |
| |
|
| | |
| | if prompt is not None: |
| | prompt_path = prompt |
| | if not os.path.exists(prompt_path): |
| | prompt_path = "./presets/" + prompt + ".npz" |
| | if not os.path.exists(prompt_path): |
| | prompt_path = "./customs/" + prompt + ".npz" |
| | if not os.path.exists(prompt_path): |
| | raise ValueError(f"Cannot find prompt {prompt}") |
| | prompt_data = np.load(prompt_path) |
| | audio_prompts = prompt_data['audio_tokens'] |
| | text_prompts = prompt_data['text_tokens'] |
| | lang_pr = prompt_data['lang_code'] |
| | lang_pr = code2lang[int(lang_pr)] |
| |
|
| | |
| | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) |
| | text_prompts = torch.tensor(text_prompts).type(torch.int32) |
| | else: |
| | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) |
| | text_prompts = torch.zeros([1, 0]).type(torch.int32) |
| | lang_pr = lang if lang != 'mix' else 'en' |
| |
|
| | enroll_x_lens = text_prompts.shape[-1] |
| | logging.info(f"synthesize text: {text}") |
| | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
| | text_tokens, text_tokens_lens = text_collater( |
| | [ |
| | phone_tokens |
| | ] |
| | ) |
| | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
| | text_tokens_lens += enroll_x_lens |
| | |
| | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
| | encoded_frames = model.inference( |
| | text_tokens.to(device), |
| | text_tokens_lens.to(device), |
| | audio_prompts, |
| | enroll_x_lens=enroll_x_lens, |
| | top_k=-100, |
| | temperature=1, |
| | prompt_language=lang_pr, |
| | text_language=langs if accent == "no-accent" else lang, |
| | ) |
| | samples = codec.decode( |
| | [(encoded_frames.transpose(2, 1), None)] |
| | ) |
| |
|
| | return samples[0][0].cpu().numpy() |
| |
|
| | @torch.no_grad() |
| | def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'): |
| | """ |
| | For long audio generation, two modes are available. |
| | fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence. |
| | sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance. |
| | """ |
| | global model, codec, text_tokenizer, text_collater |
| | if prompt is None or prompt == "": |
| | mode = 'sliding-window' |
| | sentences = split_text_into_sentences(text) |
| | |
| | if language == "auto": |
| | language = langid.classify(text)[0] |
| |
|
| | |
| | if prompt is not None and prompt != "": |
| | prompt_path = prompt |
| | if not os.path.exists(prompt_path): |
| | prompt_path = "./presets/" + prompt + ".npz" |
| | if not os.path.exists(prompt_path): |
| | prompt_path = "./customs/" + prompt + ".npz" |
| | if not os.path.exists(prompt_path): |
| | raise ValueError(f"Cannot find prompt {prompt}") |
| | prompt_data = np.load(prompt_path) |
| | audio_prompts = prompt_data['audio_tokens'] |
| | text_prompts = prompt_data['text_tokens'] |
| | lang_pr = prompt_data['lang_code'] |
| | lang_pr = code2lang[int(lang_pr)] |
| |
|
| | |
| | audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) |
| | text_prompts = torch.tensor(text_prompts).type(torch.int32) |
| | else: |
| | audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) |
| | text_prompts = torch.zeros([1, 0]).type(torch.int32) |
| | lang_pr = language if language != 'mix' else 'en' |
| | if mode == 'fixed-prompt': |
| | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) |
| | for text in sentences: |
| | text = text.replace("\n", "").strip(" ") |
| | if text == "": |
| | continue |
| | lang_token = lang2token[language] |
| | lang = token2lang[lang_token] |
| | text = lang_token + text + lang_token |
| |
|
| | enroll_x_lens = text_prompts.shape[-1] |
| | logging.info(f"synthesize text: {text}") |
| | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
| | text_tokens, text_tokens_lens = text_collater( |
| | [ |
| | phone_tokens |
| | ] |
| | ) |
| | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
| | text_tokens_lens += enroll_x_lens |
| | |
| | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
| | encoded_frames = model.inference( |
| | text_tokens.to(device), |
| | text_tokens_lens.to(device), |
| | audio_prompts, |
| | enroll_x_lens=enroll_x_lens, |
| | top_k=-100, |
| | temperature=1, |
| | prompt_language=lang_pr, |
| | text_language=langs if accent == "no-accent" else lang, |
| | ) |
| | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) |
| | samples = codec.decode( |
| | [(complete_tokens, None)] |
| | ) |
| | return samples[0][0].cpu().numpy() |
| | elif mode == "sliding-window": |
| | complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) |
| | original_audio_prompts = audio_prompts |
| | original_text_prompts = text_prompts |
| | for text in sentences: |
| | text = text.replace("\n", "").strip(" ") |
| | if text == "": |
| | continue |
| | lang_token = lang2token[language] |
| | lang = token2lang[lang_token] |
| | text = lang_token + text + lang_token |
| |
|
| | enroll_x_lens = text_prompts.shape[-1] |
| | logging.info(f"synthesize text: {text}") |
| | phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
| | text_tokens, text_tokens_lens = text_collater( |
| | [ |
| | phone_tokens |
| | ] |
| | ) |
| | text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
| | text_tokens_lens += enroll_x_lens |
| | |
| | lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
| | encoded_frames = model.inference( |
| | text_tokens.to(device), |
| | text_tokens_lens.to(device), |
| | audio_prompts, |
| | enroll_x_lens=enroll_x_lens, |
| | top_k=-100, |
| | temperature=1, |
| | prompt_language=lang_pr, |
| | text_language=langs if accent == "no-accent" else lang, |
| | ) |
| | complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) |
| | if torch.rand(1) < 0.5: |
| | audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:] |
| | text_prompts = text_tokens[:, enroll_x_lens:] |
| | else: |
| | audio_prompts = original_audio_prompts |
| | text_prompts = original_text_prompts |
| | samples = codec.decode( |
| | [(complete_tokens, None)] |
| | ) |
| | return samples[0][0].cpu().numpy() |
| | else: |
| | raise ValueError(f"No such mode {mode}") |