import os import time import random import argparse import numpy as np from tqdm import tqdm from huggingface_hub import snapshot_download import torch import torch.nn as nn import torch.nn.functional as F import math from capspeech.nar import bigvgan import librosa from capspeech.nar.utils import make_pad_mask from capspeech.nar.model.modules import MelSpec from capspeech.nar.network.crossdit import CrossDiT from capspeech.nar.inference import sample from capspeech.nar.utils import load_yaml_with_includes import soundfile as sf from transformers import T5EncoderModel, AutoTokenizer from g2p_en import G2p import laion_clap from transformers import AutoTokenizer, AutoModelForSequenceClassification import re import time def seed_everything(seed): os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True valid_symbols = [ 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', '', ',', '.', '!', '?', '', '', '', '' ] def encode(text, text_tokenizer): if '' in text: assert '' in text, text text = text.split(">", 1)[1].strip() # remove the audio label seg1 = text.split("")[0] seg2 = text.split("")[1].split("")[0] seg3 = text.split("")[1] phn1 = text_tokenizer(seg1) if len(phn1) > 0: phn1.append(" ") phn1.append("") phn1.append(" ") phn2 = text_tokenizer(seg2) if len(phn2) > 0: phn2.append(" ") phn2.append("") phn3 = text_tokenizer(seg3) if len(phn3) > 0: phn2.append(" ") phn = [*phn1,*phn2,*phn3] elif '' in text: assert '' in text, text text = text.split(">", 1)[1].strip() # remove the audio label seg1 = text.split("")[0] seg2 = text.split("")[1].split("")[0] seg3 = text.split("")[1] phn1 = text_tokenizer(seg1) if len(phn1) > 0: phn1.append(" ") phn1.append("") phn1.append(" ") phn2 = text_tokenizer(seg2) if len(phn2) > 0: phn2.append(" ") phn2.append("") phn3 = text_tokenizer(seg3) if len(phn3) > 0: phn2.append(" ") phn = [*phn1,*phn2,*phn3] else: phn = text_tokenizer(text) phn = [item.replace(' ', '') for item in phn] phn = [item for item in phn if item in valid_symbols] return phn def estimate_duration_range(text): words = text.strip().split() num_words = len(words) min_duration = num_words / 4.0 max_duration = num_words / 1.5 ref_min = num_words / 3.0 ref_max = num_words / 1.5 return min_duration, max_duration, ref_min, ref_max def get_duration(text, predicted_duration): cleaned_text = re.sub(r"<[^>]*>", "", text) min_dur, max_dur, ref_min, ref_max = estimate_duration_range(cleaned_text) event_dur = random.uniform(0.5, 2.0) if "" in text else 0 if predicted_duration < min_dur + event_dur or predicted_duration > max_dur + event_dur: return round(random.uniform(ref_min, ref_max), 2) + event_dur return predicted_duration def run( model_list, device, duration, transcript, caption, speed=1.0, steps=25, cfg=2.0 ): model, vocoder, phn2num, text_tokenizer, clap_model, duration_tokenizer, duration_model, caption_tokenizer, caption_encoder = model_list print("Start Generation...") start_time = time.time() if "" in transcript or "" in transcript: tag = transcript.split(">", 1)[0].strip() tag = tag[1:].lower().replace("_"," ") else: tag = "none" phn = encode(transcript, text_tokenizer) text_tokens = [phn2num[item] for item in phn] text = torch.LongTensor(text_tokens).unsqueeze(0).to(device) if duration is None: duration_inputs = caption + " " + transcript duration_inputs = duration_tokenizer(duration_inputs, return_tensors="pt", padding="max_length", truncation=True, max_length=400) with torch.no_grad(): batch_encoding = caption_tokenizer(caption, return_tensors="pt") ori_tokens = batch_encoding["input_ids"].to(device) prompt = caption_encoder(input_ids=ori_tokens).last_hidden_state.squeeze().unsqueeze(0).to(device) tag_data = [tag] tag_embed = clap_model.get_text_embedding(tag_data, use_tensor=True) clap = tag_embed.squeeze().unsqueeze(0).to(device) if duration is None: duration_outputs = duration_model(**duration_inputs) predicted_duration = duration_outputs.logits.squeeze().item() duration = get_duration(transcript, predicted_duration) if speed == 0: speed = 1 duration = duration / speed audio_clips = torch.zeros([1, math.ceil(duration*24000/256), 100]).to(device) cond = None seq_len_prompt = prompt.shape[1] prompt_lens = torch.Tensor([prompt.shape[1]]) prompt_mask = make_pad_mask(prompt_lens, seq_len_prompt).to(prompt.device) gen = sample(model, vocoder, audio_clips, cond, text, prompt, clap, prompt_mask, steps=steps, cfg=cfg, sway_sampling_coef=-1.0, device=device) end_time = time.time() audio_len = gen.shape[-1] / 24000 # sampling rate fixed in this work rtf = (end_time-start_time)/audio_len print(f"RTF: {rtf:.4f}") return gen def load_model(device, task): print("Downloading model from Huggingface...") local_dir = snapshot_download( repo_id="OpenSound/CapSpeech-models" ) if task == "PT": model_path = os.path.join(local_dir, "nar_PT.pt") elif task == "CapTTS": model_path = os.path.join(local_dir, "nar_CapTTS.pt") elif task == "EmoCapTTS": model_path = os.path.join(local_dir, "nar_EmoCapTTS.pt") elif task == "AccCapTTS": model_path = os.path.join(local_dir, "nar_AccCapTTS.pt") elif task == "AgentTTS": model_path = os.path.join(local_dir, "nar_AgentTTS.pt") else: assert 1 == 0, task print("Loading models...") params = load_yaml_with_includes(os.path.join(local_dir, "nar_pretrain.yaml")) model = CrossDiT(**params['model']).to(device) checkpoint = torch.load(model_path)['model'] model.load_state_dict(checkpoint, strict=True) # mel spectrogram mel = MelSpec(**params['mel']).to(device) latent_sr = params['mel']['target_sample_rate'] / params['mel']['hop_length'] # load vocab vocab_fn = os.path.join(os.path.join(local_dir, "vocab.txt")) with open(vocab_fn, "r") as f: temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0] phn2num = {item[1]:int(item[0]) for item in temp} # load g2p text_tokenizer = G2p() # load vocoder # instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. vocoder = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False) # remove weight norm in the model and set to eval mode vocoder.remove_weight_norm() vocoder = vocoder.eval().to(device) # load t5 caption_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") caption_encoder = T5EncoderModel.from_pretrained("google/flan-t5-large").to(device).eval() # load clap clap_model = laion_clap.CLAP_Module(enable_fusion=False) clap_model.load_ckpt(os.path.join(local_dir, "clap-630k-best.pt")) # load duration predictor duration_tokenizer = AutoTokenizer.from_pretrained(os.path.join(local_dir, "nar_duration_predictor")) duration_model = AutoModelForSequenceClassification.from_pretrained(os.path.join(local_dir, "nar_duration_predictor")) duration_model.eval() model_list = [model, vocoder, phn2num, text_tokenizer, clap_model, duration_tokenizer, duration_model, caption_tokenizer, caption_encoder] return model_list