Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import random | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from huggingface_hub import snapshot_download | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from capspeech.nar import bigvgan | |
import librosa | |
from capspeech.nar.utils import make_pad_mask | |
from capspeech.nar.model.modules import MelSpec | |
from capspeech.nar.network.crossdit import CrossDiT | |
from capspeech.nar.inference import sample | |
from capspeech.nar.utils import load_yaml_with_includes | |
import soundfile as sf | |
from transformers import T5EncoderModel, AutoTokenizer | |
from g2p_en import G2p | |
import laion_clap | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import re | |
import time | |
def seed_everything(seed): | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
valid_symbols = [ | |
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', | |
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', | |
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', | |
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', | |
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', | |
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', | |
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', '<BLK>', ',', '.', '!', '?', | |
'<B_start>', '<B_end>', '<I_start>', '<I_end>' | |
] | |
def encode(text, text_tokenizer): | |
if '<B_start>' in text: | |
assert '<B_end>' in text, text | |
text = text.split(">", 1)[1].strip() # remove the audio label | |
seg1 = text.split("<B_start>")[0] | |
seg2 = text.split("<B_start>")[1].split("<B_end>")[0] | |
seg3 = text.split("<B_end>")[1] | |
phn1 = text_tokenizer(seg1) | |
if len(phn1) > 0: | |
phn1.append(" ") | |
phn1.append("<B_start>") | |
phn1.append(" ") | |
phn2 = text_tokenizer(seg2) | |
if len(phn2) > 0: | |
phn2.append(" ") | |
phn2.append("<B_end>") | |
phn3 = text_tokenizer(seg3) | |
if len(phn3) > 0: | |
phn2.append(" ") | |
phn = [*phn1,*phn2,*phn3] | |
elif '<I_start>' in text: | |
assert '<I_end>' in text, text | |
text = text.split(">", 1)[1].strip() # remove the audio label | |
seg1 = text.split("<I_start>")[0] | |
seg2 = text.split("<I_start>")[1].split("<I_end>")[0] | |
seg3 = text.split("<I_end>")[1] | |
phn1 = text_tokenizer(seg1) | |
if len(phn1) > 0: | |
phn1.append(" ") | |
phn1.append("<I_start>") | |
phn1.append(" ") | |
phn2 = text_tokenizer(seg2) | |
if len(phn2) > 0: | |
phn2.append(" ") | |
phn2.append("<I_end>") | |
phn3 = text_tokenizer(seg3) | |
if len(phn3) > 0: | |
phn2.append(" ") | |
phn = [*phn1,*phn2,*phn3] | |
else: | |
phn = text_tokenizer(text) | |
phn = [item.replace(' ', '<BLK>') for item in phn] | |
phn = [item for item in phn if item in valid_symbols] | |
return phn | |
def estimate_duration_range(text): | |
words = text.strip().split() | |
num_words = len(words) | |
min_duration = num_words / 4.0 | |
max_duration = num_words / 1.5 | |
ref_min = num_words / 3.0 | |
ref_max = num_words / 1.5 | |
return min_duration, max_duration, ref_min, ref_max | |
def get_duration(text, predicted_duration): | |
cleaned_text = re.sub(r"<[^>]*>", "", text) | |
min_dur, max_dur, ref_min, ref_max = estimate_duration_range(cleaned_text) | |
event_dur = random.uniform(0.5, 2.0) if "<I_start>" in text else 0 | |
if predicted_duration < min_dur + event_dur or predicted_duration > max_dur + event_dur: | |
return round(random.uniform(ref_min, ref_max), 2) + event_dur | |
return predicted_duration | |
def run( | |
model_list, | |
device, | |
duration, | |
transcript, | |
caption, | |
speed=1.0, | |
steps=25, | |
cfg=2.0 | |
): | |
model, vocoder, phn2num, text_tokenizer, clap_model, duration_tokenizer, duration_model, caption_tokenizer, caption_encoder = model_list | |
print("Start Generation...") | |
start_time = time.time() | |
if "<B_start>" in transcript or "<I_start>" in transcript: | |
tag = transcript.split(">", 1)[0].strip() | |
tag = tag[1:].lower().replace("_"," ") | |
else: | |
tag = "none" | |
phn = encode(transcript, text_tokenizer) | |
text_tokens = [phn2num[item] for item in phn] | |
text = torch.LongTensor(text_tokens).unsqueeze(0).to(device) | |
if duration is None: | |
duration_inputs = caption + " <NEW_SEP> " + transcript | |
duration_inputs = duration_tokenizer(duration_inputs, return_tensors="pt", padding="max_length", truncation=True, max_length=400) | |
with torch.no_grad(): | |
batch_encoding = caption_tokenizer(caption, return_tensors="pt") | |
ori_tokens = batch_encoding["input_ids"].to(device) | |
prompt = caption_encoder(input_ids=ori_tokens).last_hidden_state.squeeze().unsqueeze(0).to(device) | |
tag_data = [tag] | |
tag_embed = clap_model.get_text_embedding(tag_data, use_tensor=True) | |
clap = tag_embed.squeeze().unsqueeze(0).to(device) | |
if duration is None: | |
duration_outputs = duration_model(**duration_inputs) | |
predicted_duration = duration_outputs.logits.squeeze().item() | |
duration = get_duration(transcript, predicted_duration) | |
if speed == 0: | |
speed = 1 | |
duration = duration / speed | |
audio_clips = torch.zeros([1, math.ceil(duration*24000/256), 100]).to(device) | |
cond = None | |
seq_len_prompt = prompt.shape[1] | |
prompt_lens = torch.Tensor([prompt.shape[1]]) | |
prompt_mask = make_pad_mask(prompt_lens, seq_len_prompt).to(prompt.device) | |
gen = sample(model, vocoder, | |
audio_clips, cond, text, prompt, clap, prompt_mask, | |
steps=steps, cfg=cfg, | |
sway_sampling_coef=-1.0, device=device) | |
end_time = time.time() | |
audio_len = gen.shape[-1] / 24000 # sampling rate fixed in this work | |
rtf = (end_time-start_time)/audio_len | |
print(f"RTF: {rtf:.4f}") | |
return gen | |
def load_model(device, task): | |
print("Downloading model from Huggingface...") | |
local_dir = snapshot_download( | |
repo_id="OpenSound/CapSpeech-models" | |
) | |
if task == "PT": | |
model_path = os.path.join(local_dir, "nar_PT.pt") | |
elif task == "CapTTS": | |
model_path = os.path.join(local_dir, "nar_CapTTS.pt") | |
elif task == "EmoCapTTS": | |
model_path = os.path.join(local_dir, "nar_EmoCapTTS.pt") | |
elif task == "AccCapTTS": | |
model_path = os.path.join(local_dir, "nar_AccCapTTS.pt") | |
elif task == "AgentTTS": | |
model_path = os.path.join(local_dir, "nar_AgentTTS.pt") | |
else: | |
assert 1 == 0, task | |
print("Loading models...") | |
params = load_yaml_with_includes(os.path.join(local_dir, "nar_pretrain.yaml")) | |
model = CrossDiT(**params['model']).to(device) | |
checkpoint = torch.load(model_path)['model'] | |
model.load_state_dict(checkpoint, strict=True) | |
# mel spectrogram | |
mel = MelSpec(**params['mel']).to(device) | |
latent_sr = params['mel']['target_sample_rate'] / params['mel']['hop_length'] | |
# load vocab | |
vocab_fn = os.path.join(os.path.join(local_dir, "vocab.txt")) | |
with open(vocab_fn, "r") as f: | |
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0] | |
phn2num = {item[1]:int(item[0]) for item in temp} | |
# load g2p | |
text_tokenizer = G2p() | |
# load vocoder | |
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. | |
vocoder = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False) | |
# remove weight norm in the model and set to eval mode | |
vocoder.remove_weight_norm() | |
vocoder = vocoder.eval().to(device) | |
# load t5 | |
caption_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
caption_encoder = T5EncoderModel.from_pretrained("google/flan-t5-large").to(device).eval() | |
# load clap | |
clap_model = laion_clap.CLAP_Module(enable_fusion=False) | |
clap_model.load_ckpt(os.path.join(local_dir, "clap-630k-best.pt")) | |
# load duration predictor | |
duration_tokenizer = AutoTokenizer.from_pretrained(os.path.join(local_dir, "nar_duration_predictor")) | |
duration_model = AutoModelForSequenceClassification.from_pretrained(os.path.join(local_dir, "nar_duration_predictor")) | |
duration_model.eval() | |
model_list = [model, vocoder, phn2num, text_tokenizer, clap_model, duration_tokenizer, duration_model, caption_tokenizer, caption_encoder] | |
return model_list | |