cren / weclone-audio /src /Llasa /text_to_speech.py
CrenCren's picture
Upload folder using huggingface_hub
88aba71 verified
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model
import torchaudio
class TextToSpeech:
def __init__(self, sample_audio_path, sample_audio_text):
self.sample_audio_text = sample_audio_text
# 初始化模型
llasa_3b = "HKUSTAudio/Llasa-3B"
xcodec2 = "HKUSTAudio/xcodec2"
self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
self.llasa_3b_model = AutoModelForCausalLM.from_pretrained(
llasa_3b,
trust_remote_code=True,
device_map="auto",
)
self.llasa_3b_model.eval()
self.xcodec_model = XCodec2Model.from_pretrained(xcodec2)
self.xcodec_model.eval().cuda()
# 处理音频
waveform, sample_rate = torchaudio.load(sample_audio_path)
if len(waveform[0]) / sample_rate > 15:
print("已将音频裁剪至前15秒。")
waveform = waveform[:, : sample_rate * 15]
# 检查音频是否为立体声
if waveform.size(0) > 1:
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
else:
waveform_mono = waveform
self.prompt_wav = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform_mono)
# Encode the prompt wav
vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav)
vq_code_prompt = vq_code_prompt[0, 0, :]
self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt)
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
def ids_to_speech_tokens(self, speech_ids):
speech_tokens_str = []
for speech_id in speech_ids:
speech_tokens_str.append(f"<|s_{speech_id}|>")
return speech_tokens_str
def extract_speech_ids(self, speech_tokens_str):
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith("<|s_") and token_str.endswith("|>"):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
@torch.inference_mode()
def infer(self, target_text):
if len(target_text) == 0:
return None
elif len(target_text) > 300:
print("文本过长,请保持在300字符以内。")
target_text = target_text[:300]
input_text = self.sample_audio_text + " " + target_text
formatted_text = (
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
)
chat = [
{
"role": "user",
"content": "Convert the text to speech:" + formatted_text,
},
{
"role": "assistant",
"content": "<|SPEECH_GENERATION_START|>"
+ "".join(self.speech_ids_prefix),
},
]
input_ids = self.tokenizer.apply_chat_template(
chat, tokenize=True, return_tensors="pt", continue_final_message=True
)
input_ids = input_ids.to("cuda")
outputs = self.llasa_3b_model.generate(
input_ids,
max_length=2048,
eos_token_id=self.speech_end_id,
do_sample=True,
top_p=1,
temperature=0.8,
)
generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1]
speech_tokens = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)
speech_tokens = self.extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
gen_wav = self.xcodec_model.decode_code(speech_tokens)
gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:]
return (16000, gen_wav[0, 0, :].cpu().numpy())
if __name__ == "__main__":
# 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。
sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav")
tts = TextToSpeech(sample_audio_path, sample_audio_text)
target_text = "晚上好啊,吃了吗您"
result = tts.infer(target_text)
sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0])
target_text = "我是老北京正黄旗!"
result = tts.infer(target_text)
sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0])