Spaces:
Build error
Build error
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import soundfile as sf | |
from xcodec2.modeling_xcodec2 import XCodec2Model | |
import numpy as np | |
import ChatTTS | |
import re | |
DEFAULT_TTS_MODEL_NAME = "HKUSTAudio/LLasa-1B" | |
DEMO_EXAMPLES = [ | |
["太乙真人.wav", "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"], | |
["邓紫棋.wav", "特别大的不同,因为以前在香港是过年的时候,我们可能见到的亲戚都是爸爸那边的亲戚"], | |
["雷军.wav", "这是个好问题,我把来龙去脉给你简单讲,就是这个社会对小米有很多的误解,有很多的误解,呃,也能理解啊,就是小米这个模式呢"], | |
["周杰伦.wav", "但如果你这兴趣可以得到很大的回响,那会更开心"], | |
["Taylor Swift.wav", "It's actually uh, it's a concept record, but it's my first directly autobiographical album in a while because the last album that I put out was, uh, a rework."] | |
] | |
class TTSapi: | |
def __init__(self, | |
model_name=DEFAULT_TTS_MODEL_NAME, | |
codec_model_name="HKUST-Audio/xcodec2", | |
device=torch.device("cuda:0")): | |
self.reload(model_name, codec_model_name, device) | |
def reload(self, | |
model_name=DEFAULT_TTS_MODEL_NAME, | |
codec_model_name="HKUST-Audio/xcodec2", | |
device=torch.device("cuda:0")): | |
if 'llasa' in model_name.lower(): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
self.model.eval().to(device) | |
self.codec_model = XCodec2Model.from_pretrained(codec_model_name) | |
self.codec_model.eval().to(device) | |
self.device = device | |
self.codec_model_name = codec_model_name | |
self.sr = 16000 | |
elif 'chattts' in model_name.lower(): | |
self.model = ChatTTS.Chat() | |
self.model.load(compile=False) # Set to True for better performance but would l significantly reduce the loading speed | |
self.sr = 24000 | |
self.punctuation = r'[,,.。??!!~~;;]' | |
else: | |
raise ValueError(f'不支持的TTS模型:{model_name}') | |
self.model_name = model_name | |
def ids_to_speech_tokens(self, speech_ids): | |
speech_tokens_str = [] | |
for speech_id in speech_ids: | |
speech_tokens_str.append(f"<|s_{speech_id}|>") | |
return speech_tokens_str | |
def extract_speech_ids(self, speech_tokens_str): | |
speech_ids = [] | |
for token_str in speech_tokens_str: | |
if token_str.startswith('<|s_') and token_str.endswith('|>'): | |
num_str = token_str[4:-2] | |
num = int(num_str) | |
speech_ids.append(num) | |
else: | |
print(f"Unexpected token: {token_str}") | |
return speech_ids | |
def forward(self, input_text, speech_prompt=None, save_path='wavs/generated/gen.wav'): | |
#TTS start! | |
with torch.no_grad(): | |
if 'chattts' in self.model_name.lower(): | |
# rand_spk = chat.sample_random_speaker() | |
# print(rand_spk) # save it for later timbre recovery | |
# params_infer_code = ChatTTS.Chat.InferCodeParams( | |
# spk_emb = rand_spk, # add sampled speaker | |
# temperature = .3, # using custom temperature | |
# top_P = 0.7, # top P decode | |
# top_K = 20, # top K decode | |
# ) | |
break_num = max(min(len(re.split(self.punctuation, input_text)), 7), 2) | |
params_refine_text = ChatTTS.Chat.RefineTextParams( | |
prompt=f'[oral_2][laugh_0][break_{break_num}]', | |
) | |
wavs = self.model.infer([input_text], | |
params_refine_text=params_refine_text, | |
) | |
gen_wav_save = wavs[0] | |
sf.write(save_path, gen_wav_save, 24000) | |
else: | |
if speech_prompt: | |
# only 16khz speech support! | |
prompt_wav, sr = sf.read(speech_prompt) # you can find wav in Files | |
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0) | |
# Encode the prompt wav | |
vq_code_prompt = self.codec_model.encode_code(input_waveform=prompt_wav) | |
print("Prompt Vq Code Shape:", vq_code_prompt.shape ) | |
vq_code_prompt = vq_code_prompt[0,0,:] | |
# Convert int 12345 to token <|s_12345|> | |
speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) | |
else: | |
speech_ids_prefix = '' | |
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" | |
# Tokenize the text ( and the speech prefix) | |
chat = [ | |
{"role": "user", "content": "Convert the text to speech:" + formatted_text}, | |
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)} | |
] | |
input_ids = self.tokenizer.apply_chat_template( | |
chat, | |
tokenize=True, | |
return_tensors='pt', | |
continue_final_message=True | |
) | |
input_ids = input_ids.to(self.device) | |
speech_end_id = self.tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') | |
# Generate the speech autoregressively | |
outputs = self.model.generate( | |
input_ids, | |
max_length=2048, # We trained our model with a max length of 2048 | |
eos_token_id= speech_end_id , | |
do_sample=True, | |
top_p=1, # Adjusts the diversity of generated content | |
temperature=1, # Controls randomness in output | |
) | |
# Extract the speech tokens | |
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1] | |
speech_tokens = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
# Convert token <|s_23456|> to int 23456 | |
speech_tokens = self.extract_speech_ids(speech_tokens) | |
speech_tokens = torch.tensor(speech_tokens).to(self.device).unsqueeze(0).unsqueeze(0) | |
# Decode the speech tokens to speech waveform | |
gen_wav = self.codec_model.decode_code(speech_tokens) | |
# if only need the generated part | |
if speech_prompt: | |
gen_wav = gen_wav[:,:,prompt_wav.shape[1]:] | |
gen_wav_save = gen_wav[0, 0, :].cpu().numpy() | |
sf.write(save_path, gen_wav_save, 16000) | |
# gen_wav_save = np.clip(gen_wav_save, -1, 1) | |
# gen_wav_save = (gen_wav_save * 32767).astype(np.int16) | |
return gen_wav_save | |
if __name__ == '__main__': | |
# Llasa-8B shows better text understanding ability. | |
# input_text = " He shouted, 'Everyone, please gather 'round! Here's the plan: 1) Set-up at 9:15 a.m.; 2) Lunch at 12:00 p.m. (please RSVP!); 3) Playing — e.g., games, music, etc. — from 1:15 to 4:45; and 4) Clean-up at 5 p.m.'" | |
# prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" | |
# input_text = prompt_text + '嘻嘻,臭宝儿你真可爱,我好喜欢你呀。' | |
# save_root = 'wavs/generated/' | |
# save_path = save_root + 'test.wav' | |
# speech_ref = 'wavs/ref/太乙真人.wav' | |
# # speech_ref = None | |
# # 帘外雨潺潺,春意阑珊。罗衾不耐五更寒。梦里不知身是客,一晌贪欢。独自莫凭栏,无限江山。别时容易见时难。流水落花春去也,天上人间。 | |
# llasa_tts = TTSapi() | |
# gen = llasa_tts.forward(input_text, speech_prompt=speech_ref, save_path=save_path) | |
# print(gen.shape) | |
import gradio as gr | |
synthesiser = TTSapi() | |
TTS_LOADED = True | |
def predict(config): | |
global TTS_LOADED, synthesiser | |
print(f"待合成文本:{config['msg']}") | |
print(f"选中TTS模型:{config['tts_model']}") | |
print(f"参考音频路径:{config['ref_audio']}") | |
print(f"参考音频文本:{config['ref_audio_transcribe']}") | |
text = config['msg'] | |
try: | |
if len(text) == 0: | |
audio_output = np.array([0], dtype=np.int16) | |
print("输入为空,无法合成语音") | |
else: | |
if not TTS_LOADED: | |
print('TTS模型首次加载...') | |
gr.Info("初次加载TTS模型,请稍候..", duration=63) | |
synthesiser = TTSapi(model_name=config['tts_model'])#, device="cuda:2") | |
TTS_LOADED = True | |
print('加载完毕...') | |
# 检查当前模型是否是所选 | |
if config['tts_model'] != synthesiser.model_name: | |
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') | |
synthesiser.reload(model_name=config['tts_model']) | |
# 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀 | |
if config['ref_audio']: | |
prompt_text = config['ref_audio_transcribe'] | |
if prompt_text is None: | |
# prompt_text = ... | |
raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型 | |
text = prompt_text + text | |
audio_output = synthesiser.forward(text, speech_prompt=config['ref_audio']) | |
except Exception as e: | |
print('!!!!!!!!') | |
print(e) | |
print('!!!!!!!!') | |
return (synthesiser.sr if synthesiser else 16000, audio_output) | |
with gr.Blocks(title="TTS Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: | |
gr.Markdown(""" | |
# Personalized TTS Demo | |
## 使用步骤 | |
* 上传你想要合成的目标说话人的语音,10s左右即可,并在下面填入对应的文本。或直接点击下方示例 | |
* 输入你想要合成的文字,点击合成语音按钮,稍等片刻即可 | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# TTS模型选择 | |
tts_model = gr.Dropdown( | |
label="选择TTS模型", | |
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], | |
value=DEFAULT_TTS_MODEL_NAME, | |
interactive=True, | |
visible=False # 给产品演示,暂时不展示模型选择 | |
) | |
# 参考音频上传 | |
ref_audio = gr.Audio( | |
label="上传参考音频", | |
type="filepath", | |
interactive=True | |
) | |
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) | |
# 创建示例数据 | |
examples = gr.Examples( | |
examples=DEMO_EXAMPLES, | |
inputs=[ref_audio, ref_audio_transcribe], | |
fn=predict | |
) | |
with gr.Column(): | |
audio_player = gr.Audio( | |
label="听听我声音~", | |
type="numpy", | |
interactive=False | |
) | |
msg = gr.Textbox(label="输入文本", placeholder="请输入想要合成的文本") | |
submit_btn = gr.Button("合成语音", variant="primary") | |
current_config = gr.State({ | |
"msg": None, | |
"tts_model": DEFAULT_TTS_MODEL_NAME, | |
"ref_audio": None, | |
"ref_audio_transcribe": None | |
}) | |
gr.on(triggers=[msg.change, tts_model.change, ref_audio.change, | |
ref_audio_transcribe.change], | |
fn=lambda text, model, audio, ref_text: {"msg": text, "tts_model": model, "ref_audio": audio, | |
"ref_audio_transcribe": ref_text}, | |
inputs=[msg, tts_model, ref_audio, ref_audio_transcribe], | |
outputs=current_config | |
) | |
submit_btn.click( | |
predict, | |
[current_config], | |
[audio_player], | |
queue=False | |
) | |
demo.launch(share=False, server_name='0.0.0.0', server_port=7863, inbrowser=True) | |