|
import spaces |
|
import os |
|
import re |
|
import sys |
|
import torch |
|
import torchaudio |
|
from omegaconf import OmegaConf |
|
import sentencepiece as spm |
|
|
|
from indextts.utils.front import TextNormalizer |
|
from utils.common import tokenize_by_CJK_char |
|
from utils.feature_extractors import MelSpectrogramFeatures |
|
from indextts.vqvae.xtts_dvae import DiscreteVAE |
|
from indextts.utils.checkpoint import load_checkpoint |
|
from indextts.gpt.model import UnifiedVoice |
|
from indextts.BigVGAN.models import BigVGAN as Generator |
|
|
|
|
|
class IndexTTS: |
|
|
|
def __init__(self, cfg_path='checkpoints/config.yaml', model_dir='checkpoints'): |
|
self.cfg = OmegaConf.load(cfg_path) |
|
self.device = 'cuda:0' |
|
self.model_dir = model_dir |
|
self.dvae = DiscreteVAE(**self.cfg.vqvae) |
|
self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) |
|
load_checkpoint(self.dvae, self.dvae_path) |
|
self.dvae = self.dvae.to(self.device) |
|
self.dvae.eval() |
|
print(">> vqvae weights restored from:", self.dvae_path) |
|
|
|
self.gpt = UnifiedVoice(**self.cfg.gpt) |
|
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) |
|
load_checkpoint(self.gpt, self.gpt_path) |
|
self.gpt = self.gpt.to(self.device) |
|
self.gpt.eval() |
|
print(">> GPT weights restored from:", self.gpt_path) |
|
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False) |
|
|
|
self.bigvgan = Generator(self.cfg.bigvgan) |
|
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint) |
|
vocoder_dict = torch.load(self.bigvgan_path, map_location='cpu') |
|
self.bigvgan.load_state_dict(vocoder_dict['generator']) |
|
self.bigvgan = self.bigvgan.to(self.device) |
|
self.bigvgan.eval() |
|
print(">> bigvgan weights restored from:", self.bigvgan_path) |
|
self.normalizer = None |
|
print(">> end load weights") |
|
|
|
def load_normalizer(self): |
|
self.normalizer = TextNormalizer() |
|
self.normalizer.load() |
|
print(">> TextNormalizer loaded") |
|
|
|
def preprocess_text(self, text): |
|
return self.normalizer.infer(text) |
|
|
|
def infer(self, audio_prompt, text, output_path): |
|
text = self.preprocess_text(text) |
|
|
|
audio, sr = torchaudio.load(audio_prompt) |
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
if audio.shape[0] > 1: |
|
audio = audio[0].unsqueeze(0) |
|
audio = torchaudio.transforms.Resample(sr, 24000)(audio) |
|
cond_mel = MelSpectrogramFeatures()(audio).to(self.device) |
|
print(f"cond_mel shape: {cond_mel.shape}") |
|
|
|
auto_conditioning = cond_mel |
|
|
|
tokenizer = spm.SentencePieceProcessor() |
|
tokenizer.load(os.path.join(self.model_dir, self.cfg.dataset['bpe_model'])) |
|
|
|
punctuation = ["!", "?", ".", ";", "!", "?", "。", ";"] |
|
pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) |
|
sentences = [i for i in re.split(pattern, text) if i.strip() != ""] |
|
print(sentences) |
|
|
|
top_p = .8 |
|
top_k = 30 |
|
temperature = 1.0 |
|
autoregressive_batch_size = 1 |
|
length_penalty = 0.0 |
|
num_beams = 3 |
|
repetition_penalty = 10.0 |
|
max_mel_tokens = 600 |
|
sampling_rate = 24000 |
|
lang = "EN" |
|
lang = "ZH" |
|
wavs = [] |
|
wavs1 = [] |
|
|
|
for sent in sentences: |
|
print(sent) |
|
|
|
cleand_text = tokenize_by_CJK_char(sent) |
|
|
|
print(cleand_text) |
|
text_tokens = torch.IntTensor(tokenizer.encode(cleand_text)).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
|
text_tokens = text_tokens.to(self.device) |
|
print(text_tokens) |
|
print(f"text_tokens shape: {text_tokens.shape}") |
|
text_token_syms = [tokenizer.IdToPiece(idx) for idx in text_tokens[0].tolist()] |
|
print(text_token_syms) |
|
text_len = [text_tokens.size(1)] |
|
text_len = torch.IntTensor(text_len).to(self.device) |
|
print(text_len) |
|
with torch.no_grad(): |
|
codes = self.gpt.inference_speech(auto_conditioning, text_tokens, |
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], |
|
device=text_tokens.device), |
|
|
|
do_sample=True, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temperature, |
|
num_return_sequences=autoregressive_batch_size, |
|
length_penalty=length_penalty, |
|
num_beams=num_beams, |
|
repetition_penalty=repetition_penalty, |
|
max_generate_length=max_mel_tokens) |
|
print(codes) |
|
print(f"codes shape: {codes.shape}") |
|
codes = codes[:, :-2] |
|
|
|
|
|
latent = \ |
|
self.gpt(auto_conditioning, text_tokens, |
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, |
|
torch.tensor([codes.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device), |
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), |
|
return_latent=True, clip_inputs=False) |
|
latent = latent.transpose(1, 2) |
|
''' |
|
latent_list = [] |
|
for lat, t_len in zip(latent, text_lens_out): |
|
lat = lat[:, t_len:] |
|
latent_list.append(lat) |
|
latent = torch.stack(latent_list) |
|
print(f"latent shape: {latent.shape}") |
|
''' |
|
|
|
wav, _ = self.bigvgan(latent.transpose(1, 2), auto_conditioning.transpose(1, 2)) |
|
wav = wav.squeeze(1).cpu() |
|
|
|
wav = 32767 * wav |
|
torch.clip(wav, -32767.0, 32767.0) |
|
print(f"wav shape: {wav.shape}") |
|
|
|
wavs.append(wav) |
|
|
|
wav = torch.cat(wavs, dim=1) |
|
torchaudio.save(output_path, wav.type(torch.int16), 24000) |
|
|
|
|
|
if __name__ == "__main__": |
|
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints") |
|
tts.load_normalizer() |
|
tts.infer(audio_prompt='test_data/input.wav', text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!',output_path="gen.wav") |
|
|